hyperactor_mesh/proc_mesh/
mesh_agent.rs1use std::collections::HashMap;
12use std::mem::replace;
13use std::sync::Arc;
14use std::sync::Mutex;
15use std::sync::RwLock;
16
17use async_trait::async_trait;
18use enum_as_inner::EnumAsInner;
19use hyperactor::Actor;
20use hyperactor::ActorHandle;
21use hyperactor::ActorId;
22use hyperactor::Context;
23use hyperactor::Data;
24use hyperactor::HandleClient;
25use hyperactor::Handler;
26use hyperactor::Instance;
27use hyperactor::Named;
28use hyperactor::OncePortRef;
29use hyperactor::PortHandle;
30use hyperactor::PortRef;
31use hyperactor::ProcId;
32use hyperactor::RefClient;
33use hyperactor::actor::ActorStatus;
34use hyperactor::actor::remote::Remote;
35use hyperactor::channel;
36use hyperactor::channel::ChannelAddr;
37use hyperactor::clock::Clock;
38use hyperactor::clock::RealClock;
39use hyperactor::mailbox::BoxedMailboxSender;
40use hyperactor::mailbox::DialMailboxRouter;
41use hyperactor::mailbox::IntoBoxedMailboxSender;
42use hyperactor::mailbox::MailboxClient;
43use hyperactor::mailbox::MailboxSender;
44use hyperactor::mailbox::MessageEnvelope;
45use hyperactor::mailbox::Undeliverable;
46use hyperactor::proc::Proc;
47use hyperactor::supervision::ActorSupervisionEvent;
48use serde::Deserialize;
49use serde::Serialize;
50
51#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Named)]
52pub enum GspawnResult {
53 Success { rank: usize, actor_id: ActorId },
54 Error(String),
55}
56
57#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
58pub enum StopActorResult {
59 Success,
60 Timeout,
61 NotFound,
62}
63
64#[derive(
65 Debug,
66 Clone,
67 PartialEq,
68 Serialize,
69 Deserialize,
70 Handler,
71 HandleClient,
72 RefClient,
73 Named
74)]
75pub(crate) enum MeshAgentMessage {
76 Configure {
78 rank: usize,
80 forwarder: ChannelAddr,
82 supervisor: PortRef<ActorSupervisionEvent>,
84 address_book: HashMap<ProcId, ChannelAddr>,
86 configured: PortRef<usize>,
89 },
90
91 Gspawn {
93 actor_type: String,
95 actor_name: String,
97 params_data: Data,
99 status_port: PortRef<GspawnResult>,
101 },
102
103 StopActor {
105 actor_id: ActorId,
107 timeout_ms: u64,
109 #[reply]
111 stopped: OncePortRef<StopActorResult>,
112 },
113}
114
115#[derive(Debug)]
117#[hyperactor::export(handlers=[MeshAgentMessage])]
118pub struct MeshAgent {
119 proc: Proc,
120 remote: Remote,
121 sender: ReconfigurableMailboxSender,
122 rank: Option<usize>,
123 supervisor: Option<PortRef<ActorSupervisionEvent>>,
124}
125
126impl MeshAgent {
127 #[hyperactor::instrument]
128 pub(crate) async fn bootstrap(
129 proc_id: ProcId,
130 ) -> Result<(Proc, ActorHandle<Self>), anyhow::Error> {
131 let sender = ReconfigurableMailboxSender::new();
132 let proc = Proc::new(proc_id.clone(), BoxedMailboxSender::new(sender.clone()));
133
134 super::global_router().bind(proc_id.into(), proc.clone());
137
138 let agent = MeshAgent {
139 proc: proc.clone(),
140 remote: Remote::collect(),
141 sender,
142 rank: None, supervisor: None, };
145 let handle = proc.spawn::<Self>("mesh", agent).await?;
146 tracing::info!("bootstrap_end");
147 Ok((proc, handle))
148 }
149}
150
151#[async_trait]
152impl Actor for MeshAgent {
153 type Params = Self;
154
155 async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
156 Ok(params)
157 }
158
159 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
160 self.proc.set_supervision_coordinator(this.port())?;
161 Ok(())
162 }
163}
164
165#[async_trait]
166#[hyperactor::forward(MeshAgentMessage)]
167impl MeshAgentMessageHandler for MeshAgent {
168 async fn configure(
169 &mut self,
170 cx: &Context<Self>,
171 rank: usize,
172 forwarder: ChannelAddr,
173 supervisor: PortRef<ActorSupervisionEvent>,
174 address_book: HashMap<ProcId, ChannelAddr>,
175 configured: PortRef<usize>,
176 ) -> Result<(), anyhow::Error> {
177 self.supervisor = Some(supervisor);
181
182 let client = MailboxClient::new(channel::dial(forwarder)?);
185
186 let router = if std::env::var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK").is_err() {
190 let default = super::global_router().fallback(client.into_boxed());
191 DialMailboxRouter::new_with_default(default.into_boxed())
192 } else {
193 DialMailboxRouter::new_with_default(client.into_boxed())
194 };
195
196 for (proc_id, addr) in address_book {
197 router.bind(proc_id.into(), addr);
198 }
199
200 if self.sender.configure(router.into_boxed()) {
201 self.rank = Some(rank);
202 configured.send(cx, rank)?;
203 } else {
204 tracing::error!("tried to reconfigure mesh agent");
205 }
206 Ok(())
207 }
208
209 async fn gspawn(
210 &mut self,
211 cx: &Context<Self>,
212 actor_type: String,
213 actor_name: String,
214 params_data: Data,
215 status_port: PortRef<GspawnResult>,
216 ) -> Result<(), anyhow::Error> {
217 let actor_id = match self
218 .remote
219 .gspawn(&self.proc, &actor_type, &actor_name, params_data)
220 .await
221 {
222 Ok(id) => id,
223 Err(err) => {
224 status_port.send(cx, GspawnResult::Error(format!("gspawn failed: {}", err)))?;
225 return Err(anyhow::anyhow!("gspawn failed"));
226 }
227 };
228 let rank = match self.rank {
229 Some(rank) => rank,
230 None => {
231 let err = "tried to spawn on unconfigured proc";
232 status_port.send(cx, GspawnResult::Error(err.to_string()))?;
233 return Err(anyhow::anyhow!(err));
234 }
235 };
236 status_port.send(cx, GspawnResult::Success { rank, actor_id })?;
237 Ok(())
238 }
239
240 async fn stop_actor(
241 &mut self,
242 _cx: &Context<Self>,
243 actor_id: ActorId,
244 timeout_ms: u64,
245 ) -> Result<StopActorResult, anyhow::Error> {
246 tracing::info!("Stopping actor: {}", actor_id);
247
248 if let Some(mut status) = self.proc.stop_actor(&actor_id) {
249 match RealClock
250 .timeout(
251 tokio::time::Duration::from_millis(timeout_ms),
252 status.wait_for(|state: &ActorStatus| matches!(*state, ActorStatus::Stopped)),
253 )
254 .await
255 {
256 Ok(_) => Ok(StopActorResult::Success),
257 Err(_) => Ok(StopActorResult::Timeout),
258 }
259 } else {
260 Ok(StopActorResult::NotFound)
261 }
262 }
263}
264
265#[async_trait]
266impl Handler<ActorSupervisionEvent> for MeshAgent {
267 async fn handle(
268 &mut self,
269 cx: &Context<Self>,
270 event: ActorSupervisionEvent,
271 ) -> anyhow::Result<()> {
272 if let Some(supervisor) = &self.supervisor {
273 supervisor.send(cx, event)?;
274 } else {
275 tracing::error!(
276 "proc {}: could not propagate supervision event {:?}: crashing",
277 cx.self_id().proc_id(),
278 event
279 );
280
281 std::process::exit(1);
284 }
285 Ok(())
286 }
287}
288
289#[derive(Clone)]
292pub(crate) struct ReconfigurableMailboxSender {
293 state: Arc<RwLock<ReconfigurableMailboxSenderState>>,
294}
295
296impl std::fmt::Debug for ReconfigurableMailboxSender {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 f.debug_struct("ReconfigurableMailboxSender").finish()
301 }
302}
303
304type Post = (MessageEnvelope, PortHandle<Undeliverable<MessageEnvelope>>);
305
306#[derive(EnumAsInner, Debug)]
307enum ReconfigurableMailboxSenderState {
308 Queueing(Mutex<Vec<Post>>),
309 Configured(BoxedMailboxSender),
310}
311
312impl ReconfigurableMailboxSender {
313 pub(crate) fn new() -> Self {
314 Self {
315 state: Arc::new(RwLock::new(ReconfigurableMailboxSenderState::Queueing(
316 Mutex::new(Vec::new()),
317 ))),
318 }
319 }
320
321 pub(crate) fn configure(&self, sender: BoxedMailboxSender) -> bool {
325 let mut state = self.state.write().unwrap();
326 if state.is_configured() {
327 return false;
328 }
329
330 let queued = replace(
331 &mut *state,
332 ReconfigurableMailboxSenderState::Configured(sender.clone()),
333 );
334
335 for (envelope, return_handle) in queued.into_queueing().unwrap().into_inner().unwrap() {
336 sender.post(envelope, return_handle);
337 }
338 *state = ReconfigurableMailboxSenderState::Configured(sender);
339 true
340 }
341}
342
343impl MailboxSender for ReconfigurableMailboxSender {
344 fn post(
345 &self,
346 envelope: MessageEnvelope,
347 return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
348 ) {
349 match *self.state.read().unwrap() {
350 ReconfigurableMailboxSenderState::Queueing(ref queue) => {
351 queue.lock().unwrap().push((envelope, return_handle));
352 }
353 ReconfigurableMailboxSenderState::Configured(ref sender) => {
354 sender.post(envelope, return_handle);
355 }
356 }
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use std::sync::Arc;
363 use std::sync::Mutex;
364
365 use hyperactor::attrs::Attrs;
366 use hyperactor::id;
367 use hyperactor::mailbox::BoxedMailboxSender;
368 use hyperactor::mailbox::Mailbox;
369 use hyperactor::mailbox::MailboxSender;
370 use hyperactor::mailbox::MessageEnvelope;
371 use hyperactor::mailbox::PortHandle;
372 use hyperactor::mailbox::Undeliverable;
373
374 use super::*;
375
376 #[derive(Debug, Clone)]
377 struct QueueingMailboxSender {
378 messages: Arc<Mutex<Vec<MessageEnvelope>>>,
379 }
380
381 impl QueueingMailboxSender {
382 fn new() -> Self {
383 Self {
384 messages: Arc::new(Mutex::new(Vec::new())),
385 }
386 }
387
388 fn get_messages(&self) -> Vec<MessageEnvelope> {
389 self.messages.lock().unwrap().clone()
390 }
391 }
392
393 impl MailboxSender for QueueingMailboxSender {
394 fn post(
395 &self,
396 envelope: MessageEnvelope,
397 _return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
398 ) {
399 self.messages.lock().unwrap().push(envelope);
400 }
401 }
402
403 fn envelope(data: u64) -> MessageEnvelope {
405 MessageEnvelope::serialize(
406 id!(world[0].sender),
407 id!(world[0].receiver[0][1]),
408 &data,
409 Attrs::new(),
410 )
411 .unwrap()
412 }
413
414 fn return_handle() -> PortHandle<Undeliverable<MessageEnvelope>> {
415 let mbox = Mailbox::new_detached(id!(test[0].test));
416 let (port, _receiver) = mbox.open_port::<Undeliverable<MessageEnvelope>>();
417 port
418 }
419
420 #[test]
421 fn test_queueing_before_configure() {
422 let sender = ReconfigurableMailboxSender::new();
423
424 let test_sender = QueueingMailboxSender::new();
425 let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
426
427 let return_handle = return_handle();
428 sender.post(envelope(1), return_handle.clone());
429 sender.post(envelope(2), return_handle.clone());
430
431 assert_eq!(test_sender.get_messages().len(), 0);
432
433 sender.configure(boxed_sender);
434
435 let messages = test_sender.get_messages();
436 assert_eq!(messages.len(), 2);
437
438 assert_eq!(messages[0].deserialized::<u64>().unwrap(), 1);
439 assert_eq!(messages[1].deserialized::<u64>().unwrap(), 2);
440 }
441
442 #[test]
443 fn test_direct_delivery_after_configure() {
444 let sender = ReconfigurableMailboxSender::new();
446
447 let test_sender = QueueingMailboxSender::new();
448 let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
449 sender.configure(boxed_sender);
450
451 let return_handle = return_handle();
452 sender.post(envelope(3), return_handle.clone());
453 sender.post(envelope(4), return_handle.clone());
454
455 let messages = test_sender.get_messages();
456 assert_eq!(messages.len(), 2);
457
458 assert_eq!(messages[0].deserialized::<u64>().unwrap(), 3);
459 assert_eq!(messages[1].deserialized::<u64>().unwrap(), 4);
460 }
461
462 #[test]
463 fn test_multiple_configurations() {
464 let sender = ReconfigurableMailboxSender::new();
465 let boxed_sender = BoxedMailboxSender::new(QueueingMailboxSender::new());
466
467 assert!(sender.configure(boxed_sender.clone()));
468 assert!(!sender.configure(boxed_sender));
469 }
470
471 #[test]
472 fn test_mixed_queueing_and_direct_delivery() {
473 let sender = ReconfigurableMailboxSender::new();
474
475 let test_sender = QueueingMailboxSender::new();
476 let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
477
478 let return_handle = return_handle();
479 sender.post(envelope(5), return_handle.clone());
480 sender.post(envelope(6), return_handle.clone());
481
482 sender.configure(boxed_sender);
483
484 sender.post(envelope(7), return_handle.clone());
485 sender.post(envelope(8), return_handle.clone());
486
487 let messages = test_sender.get_messages();
488 assert_eq!(messages.len(), 4);
489
490 assert_eq!(messages[0].deserialized::<u64>().unwrap(), 5);
491 assert_eq!(messages[1].deserialized::<u64>().unwrap(), 6);
492 assert_eq!(messages[2].deserialized::<u64>().unwrap(), 7);
493 assert_eq!(messages[3].deserialized::<u64>().unwrap(), 8);
494 }
495}