monarch_rdma/
rdma_manager_actor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # RDMA Manager Actor
10//!
11//! Manages RDMA connections and operations using `hyperactor` for asynchronous messaging.
12//!
13//! ## Architecture
14//!
15//! `RdmaManagerActor` is a per-host entity that:
16//! - Manages connections to multiple remote RdmaManagerActors (i.e. across the hosts in a Monarch cluster)
17//! - Handles memory registration, connection setup, and data transfer
18//! - Manages all RdmaBuffers in its associated host
19//!
20//! ## Core Operations
21//!
22//! - Connection establishment with partner actors
23//! - RDMA operations (put/write, get/read)
24//! - Completion polling
25//! - Memory region management
26//!
27//! ## Usage
28//!
29//! See test examples: `test_rdma_write_loopback` and `test_rdma_read_loopback`.
30use std::collections::HashMap;
31
32use async_trait::async_trait;
33use hyperactor::Actor;
34use hyperactor::ActorId;
35use hyperactor::ActorRef;
36use hyperactor::Context;
37use hyperactor::HandleClient;
38use hyperactor::Handler;
39use hyperactor::Instance;
40use hyperactor::Named;
41use hyperactor::OncePortRef;
42use hyperactor::RefClient;
43use hyperactor::supervision::ActorSupervisionEvent;
44use serde::Deserialize;
45use serde::Serialize;
46
47use crate::ibverbs_primitives::IbverbsConfig;
48use crate::ibverbs_primitives::RdmaQpInfo;
49use crate::rdma_components::RdmaBuffer;
50use crate::rdma_components::RdmaDomain;
51use crate::rdma_components::RdmaQueuePair;
52use crate::validate_execution_context;
53
54/// Represents a reference to a remote RDMA buffer that can be accessed via RDMA operations.
55/// This struct encapsulates all the information needed to identify and access a memory region
56/// on a remote host using RDMA.
57#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
58pub enum RdmaManagerMessage {
59    RequestBuffer {
60        addr: usize,
61        size: usize,
62        #[reply]
63        /// `reply` - Reply channel to return the RDMA buffer handle
64        reply: OncePortRef<RdmaBuffer>,
65    },
66    ReleaseBuffer {
67        buffer: RdmaBuffer,
68    },
69    RequestQueuePair {
70        remote: ActorRef<RdmaManagerActor>,
71        #[reply]
72        /// `reply` - Reply channel to return the queue pair for communication
73        reply: OncePortRef<RdmaQueuePair>,
74    },
75    IsConnected {
76        /// `other` - The ActorId of the actor to check connection with
77        other: ActorRef<RdmaManagerActor>,
78        #[reply]
79        /// `reply` - Reply channel to return whether the actors have connected
80        reply: OncePortRef<bool>,
81    },
82    Connect {
83        /// `other` - The ActorId of the actor to connect to
84        other: ActorRef<RdmaManagerActor>,
85        /// `endpoint` - Connection information needed to establish the RDMA connection
86        endpoint: RdmaQpInfo,
87    },
88    InitializeQP {
89        remote: ActorRef<RdmaManagerActor>,
90        #[reply]
91        /// `reply` - Reply channel to return the queue pair for communication
92        reply: OncePortRef<bool>,
93    },
94    ConnectionInfo {
95        /// `other` - The ActorId to get connection info for
96        other: ActorRef<RdmaManagerActor>,
97        #[reply]
98        /// `reply` - Reply channel to return connection information needed for the RDMA connection
99        reply: OncePortRef<RdmaQpInfo>,
100    },
101}
102
103#[derive(Debug)]
104#[hyperactor::export(
105    spawn = true,
106    handlers = [
107        RdmaManagerMessage,
108    ],
109)]
110pub struct RdmaManagerActor {
111    // Map between ActorIds and their corresponding RdmaQueuePair
112    qp_map: HashMap<ActorId, RdmaQueuePair>,
113
114    // The RDMA domain associated with this actor.
115    //
116    // The domain is responsible for managing the RDMA resources and configurations
117    // specific to this actor. It encapsulates the context and protection domain
118    // necessary for RDMA operations, ensuring that all RDMA activities are
119    // performed within a consistent and isolated environment.
120    //
121    // This domain is initialized during the creation of the `RdmaManagerActor`
122    // and is used throughout the actor's lifecycle to manage RDMA connections
123    // and operations.
124    domain: RdmaDomain,
125    config: IbverbsConfig,
126}
127
128#[async_trait]
129impl Actor for RdmaManagerActor {
130    type Params = IbverbsConfig;
131
132    async fn new(_params: Self::Params) -> Result<Self, anyhow::Error> {
133        let mut config = _params;
134
135        // check config and hardware support align
136        if config.use_gpu_direct {
137            match validate_execution_context().await {
138                Ok(_) => {
139                    tracing::info!("GPU Direct RDMA execution context validated successfully");
140                }
141                Err(e) => {
142                    tracing::warn!(
143                        "GPU Direct RDMA execution context validation failed: {}. Downgrading to standard ibverbs mode.",
144                        e
145                    );
146                    config.use_gpu_direct = false;
147                }
148            }
149        }
150
151        let domain = RdmaDomain::new(config.device.clone())
152            .map_err(|e| anyhow::anyhow!("rdmaManagerActor could not create domain: {}", e))?;
153        Ok(Self {
154            qp_map: HashMap::new(),
155            domain,
156            config,
157        })
158    }
159
160    async fn handle_supervision_event(
161        &mut self,
162        _cx: &Instance<Self>,
163        _event: &ActorSupervisionEvent,
164    ) -> Result<bool, anyhow::Error> {
165        tracing::error!("rdmaManagerActor supervision event: {:?}", _event);
166        tracing::error!("rdmaManagerActor error occurred, stop the worker process, exit code: 1");
167        std::process::exit(1);
168    }
169}
170
171#[async_trait]
172#[hyperactor::forward(RdmaManagerMessage)]
173impl RdmaManagerMessageHandler for RdmaManagerActor {
174    /// Requests a buffer to be registered with the RDMA domain.
175    ///
176    /// This function registers a memory region with the RDMA domain and returns an `RdmaBuffer`
177    /// that encapsulates the necessary information for RDMA operations.
178    ///
179    /// # Arguments
180    ///
181    /// * `this` - The context of the actor requesting the buffer.
182    /// * `addr` - The starting address of the memory region to be registered.
183    /// * `size` - The size of the memory region to be registered.
184    ///
185    /// # Returns
186    ///
187    /// * `Result<RdmaBuffer, anyhow::Error>` - On success, returns an `RdmaBuffer` containing
188    ///   the registered memory region's details. On failure, returns an error.
189    async fn request_buffer(
190        &mut self,
191        cx: &Context<Self>,
192        addr: usize,
193        size: usize,
194    ) -> Result<RdmaBuffer, anyhow::Error> {
195        let mr = self.domain.register_buffer(addr, size)?;
196        Ok(RdmaBuffer {
197            owner: cx.bind().clone(),
198            mr_id: mr.id,
199            addr: mr.addr,
200            size: mr.size,
201            rkey: mr.rkey,
202            lkey: mr.lkey,
203        })
204    }
205
206    /// Deregisters a buffer from the RDMA domain.
207    ///
208    /// This function removes the specified `RdmaBuffer` from the RDMA domain,
209    /// effectively releasing the resources associated with it.
210    ///
211    /// # Arguments
212    ///
213    /// * `_this` - The context of the actor releasing the buffer.
214    /// * `buffer` - The `RdmaBuffer` to be deregistered.
215    ///
216    /// # Returns
217    ///
218    /// * `Result<(), anyhow::Error>` - On success, returns `Ok(())`. On failure, returns an error.
219    async fn release_buffer(
220        &mut self,
221        _cx: &Context<Self>,
222        buffer: RdmaBuffer,
223    ) -> Result<(), anyhow::Error> {
224        self.domain
225            .deregister_buffer(buffer)
226            .map_err(|e| anyhow::anyhow!("could not deregister buffer: {}", e))?;
227        Ok(())
228    }
229
230    /// Requests a queue pair for communication with a remote RDMA manager actor.
231    ///
232    /// This function checks if a connection already exists with the specified remote actor.
233    /// If not, it initializes a new queue pair and establishes a connection with the remote actor.
234    /// It then retrieves the queue pair associated with the remote actor for communication.
235    ///
236    /// # Arguments
237    ///
238    /// * `this` - The context of the actor requesting the queue pair.
239    /// * `remote` - The ActorRef of the remote RDMA manager actor to communicate with.
240    ///
241    /// # Returns
242    ///
243    /// * `Result<RdmaQueuePair, anyhow::Error>` - On success, returns the queue pair for communication.
244    ///   On failure, returns an error.
245    async fn request_queue_pair(
246        &mut self,
247        cx: &Context<Self>,
248        remote: ActorRef<RdmaManagerActor>,
249    ) -> Result<RdmaQueuePair, anyhow::Error> {
250        if !self.is_connected(cx, remote.clone()).await? {
251            let is_loopback =
252                remote.actor_id().clone() == cx.bind::<RdmaManagerActor>().actor_id().clone();
253
254            if is_loopback {
255                self.initialize_qp(cx, remote.clone()).await?;
256                let endpoint = self.connection_info(cx, remote.clone()).await?;
257                self.connect(cx, remote.clone(), endpoint).await?;
258            } else {
259                self.initialize_qp(cx, remote.clone()).await?;
260                remote.initialize_qp(cx, cx.bind().clone()).await?;
261                let remote_endpoint = remote.connection_info(cx, cx.bind().clone()).await?;
262                self.connect(cx, remote.clone(), remote_endpoint).await?;
263                let local_endpoint = self.connection_info(cx, remote.clone()).await?;
264                remote
265                    .connect(cx, cx.bind().clone(), local_endpoint)
266                    .await?;
267            }
268        }
269
270        let qp = self
271            .qp_map
272            .get_mut(&remote.actor_id().clone())
273            .ok_or_else(|| anyhow::anyhow!("on get, no connection found for actor {}", remote))?;
274        Ok(qp.clone())
275    }
276
277    /// Convenience utility to create a new RdmaQueuePair.
278    ///
279    /// This function initializes a new RDMA connection with another actor if one doesn't already exist.
280    /// It creates a new RdmaQueuePair associated with the specified actor ID and adds it to the
281    /// connection map.
282    ///
283    /// # Arguments
284    ///
285    /// * `other` - The ActorRef of the remote actor to connect with
286    async fn initialize_qp(
287        &mut self,
288        _cx: &Context<Self>,
289        other: ActorRef<RdmaManagerActor>,
290    ) -> Result<bool, anyhow::Error> {
291        let key = other.actor_id().clone();
292
293        if let std::collections::hash_map::Entry::Vacant(e) = self.qp_map.entry(key) {
294            let qp = RdmaQueuePair::new(self.domain.context, self.domain.pd, self.config.clone())
295                .map_err(|e| anyhow::anyhow!("could not create RdmaQueuePair: {}", e))?;
296            e.insert(qp);
297            tracing::debug!("successfully created a connection with {:?}", other);
298        }
299        Ok(true)
300    }
301
302    /// Checks if a connection exists with another actor.
303    ///
304    /// # Arguments
305    /// * `other` - The ActorRef of the actor to check the connection with.
306    ///
307    /// # Returns
308    /// * `bool` - Returns true if connected, false otherwise.
309    async fn is_connected(
310        &mut self,
311        _cx: &Context<Self>,
312        other: ActorRef<RdmaManagerActor>,
313    ) -> Result<bool, anyhow::Error> {
314        tracing::debug!("checking if connected with {:?}", other);
315        if !self.qp_map.contains_key(&other.actor_id().clone()) {
316            return Ok(false);
317        }
318        let qp_state = self
319            .qp_map
320            .get_mut(&other.actor_id().clone())
321            .unwrap()
322            .state()?;
323        Ok(qp_state == rdmaxcel_sys::ibv_qp_state::IBV_QPS_RTS)
324    }
325
326    /// Establishes a connection with another actor
327    ///
328    /// # Arguments
329    /// * `other` - The ActorRef of the actor to connect to
330    /// * `endpoint` - Connection information needed to establish the RDMA connection
331    async fn connect(
332        &mut self,
333        _cx: &Context<Self>,
334        other: ActorRef<RdmaManagerActor>,
335        endpoint: RdmaQpInfo,
336    ) -> Result<(), anyhow::Error> {
337        tracing::debug!("connecting with {:?}", other);
338        let qp = self
339            .qp_map
340            .get_mut(&other.actor_id().clone())
341            .ok_or_else(|| {
342                anyhow::anyhow!("on connect, no connection found for actor {}", other)
343            })?;
344        qp.connect(&endpoint)
345            .map_err(|e| anyhow::anyhow!("could not connect to RDMA endpoint: {}", e))?;
346        Ok(())
347    }
348
349    /// Gets connection information for establishing an RDMA connection
350    ///
351    /// # Arguments
352    /// * `other` - The ActorRef to get connection info for
353    ///
354    /// # Returns
355    /// * `RdmaQpInfo` - Connection information needed for the RDMA connection
356    async fn connection_info(
357        &mut self,
358        _cx: &Context<Self>,
359        other: ActorRef<RdmaManagerActor>,
360    ) -> Result<RdmaQpInfo, anyhow::Error> {
361        tracing::debug!("getting connection info with {:?}", other);
362
363        let connection_info = self
364            .qp_map
365            .get_mut(&other.actor_id().clone())
366            .ok_or_else(|| anyhow::anyhow!("no connection found for actor {}", other))?
367            .get_qp_info()?;
368        Ok(connection_info)
369    }
370}