1use std::collections::HashMap;
27use std::sync::Arc;
28
29use async_trait::async_trait;
30use hyperactor::Actor;
31use hyperactor::ActorHandle;
32use hyperactor::ActorRef;
33use hyperactor::Context;
34use hyperactor::HandleClient;
35use hyperactor::Handler;
36use hyperactor::Instance;
37use hyperactor::OncePortHandle;
38use hyperactor::OncePortRef;
39use hyperactor::RefClient;
40use hyperactor::RemoteSpawn;
41use hyperactor::context;
42use hyperactor::supervision::ActorSupervisionEvent;
43use hyperactor_config::Flattrs;
44use serde::Deserialize;
45use serde::Serialize;
46use typeuri::Named;
47
48use crate::backend::RdmaRemoteBackendContext;
49use crate::backend::ibverbs::manager_actor::IbvManagerActor;
50use crate::backend::ibverbs::manager_actor::IbvManagerLocalMessageClient;
51use crate::backend::ibverbs::manager_actor::IbvManagerMessageClient;
52use crate::backend::ibverbs::primitives::IbvConfig;
53use crate::backend::tcp::manager_actor::TcpManagerActor;
54use crate::local_memory::KeepaliveLocalMemory;
55use crate::rdma_components::RdmaRemoteBuffer;
56
57pub fn get_rdmaxcel_error_message(error_code: i32) -> String {
59 unsafe {
60 let c_str = rdmaxcel_sys::rdmaxcel_error_string(error_code);
61 std::ffi::CStr::from_ptr(c_str)
62 .to_string_lossy()
63 .into_owned()
64 }
65}
66
67#[derive(Handler, HandleClient, Debug)]
72pub enum RdmaManagerMessage {
73 RequestBuffer {
76 local: Arc<KeepaliveLocalMemory>,
77 #[reply]
78 reply: OncePortHandle<RdmaRemoteBuffer>,
79 },
80 RequestLocalMemory {
83 remote_buf_id: usize,
84 #[reply]
85 reply: OncePortHandle<Option<Arc<KeepaliveLocalMemory>>>,
86 },
87}
88
89#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
94pub struct ReleaseBuffer {
95 pub id: usize,
96}
97wirevalue::register_type!(ReleaseBuffer);
98
99#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
102pub struct GetIbvActorRef {
103 #[reply]
104 pub reply: OncePortRef<Option<ActorRef<IbvManagerActor>>>,
105}
106wirevalue::register_type!(GetIbvActorRef);
107
108#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
111pub struct GetTcpActorRef {
112 #[reply]
113 pub reply: OncePortRef<ActorRef<TcpManagerActor>>,
114}
115wirevalue::register_type!(GetTcpActorRef);
116
117#[derive(Debug)]
118enum RdmaBackendActor<A: Actor> {
119 Uninit,
120 Created(A),
121 Spawned(ActorHandle<A>),
122}
123
124impl<A: Actor> RdmaBackendActor<A> {
125 fn spawn(&mut self, rdma_manager: &Instance<RdmaManagerActor>) -> anyhow::Result<()> {
126 let created = std::mem::replace(self, RdmaBackendActor::Uninit);
127 let actor = if let RdmaBackendActor::Created(actor) = created {
128 actor
129 } else {
130 panic!("rdma backend actor already spawned");
131 };
132 let handle = rdma_manager.spawn(actor)?;
133 *self = RdmaBackendActor::Spawned(handle);
134 Ok(())
135 }
136
137 fn handle(&self) -> &ActorHandle<A> {
138 if let RdmaBackendActor::Spawned(handle) = self {
139 handle
140 } else {
141 panic!("cannot get handle")
142 }
143 }
144}
145
146#[derive(Debug)]
147#[hyperactor::export(
148 handlers = [
149 GetIbvActorRef,
150 GetTcpActorRef,
151 ReleaseBuffer,
152 ],
153)]
154#[hyperactor::spawnable]
155pub struct RdmaManagerActor {
156 next_remote_buf_id: usize,
157 buffers: HashMap<usize, Arc<KeepaliveLocalMemory>>,
158 ibverbs: Option<RdmaBackendActor<IbvManagerActor>>,
159 tcp: RdmaBackendActor<TcpManagerActor>,
160}
161
162impl RdmaManagerActor {
163 pub fn local_handle(client: &impl context::Actor) -> ActorHandle<Self> {
166 let actor_ref = ActorRef::attest(
167 client
168 .mailbox()
169 .actor_addr()
170 .proc_addr()
171 .actor_addr("rdma_manager"),
172 );
173 actor_ref
174 .downcast_handle(client)
175 .expect("RdmaManagerActor is not in the local process")
176 }
177}
178
179#[async_trait]
180impl RemoteSpawn for RdmaManagerActor {
181 type Params = Option<IbvConfig>;
182
183 async fn new(params: Self::Params, _environment: Flattrs) -> Result<Self, anyhow::Error> {
184 let ibv = if hyperactor_config::global::get(crate::config::RDMA_DISABLE_IBVERBS) {
185 if hyperactor_config::global::get(crate::config::RDMA_ALLOW_TCP_FALLBACK) {
186 tracing::info!("ibverbs disabled by configuration, using TCP transport");
187 None
188 } else {
189 anyhow::bail!(
190 "ibverbs is disabled (rdma_disable_ibverbs=true) \
191 but TCP fallback is also disabled"
192 );
193 }
194 } else {
195 match IbvManagerActor::new(params).await {
196 Ok(actor) => Some(RdmaBackendActor::Created(actor)),
197 Err(e) => {
198 if hyperactor_config::global::get(crate::config::RDMA_ALLOW_TCP_FALLBACK) {
199 tracing::warn!(
200 "ibverbs initialization failed, TCP fallback enabled: {}",
201 e
202 );
203 None
204 } else {
205 return Err(e);
206 }
207 }
208 }
209 };
210
211 let tcp = RdmaBackendActor::Created(TcpManagerActor::new());
212
213 Ok(Self {
214 next_remote_buf_id: 0,
215 buffers: HashMap::new(),
216 ibverbs: ibv,
217 tcp,
218 })
219 }
220}
221
222#[async_trait]
223impl Actor for RdmaManagerActor {
224 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
225 if let Some(ibv) = &mut self.ibverbs {
226 ibv.spawn(this)?;
227 }
228 self.tcp.spawn(this)?;
229 tracing::debug!("RdmaManagerActor initialized with lazy domain/QP creation");
230 Ok(())
231 }
232
233 async fn handle_supervision_event(
234 &mut self,
235 _cx: &Instance<Self>,
236 event: &ActorSupervisionEvent,
237 ) -> Result<bool, anyhow::Error> {
238 if !event.is_error() {
239 return Ok(true);
240 }
241 tracing::error!("rdmaManagerActor supervision event: {:?}", event);
242 tracing::error!("rdmaManagerActor error occurred, stop the worker process, exit code: 1");
243 std::process::exit(1);
244 }
245}
246
247#[async_trait]
248#[hyperactor::handle(GetIbvActorRef)]
249impl GetIbvActorRefHandler for RdmaManagerActor {
250 async fn get_ibv_actor_ref(
251 &mut self,
252 _cx: &Context<Self>,
253 ) -> Result<Option<ActorRef<IbvManagerActor>>, anyhow::Error> {
254 Ok(self.ibverbs.as_ref().map(|ibv| ibv.handle().bind()))
255 }
256}
257
258#[async_trait]
259#[hyperactor::handle(GetTcpActorRef)]
260impl GetTcpActorRefHandler for RdmaManagerActor {
261 async fn get_tcp_actor_ref(
262 &mut self,
263 _cx: &Context<Self>,
264 ) -> Result<ActorRef<TcpManagerActor>, anyhow::Error> {
265 Ok(self.tcp.handle().bind())
266 }
267}
268
269#[async_trait]
270#[hyperactor::handle(ReleaseBuffer)]
271impl ReleaseBufferHandler for RdmaManagerActor {
272 async fn release_buffer(&mut self, cx: &Context<Self>, id: usize) -> Result<(), anyhow::Error> {
273 self.buffers.remove(&id);
274 if let Some(ibv) = &self.ibverbs {
275 ibv.handle().release_buffer(cx, id).await?;
276 }
277 Ok(())
278 }
279}
280
281#[async_trait]
282#[hyperactor::handle(RdmaManagerMessage)]
283impl RdmaManagerMessageHandler for RdmaManagerActor {
284 async fn request_buffer(
285 &mut self,
286 cx: &Context<Self>,
287 local: Arc<KeepaliveLocalMemory>,
288 ) -> Result<RdmaRemoteBuffer, anyhow::Error> {
289 let remote_buf_id = self.next_remote_buf_id;
290 self.next_remote_buf_id += 1;
291 let size = local.size();
292
293 let mut backends = Vec::new();
294
295 if let Some(ibv) = &self.ibverbs {
296 let ibv_buffer = ibv
297 .handle()
298 .register_remote_buffer(cx, remote_buf_id, local.clone())
299 .await?
300 .map_err(|e| anyhow::anyhow!(e))?;
301 backends.push(RdmaRemoteBackendContext::Ibverbs(
302 ibv.handle().bind(),
303 ibv_buffer,
304 ));
305 }
306
307 self.buffers.insert(remote_buf_id, local);
308
309 backends.push(RdmaRemoteBackendContext::Tcp(self.tcp.handle().bind()));
310
311 Ok(RdmaRemoteBuffer {
312 id: remote_buf_id,
313 size,
314 owner: cx.bind().clone(),
315 backends,
316 })
317 }
318
319 async fn request_local_memory(
320 &mut self,
321 _cx: &Context<Self>,
322 remote_buf_id: usize,
323 ) -> Result<Option<Arc<KeepaliveLocalMemory>>, anyhow::Error> {
324 Ok(self.buffers.get(&remote_buf_id).cloned())
325 }
326}