1use std::collections::HashMap;
12use std::mem::replace;
13use std::mem::take;
14use std::sync::Arc;
15use std::sync::Mutex;
16use std::sync::RwLock;
17use std::sync::RwLockWriteGuard;
18
19use async_trait::async_trait;
20use enum_as_inner::EnumAsInner;
21use hyperactor::Actor;
22use hyperactor::ActorHandle;
23use hyperactor::ActorId;
24use hyperactor::Bind;
25use hyperactor::Context;
26use hyperactor::Data;
27use hyperactor::HandleClient;
28use hyperactor::Handler;
29use hyperactor::Instance;
30use hyperactor::Named;
31use hyperactor::OncePortRef;
32use hyperactor::PortHandle;
33use hyperactor::PortRef;
34use hyperactor::ProcId;
35use hyperactor::RefClient;
36use hyperactor::Unbind;
37use hyperactor::actor::ActorStatus;
38use hyperactor::actor::remote::Remote;
39use hyperactor::channel;
40use hyperactor::channel::ChannelAddr;
41use hyperactor::clock::Clock;
42use hyperactor::clock::RealClock;
43use hyperactor::mailbox::BoxedMailboxSender;
44use hyperactor::mailbox::DialMailboxRouter;
45use hyperactor::mailbox::IntoBoxedMailboxSender;
46use hyperactor::mailbox::MailboxClient;
47use hyperactor::mailbox::MailboxSender;
48use hyperactor::mailbox::MessageEnvelope;
49use hyperactor::mailbox::Undeliverable;
50use hyperactor::proc::Proc;
51use hyperactor::supervision::ActorSupervisionEvent;
52use serde::Deserialize;
53use serde::Serialize;
54
55use crate::proc_mesh::SupervisionEventState;
56use crate::resource;
57use crate::v1::Name;
58
59#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Named)]
60pub enum GspawnResult {
61 Success { rank: usize, actor_id: ActorId },
62 Error(String),
63}
64
65#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
66pub enum StopActorResult {
67 Success,
68 Timeout,
69 NotFound,
70}
71
72#[derive(
73 Debug,
74 Clone,
75 PartialEq,
76 Serialize,
77 Deserialize,
78 Handler,
79 HandleClient,
80 RefClient,
81 Named
82)]
83pub(crate) enum MeshAgentMessage {
84 Configure {
86 rank: usize,
88 forwarder: ChannelAddr,
90 supervisor: Option<PortRef<ActorSupervisionEvent>>,
92 address_book: HashMap<ProcId, ChannelAddr>,
94 configured: PortRef<usize>,
97 record_supervision_events: bool,
99 },
100
101 Status {
102 status: PortRef<(usize, bool)>,
106 },
107
108 Gspawn {
110 actor_type: String,
112 actor_name: String,
114 params_data: Data,
116 status_port: PortRef<GspawnResult>,
118 },
119
120 StopActor {
122 actor_id: ActorId,
124 timeout_ms: u64,
126 #[reply]
128 stopped: OncePortRef<StopActorResult>,
129 },
130}
131
132#[derive(Debug, EnumAsInner, Default)]
134enum State {
135 UnconfiguredV0 {
136 sender: ReconfigurableMailboxSender,
137 },
138
139 ConfiguredV0 {
140 sender: ReconfigurableMailboxSender,
141 rank: usize,
142 supervisor: Option<PortRef<ActorSupervisionEvent>>,
143 },
144
145 V1,
146
147 #[default]
148 Invalid,
149}
150
151impl State {
152 fn rank(&self) -> Option<usize> {
153 match self {
154 State::ConfiguredV0 { rank, .. } => Some(*rank),
155 _ => None,
156 }
157 }
158
159 fn supervisor(&self) -> Option<PortRef<ActorSupervisionEvent>> {
160 match self {
161 State::ConfiguredV0 { supervisor, .. } => supervisor.clone(),
162 _ => None,
163 }
164 }
165}
166
167#[derive(Debug)]
169struct ActorInstanceState {
170 create_rank: usize,
171 spawn: Result<ActorId, anyhow::Error>,
172}
173
174#[derive(Debug)]
176#[hyperactor::export(
177 handlers=[
178 MeshAgentMessage,
179 resource::CreateOrUpdate<ActorSpec> { cast = true },
180 resource::GetState<ActorState> { cast = true },
181 resource::GetRankStatus { cast = true },
182 ]
183)]
184pub struct ProcMeshAgent {
185 proc: Proc,
186 remote: Remote,
187 state: State,
188 actor_states: HashMap<Name, ActorInstanceState>,
190 record_supervision_events: bool,
193 supervision_events: HashMap<ActorId, Vec<ActorSupervisionEvent>>,
196}
197
198impl ProcMeshAgent {
199 #[hyperactor::observe_result("MeshAgent")]
200 pub(crate) async fn bootstrap(
201 proc_id: ProcId,
202 ) -> Result<(Proc, ActorHandle<Self>), anyhow::Error> {
203 let sender = ReconfigurableMailboxSender::new();
204 let proc = Proc::new(proc_id.clone(), BoxedMailboxSender::new(sender.clone()));
205
206 super::router::global().bind(proc_id.into(), proc.clone());
209
210 let agent = ProcMeshAgent {
211 proc: proc.clone(),
212 remote: Remote::collect(),
213 state: State::UnconfiguredV0 { sender },
214 actor_states: HashMap::new(),
215 record_supervision_events: false,
216 supervision_events: HashMap::new(),
217 };
218 let handle = proc.spawn::<Self>("mesh", agent).await?;
219 Ok((proc, handle))
220 }
221
222 pub(crate) async fn boot_v1(proc: Proc) -> Result<ActorHandle<Self>, anyhow::Error> {
223 let agent = ProcMeshAgent {
224 proc: proc.clone(),
225 remote: Remote::collect(),
226 state: State::V1,
227 actor_states: HashMap::new(),
228 record_supervision_events: true,
229 supervision_events: HashMap::new(),
230 };
231 proc.spawn::<Self>("agent", agent).await
232 }
233}
234
235#[async_trait]
236impl Actor for ProcMeshAgent {
237 type Params = Self;
238
239 async fn new(params: Self::Params) -> Result<Self, anyhow::Error> {
240 Ok(params)
241 }
242
243 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
244 self.proc.set_supervision_coordinator(this.port())?;
245 Ok(())
246 }
247}
248
249#[async_trait]
250#[hyperactor::forward(MeshAgentMessage)]
251impl MeshAgentMessageHandler for ProcMeshAgent {
252 async fn configure(
253 &mut self,
254 cx: &Context<Self>,
255 rank: usize,
256 forwarder: ChannelAddr,
257 supervisor: Option<PortRef<ActorSupervisionEvent>>,
258 address_book: HashMap<ProcId, ChannelAddr>,
259 configured: PortRef<usize>,
260 record_supervision_events: bool,
261 ) -> Result<(), anyhow::Error> {
262 anyhow::ensure!(
263 self.state.is_unconfigured_v0(),
264 "mesh agent cannot be (re-)configured"
265 );
266 self.record_supervision_events = record_supervision_events;
267
268 let client = MailboxClient::new(channel::dial(forwarder)?);
271
272 let router = if std::env::var("HYPERACTOR_MESH_ROUTER_NO_GLOBAL_FALLBACK").is_err() {
276 let default = super::router::global().fallback(client.into_boxed());
277 DialMailboxRouter::new_with_default_direct_addressed_remote_only(default.into_boxed())
278 } else {
279 DialMailboxRouter::new_with_default_direct_addressed_remote_only(client.into_boxed())
280 };
281
282 for (proc_id, addr) in address_book {
283 router.bind(proc_id.into(), addr);
284 }
285
286 let sender = take(&mut self.state).into_unconfigured_v0().unwrap();
287 assert!(sender.configure(router.into_boxed()));
288
289 self.state = State::ConfiguredV0 {
293 sender,
294 rank,
295 supervisor,
296 };
297 configured.send(cx, rank)?;
298
299 Ok(())
300 }
301
302 async fn gspawn(
303 &mut self,
304 cx: &Context<Self>,
305 actor_type: String,
306 actor_name: String,
307 params_data: Data,
308 status_port: PortRef<GspawnResult>,
309 ) -> Result<(), anyhow::Error> {
310 anyhow::ensure!(
311 self.state.is_configured_v0(),
312 "mesh agent is not v0 configured"
313 );
314 let actor_id = match self
315 .remote
316 .gspawn(&self.proc, &actor_type, &actor_name, params_data)
317 .await
318 {
319 Ok(id) => id,
320 Err(err) => {
321 status_port.send(cx, GspawnResult::Error(format!("gspawn failed: {}", err)))?;
322 return Err(anyhow::anyhow!("gspawn failed"));
323 }
324 };
325 status_port.send(
326 cx,
327 GspawnResult::Success {
328 rank: self.state.rank().unwrap(),
329 actor_id,
330 },
331 )?;
332 Ok(())
333 }
334
335 async fn stop_actor(
336 &mut self,
337 _cx: &Context<Self>,
338 actor_id: ActorId,
339 timeout_ms: u64,
340 ) -> Result<StopActorResult, anyhow::Error> {
341 tracing::info!(
342 name = "StopActor",
343 actor_id = %actor_id,
344 actor_name = actor_id.name(),
345 );
346
347 if let Some(mut status) = self.proc.stop_actor(&actor_id) {
348 match RealClock
349 .timeout(
350 tokio::time::Duration::from_millis(timeout_ms),
351 status.wait_for(|state: &ActorStatus| matches!(*state, ActorStatus::Stopped)),
352 )
353 .await
354 {
355 Ok(_) => Ok(StopActorResult::Success),
356 Err(_) => Ok(StopActorResult::Timeout),
357 }
358 } else {
359 Ok(StopActorResult::NotFound)
360 }
361 }
362
363 async fn status(
364 &mut self,
365 cx: &Context<Self>,
366 status_port: PortRef<(usize, bool)>,
367 ) -> Result<(), anyhow::Error> {
368 match &self.state {
369 State::ConfiguredV0 { rank, .. } => {
370 status_port.send(cx, (*rank, true))?;
372 Ok(())
373 }
374 State::UnconfiguredV0 { .. } => {
375 Err(anyhow::anyhow!(
377 "status unavailable: v0 agent not configured (waiting for Configure)"
378 ))
379 }
380 State::V1 => {
381 Err(anyhow::anyhow!(
383 "status unsupported in v1/owned path (no rank)"
384 ))
385 }
386 State::Invalid => Err(anyhow::anyhow!(
387 "status unavailable: agent in invalid state"
388 )),
389 }
390 }
391}
392
393#[async_trait]
394impl Handler<ActorSupervisionEvent> for ProcMeshAgent {
395 async fn handle(
396 &mut self,
397 cx: &Context<Self>,
398 event: ActorSupervisionEvent,
399 ) -> anyhow::Result<()> {
400 if self.record_supervision_events {
401 tracing::info!("Received supervision event: {:?}, recording", event);
402 self.supervision_events
403 .entry(event.actor_id.clone())
404 .or_default()
405 .push(event.clone());
406 }
407 if let Some(supervisor) = self.state.supervisor() {
408 supervisor.send(cx, event)?;
409 } else if !self.record_supervision_events {
410 tracing::error!(
413 name = SupervisionEventState::SupervisionEventTransmitFailed.as_ref(),
414 "proc {}: could not propagate supervision event {:?}: crashing",
415 cx.self_id().proc_id(),
416 event
417 );
418
419 std::process::exit(1);
422 }
423 Ok(())
424 }
425}
426
427#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named)]
431pub struct ActorSpec {
432 pub actor_type: String,
434 pub params_data: Data,
436}
437
438#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
440pub struct ActorState {
441 pub actor_id: ActorId,
443 pub create_rank: usize,
445 pub supervision_events: Vec<ActorSupervisionEvent>,
447}
448
449#[async_trait]
450impl Handler<resource::CreateOrUpdate<ActorSpec>> for ProcMeshAgent {
451 async fn handle(
452 &mut self,
453 _cx: &Context<Self>,
454 create_or_update: resource::CreateOrUpdate<ActorSpec>,
455 ) -> anyhow::Result<()> {
456 if self.actor_states.contains_key(&create_or_update.name) {
457 return Ok(());
459 }
460
461 let ActorSpec {
462 actor_type,
463 params_data,
464 } = create_or_update.spec;
465 self.actor_states.insert(
466 create_or_update.name.clone(),
467 ActorInstanceState {
468 create_rank: create_or_update.rank.unwrap(),
469 spawn: self
470 .remote
471 .gspawn(
472 &self.proc,
473 &actor_type,
474 &create_or_update.name.to_string(),
475 params_data,
476 )
477 .await,
478 },
479 );
480
481 Ok(())
482 }
483}
484
485#[async_trait]
486impl Handler<resource::GetRankStatus> for ProcMeshAgent {
487 async fn handle(
488 &mut self,
489 cx: &Context<Self>,
490 get_rank_status: resource::GetRankStatus,
491 ) -> anyhow::Result<()> {
492 let (rank, status) = match self.actor_states.get(&get_rank_status.name) {
493 Some(ActorInstanceState {
494 spawn: Ok(actor_id),
495 create_rank,
496 }) => {
497 let supervision_events = self
498 .supervision_events
499 .get(actor_id)
500 .map_or_else(Vec::new, |a| a.clone());
501 (
502 *create_rank,
503 if supervision_events.is_empty() {
504 resource::Status::Running
505 } else {
506 resource::Status::Failed(format!(
507 "because of supervision events: {:?}",
508 supervision_events
509 ))
510 },
511 )
512 }
513 Some(ActorInstanceState {
514 spawn: Err(e),
515 create_rank,
516 }) => (*create_rank, resource::Status::Failed(e.to_string())),
517 None => (usize::MAX, resource::Status::NotExist),
519 };
520
521 get_rank_status.reply.send(cx, (rank, status).into())?;
522 Ok(())
523 }
524}
525
526#[async_trait]
527impl Handler<resource::GetState<ActorState>> for ProcMeshAgent {
528 async fn handle(
529 &mut self,
530 cx: &Context<Self>,
531 get_state: resource::GetState<ActorState>,
532 ) -> anyhow::Result<()> {
533 let state = match self.actor_states.get(&get_state.name) {
534 Some(ActorInstanceState {
535 create_rank,
536 spawn: Ok(actor_id),
537 }) => {
538 let supervision_events = self
539 .supervision_events
540 .get(actor_id)
541 .map_or_else(Vec::new, |a| a.clone());
542 let status = if supervision_events.is_empty() {
543 resource::Status::Running
544 } else {
545 resource::Status::Failed(format!(
546 "because of supervision events: {:?}",
547 supervision_events
548 ))
549 };
550 resource::State {
551 name: get_state.name.clone(),
552 status,
553 state: Some(ActorState {
554 actor_id: actor_id.clone(),
555 create_rank: *create_rank,
556 supervision_events,
557 }),
558 }
559 }
560 Some(ActorInstanceState { spawn: Err(e), .. }) => resource::State {
561 name: get_state.name.clone(),
562 status: resource::Status::Failed(e.to_string()),
563 state: None,
564 },
565 None => resource::State {
566 name: get_state.name.clone(),
567 status: resource::Status::NotExist,
568 state: None,
569 },
570 };
571
572 get_state.reply.send(cx, state)?;
573 Ok(())
574 }
575}
576
577#[derive(Clone)]
580pub(crate) struct ReconfigurableMailboxSender {
581 state: Arc<RwLock<ReconfigurableMailboxSenderState>>,
582}
583
584impl std::fmt::Debug for ReconfigurableMailboxSender {
585 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
586 f.debug_struct("ReconfigurableMailboxSender").finish()
589 }
590}
591
592pub(crate) struct ReconfigurableMailboxSenderInner<'a> {
593 guard: RwLockWriteGuard<'a, ReconfigurableMailboxSenderState>,
594}
595
596impl<'a> ReconfigurableMailboxSenderInner<'a> {
597 pub(crate) fn as_configured(&self) -> Option<&BoxedMailboxSender> {
598 self.guard.as_configured()
599 }
600}
601
602type Post = (MessageEnvelope, PortHandle<Undeliverable<MessageEnvelope>>);
603
604#[derive(EnumAsInner, Debug)]
605enum ReconfigurableMailboxSenderState {
606 Queueing(Mutex<Vec<Post>>),
607 Configured(BoxedMailboxSender),
608}
609
610impl ReconfigurableMailboxSender {
611 pub(crate) fn new() -> Self {
612 Self {
613 state: Arc::new(RwLock::new(ReconfigurableMailboxSenderState::Queueing(
614 Mutex::new(Vec::new()),
615 ))),
616 }
617 }
618
619 pub(crate) fn configure(&self, sender: BoxedMailboxSender) -> bool {
623 let mut state = self.state.write().unwrap();
624 if state.is_configured() {
625 return false;
626 }
627
628 let queued = replace(
629 &mut *state,
630 ReconfigurableMailboxSenderState::Configured(sender.clone()),
631 );
632
633 for (envelope, return_handle) in queued.into_queueing().unwrap().into_inner().unwrap() {
634 sender.post(envelope, return_handle);
635 }
636 *state = ReconfigurableMailboxSenderState::Configured(sender);
637 true
638 }
639
640 pub(crate) fn as_inner<'a>(
641 &'a self,
642 ) -> Result<ReconfigurableMailboxSenderInner<'a>, anyhow::Error> {
643 let state = self.state.write().unwrap();
644 if state.is_configured() {
645 Ok(ReconfigurableMailboxSenderInner { guard: state })
646 } else {
647 Err(anyhow::anyhow!("cannot get inner sender: not configured"))
648 }
649 }
650}
651
652impl MailboxSender for ReconfigurableMailboxSender {
653 fn post(
654 &self,
655 envelope: MessageEnvelope,
656 return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
657 ) {
658 match *self.state.read().unwrap() {
659 ReconfigurableMailboxSenderState::Queueing(ref queue) => {
660 queue.lock().unwrap().push((envelope, return_handle));
661 }
662 ReconfigurableMailboxSenderState::Configured(ref sender) => {
663 sender.post(envelope, return_handle);
664 }
665 }
666 }
667
668 fn post_unchecked(
669 &self,
670 envelope: MessageEnvelope,
671 return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
672 ) {
673 match *self.state.read().unwrap() {
674 ReconfigurableMailboxSenderState::Queueing(ref queue) => {
675 queue.lock().unwrap().push((envelope, return_handle));
676 }
677 ReconfigurableMailboxSenderState::Configured(ref sender) => {
678 sender.post_unchecked(envelope, return_handle);
679 }
680 }
681 }
682}
683
684#[cfg(test)]
685mod tests {
686 use std::sync::Arc;
687 use std::sync::Mutex;
688
689 use hyperactor::attrs::Attrs;
690 use hyperactor::id;
691 use hyperactor::mailbox::BoxedMailboxSender;
692 use hyperactor::mailbox::Mailbox;
693 use hyperactor::mailbox::MailboxSender;
694 use hyperactor::mailbox::MessageEnvelope;
695 use hyperactor::mailbox::PortHandle;
696 use hyperactor::mailbox::Undeliverable;
697
698 use super::*;
699
700 #[derive(Debug, Clone)]
701 struct QueueingMailboxSender {
702 messages: Arc<Mutex<Vec<MessageEnvelope>>>,
703 }
704
705 impl QueueingMailboxSender {
706 fn new() -> Self {
707 Self {
708 messages: Arc::new(Mutex::new(Vec::new())),
709 }
710 }
711
712 fn get_messages(&self) -> Vec<MessageEnvelope> {
713 self.messages.lock().unwrap().clone()
714 }
715 }
716
717 impl MailboxSender for QueueingMailboxSender {
718 fn post_unchecked(
719 &self,
720 envelope: MessageEnvelope,
721 _return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
722 ) {
723 self.messages.lock().unwrap().push(envelope);
724 }
725 }
726
727 fn envelope(data: u64) -> MessageEnvelope {
729 MessageEnvelope::serialize(
730 id!(world[0].sender),
731 id!(world[0].receiver[0][1]),
732 &data,
733 Attrs::new(),
734 )
735 .unwrap()
736 }
737
738 fn return_handle() -> PortHandle<Undeliverable<MessageEnvelope>> {
739 let mbox = Mailbox::new_detached(id!(test[0].test));
740 let (port, _receiver) = mbox.open_port::<Undeliverable<MessageEnvelope>>();
741 port
742 }
743
744 #[test]
745 fn test_queueing_before_configure() {
746 let sender = ReconfigurableMailboxSender::new();
747
748 let test_sender = QueueingMailboxSender::new();
749 let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
750
751 let return_handle = return_handle();
752 sender.post(envelope(1), return_handle.clone());
753 sender.post(envelope(2), return_handle.clone());
754
755 assert_eq!(test_sender.get_messages().len(), 0);
756
757 sender.configure(boxed_sender);
758
759 let messages = test_sender.get_messages();
760 assert_eq!(messages.len(), 2);
761
762 assert_eq!(messages[0].deserialized::<u64>().unwrap(), 1);
763 assert_eq!(messages[1].deserialized::<u64>().unwrap(), 2);
764 }
765
766 #[test]
767 fn test_direct_delivery_after_configure() {
768 let sender = ReconfigurableMailboxSender::new();
770
771 let test_sender = QueueingMailboxSender::new();
772 let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
773 sender.configure(boxed_sender);
774
775 let return_handle = return_handle();
776 sender.post(envelope(3), return_handle.clone());
777 sender.post(envelope(4), return_handle.clone());
778
779 let messages = test_sender.get_messages();
780 assert_eq!(messages.len(), 2);
781
782 assert_eq!(messages[0].deserialized::<u64>().unwrap(), 3);
783 assert_eq!(messages[1].deserialized::<u64>().unwrap(), 4);
784 }
785
786 #[test]
787 fn test_multiple_configurations() {
788 let sender = ReconfigurableMailboxSender::new();
789 let boxed_sender = BoxedMailboxSender::new(QueueingMailboxSender::new());
790
791 assert!(sender.configure(boxed_sender.clone()));
792 assert!(!sender.configure(boxed_sender));
793 }
794
795 #[test]
796 fn test_mixed_queueing_and_direct_delivery() {
797 let sender = ReconfigurableMailboxSender::new();
798
799 let test_sender = QueueingMailboxSender::new();
800 let boxed_sender = BoxedMailboxSender::new(test_sender.clone());
801
802 let return_handle = return_handle();
803 sender.post(envelope(5), return_handle.clone());
804 sender.post(envelope(6), return_handle.clone());
805
806 sender.configure(boxed_sender);
807
808 sender.post(envelope(7), return_handle.clone());
809 sender.post(envelope(8), return_handle.clone());
810
811 let messages = test_sender.get_messages();
812 assert_eq!(messages.len(), 4);
813
814 assert_eq!(messages[0].deserialized::<u64>().unwrap(), 5);
815 assert_eq!(messages[1].deserialized::<u64>().unwrap(), 6);
816 assert_eq!(messages[2].deserialized::<u64>().unwrap(), 7);
817 assert_eq!(messages[3].deserialized::<u64>().unwrap(), 8);
818 }
819}