1use std::collections::HashMap;
27use std::sync::Arc;
28
29use async_trait::async_trait;
30use hyperactor::Actor;
31use hyperactor::ActorHandle;
32use hyperactor::Context;
33use hyperactor::HandleClient;
34use hyperactor::Handler;
35use hyperactor::Instance;
36use hyperactor::OncePortHandle;
37use hyperactor::RefClient;
38use hyperactor::RemoteSpawn;
39use hyperactor::context;
40use hyperactor::reference;
41use hyperactor::supervision::ActorSupervisionEvent;
42use hyperactor_config::Flattrs;
43use serde::Deserialize;
44use serde::Serialize;
45use tokio::sync::OnceCell;
46use typeuri::Named;
47
48use crate::backend::RdmaRemoteBackendContext;
49use crate::backend::ibverbs::manager_actor::IbvManagerActor;
50use crate::backend::ibverbs::manager_actor::IbvManagerMessageClient;
51use crate::backend::ibverbs::primitives::IbvConfig;
52use crate::backend::tcp::manager_actor::TcpManagerActor;
53use crate::local_memory::RdmaLocalMemory;
54use crate::rdma_components::RdmaRemoteBuffer;
55
56pub fn get_rdmaxcel_error_message(error_code: i32) -> String {
58 unsafe {
59 let c_str = rdmaxcel_sys::rdmaxcel_error_string(error_code);
60 std::ffi::CStr::from_ptr(c_str)
61 .to_string_lossy()
62 .into_owned()
63 }
64}
65
66#[derive(Handler, HandleClient, Debug)]
71pub enum RdmaManagerMessage {
72 RequestBuffer {
75 local: Arc<dyn RdmaLocalMemory>,
76 #[reply]
77 reply: OncePortHandle<RdmaRemoteBuffer>,
78 },
79 RequestLocalMemory {
82 remote_buf_id: usize,
83 #[reply]
84 reply: OncePortHandle<Option<Arc<dyn RdmaLocalMemory>>>,
85 },
86}
87
88#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
93pub struct ReleaseBuffer {
94 pub id: usize,
95}
96wirevalue::register_type!(ReleaseBuffer);
97
98#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
101pub struct GetIbvActorRef {
102 #[reply]
103 pub reply: reference::OncePortRef<Option<reference::ActorRef<IbvManagerActor>>>,
104}
105wirevalue::register_type!(GetIbvActorRef);
106
107#[derive(Handler, HandleClient, RefClient, Debug, Serialize, Deserialize, Named)]
110pub struct GetTcpActorRef {
111 #[reply]
112 pub reply: reference::OncePortRef<reference::ActorRef<TcpManagerActor>>,
113}
114wirevalue::register_type!(GetTcpActorRef);
115
116#[derive(Debug)]
117enum RdmaBackendActor<A: Actor> {
118 Uninit,
119 Created(A),
120 Spawned(ActorHandle<A>),
121}
122
123impl<A: Actor> RdmaBackendActor<A> {
124 fn spawn(&mut self, rdma_manager: &Instance<RdmaManagerActor>) -> anyhow::Result<()> {
125 let created = std::mem::replace(self, RdmaBackendActor::Uninit);
126 let actor = if let RdmaBackendActor::Created(actor) = created {
127 actor
128 } else {
129 panic!("rdma backend actor already spawned");
130 };
131 let handle = rdma_manager.spawn(actor)?;
132 *self = RdmaBackendActor::Spawned(handle);
133 Ok(())
134 }
135
136 fn handle(&self) -> &ActorHandle<A> {
137 if let RdmaBackendActor::Spawned(handle) = self {
138 handle
139 } else {
140 panic!("cannot get handle")
141 }
142 }
143}
144
145#[derive(Debug)]
146#[hyperactor::export(
147 spawn = true,
148 handlers = [
149 GetIbvActorRef,
150 GetTcpActorRef,
151 ReleaseBuffer,
152 ],
153)]
154pub struct RdmaManagerActor {
155 next_remote_buf_id: usize,
156 buffers: HashMap<usize, Arc<dyn RdmaLocalMemory>>,
157 ibverbs: Option<RdmaBackendActor<IbvManagerActor>>,
158 tcp: RdmaBackendActor<TcpManagerActor>,
159}
160
161impl RdmaManagerActor {
162 pub fn local_handle(client: &impl context::Actor) -> ActorHandle<Self> {
165 let proc_id = client.mailbox().actor_id().proc_id().clone();
166 let actor_ref =
167 reference::ActorRef::attest(reference::ActorId::new(proc_id, "rdma_manager", 0));
168 actor_ref
169 .downcast_handle(client)
170 .expect("RdmaManagerActor is not in the local process")
171 }
172}
173
174#[async_trait]
175impl RemoteSpawn for RdmaManagerActor {
176 type Params = Option<IbvConfig>;
177
178 async fn new(params: Self::Params, _environment: Flattrs) -> Result<Self, anyhow::Error> {
179 let ibv = if hyperactor_config::global::get(crate::config::RDMA_DISABLE_IBVERBS) {
180 if hyperactor_config::global::get(crate::config::RDMA_ALLOW_TCP_FALLBACK) {
181 tracing::info!("ibverbs disabled by configuration, using TCP transport");
182 None
183 } else {
184 anyhow::bail!(
185 "ibverbs is disabled (rdma_disable_ibverbs=true) \
186 but TCP fallback is also disabled"
187 );
188 }
189 } else {
190 match IbvManagerActor::new(params).await {
191 Ok(actor) => Some(RdmaBackendActor::Created(actor)),
192 Err(e) => {
193 if hyperactor_config::global::get(crate::config::RDMA_ALLOW_TCP_FALLBACK) {
194 tracing::warn!(
195 "ibverbs initialization failed, TCP fallback enabled: {}",
196 e
197 );
198 None
199 } else {
200 return Err(e);
201 }
202 }
203 }
204 };
205
206 let tcp = RdmaBackendActor::Created(TcpManagerActor::new());
207
208 Ok(Self {
209 next_remote_buf_id: 0,
210 buffers: HashMap::new(),
211 ibverbs: ibv,
212 tcp,
213 })
214 }
215}
216
217#[async_trait]
218impl Actor for RdmaManagerActor {
219 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
220 if let Some(ibv) = &mut self.ibverbs {
221 ibv.spawn(this)?;
222 }
223 self.tcp.spawn(this)?;
224 tracing::debug!("RdmaManagerActor initialized with lazy domain/QP creation");
225 Ok(())
226 }
227
228 async fn handle_supervision_event(
229 &mut self,
230 _cx: &Instance<Self>,
231 _event: &ActorSupervisionEvent,
232 ) -> Result<bool, anyhow::Error> {
233 tracing::error!("rdmaManagerActor supervision event: {:?}", _event);
234 tracing::error!("rdmaManagerActor error occurred, stop the worker process, exit code: 1");
235 std::process::exit(1);
236 }
237}
238
239#[async_trait]
240#[hyperactor::handle(GetIbvActorRef)]
241impl GetIbvActorRefHandler for RdmaManagerActor {
242 async fn get_ibv_actor_ref(
243 &mut self,
244 _cx: &Context<Self>,
245 ) -> Result<Option<reference::ActorRef<IbvManagerActor>>, anyhow::Error> {
246 Ok(self.ibverbs.as_ref().map(|ibv| ibv.handle().bind()))
247 }
248}
249
250#[async_trait]
251#[hyperactor::handle(GetTcpActorRef)]
252impl GetTcpActorRefHandler for RdmaManagerActor {
253 async fn get_tcp_actor_ref(
254 &mut self,
255 _cx: &Context<Self>,
256 ) -> Result<reference::ActorRef<TcpManagerActor>, anyhow::Error> {
257 Ok(self.tcp.handle().bind())
258 }
259}
260
261#[async_trait]
262#[hyperactor::handle(ReleaseBuffer)]
263impl ReleaseBufferHandler for RdmaManagerActor {
264 async fn release_buffer(&mut self, cx: &Context<Self>, id: usize) -> Result<(), anyhow::Error> {
265 self.buffers.remove(&id);
266 if let Some(ibv) = &self.ibverbs {
267 ibv.handle().release_buffer(cx, id).await?;
268 }
269 Ok(())
270 }
271}
272
273#[async_trait]
274#[hyperactor::handle(RdmaManagerMessage)]
275impl RdmaManagerMessageHandler for RdmaManagerActor {
276 async fn request_buffer(
277 &mut self,
278 cx: &Context<Self>,
279 local: Arc<dyn RdmaLocalMemory>,
280 ) -> Result<RdmaRemoteBuffer, anyhow::Error> {
281 let remote_buf_id = self.next_remote_buf_id;
282 self.next_remote_buf_id += 1;
283 let size = local.size();
284
285 self.buffers.insert(remote_buf_id, local);
286
287 let mut backends = Vec::new();
288
289 if let Some(ibv) = &self.ibverbs {
290 backends.push(RdmaRemoteBackendContext::Ibverbs(
291 ibv.handle().bind(),
292 Arc::new(OnceCell::new()),
293 ));
294 }
295
296 backends.push(RdmaRemoteBackendContext::Tcp(self.tcp.handle().bind()));
297
298 Ok(RdmaRemoteBuffer {
299 id: remote_buf_id,
300 size,
301 owner: cx.bind().clone(),
302 backends,
303 })
304 }
305
306 async fn request_local_memory(
307 &mut self,
308 _cx: &Context<Self>,
309 remote_buf_id: usize,
310 ) -> Result<Option<Arc<dyn RdmaLocalMemory>>, anyhow::Error> {
311 Ok(self.buffers.get(&remote_buf_id).cloned())
312 }
313}