1use std::collections::HashMap;
10use std::collections::HashSet;
11use std::fmt::Debug;
12
13use async_trait::async_trait;
14use hyperactor::Actor;
15use hyperactor::Bind;
16use hyperactor::Context;
17use hyperactor::Handler;
18use hyperactor::Instance;
19use hyperactor::PortRef;
20use hyperactor::ProcId;
21use hyperactor::Unbind;
22use hyperactor::actor::ActorError;
23use hyperactor::actor::ActorErrorKind;
24use hyperactor::actor::ActorStatus;
25use hyperactor::actor::Referable;
26use hyperactor::actor::handle_undeliverable_message;
27use hyperactor::context;
28use hyperactor::mailbox::MessageEnvelope;
29use hyperactor::mailbox::Undeliverable;
30use hyperactor::supervision::ActorSupervisionEvent;
31use hyperactor_config::CONFIG;
32use hyperactor_config::ConfigAttr;
33use hyperactor_config::attrs::declare_attrs;
34use ndslice::ViewExt;
35use ndslice::view::CollectMeshExt;
36use ndslice::view::Ranked;
37use serde::Deserialize;
38use serde::Serialize;
39use tokio::time::Duration;
40use typeuri::Named;
41
42use crate::actor_mesh::update_undeliverable_envelope_for_casting;
43use crate::bootstrap::ProcStatus;
44use crate::proc_mesh::mesh_agent::ActorState;
45use crate::resource;
46use crate::supervision::MeshFailure;
47use crate::supervision::Unhealthy;
48use crate::v1;
49use crate::v1::Name;
50use crate::v1::ValueMesh;
51use crate::v1::actor_mesh::ActorMeshRef;
52use crate::v1::host_mesh::HostMeshRef;
53use crate::v1::proc_mesh::ProcMeshRef;
54use crate::v1::view::Point;
55
56declare_attrs! {
57 @meta(CONFIG = ConfigAttr {
63 env_name: Some("HYPERACTOR_MESH_SUPERVISION_POLL_FREQUENCY".to_string()),
64 py_name: None,
65 })
66 pub attr SUPERVISION_POLL_FREQUENCY: Duration = Duration::from_secs(3);
67}
68
69#[derive(Debug)]
70struct HealthState {
71 statuses: HashMap<Point, resource::Status>,
74 unhealthy_event: Option<Unhealthy>,
75 crashed_ranks: HashMap<usize, ActorSupervisionEvent>,
76 owner: Option<PortRef<MeshFailure>>,
78 subscribers: HashSet<PortRef<Option<MeshFailure>>>,
80 undeliverable_subscribers: HashSet<PortRef<Option<MeshFailure>>>,
86}
87
88impl HealthState {
89 fn new(
90 statuses: HashMap<Point, resource::Status>,
91 owner: Option<PortRef<MeshFailure>>,
92 ) -> Self {
93 Self {
94 statuses,
95 unhealthy_event: None,
96 crashed_ranks: HashMap::new(),
97 owner,
98 subscribers: HashSet::new(),
99 undeliverable_subscribers: HashSet::new(),
100 }
101 }
102}
103
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
110pub struct Subscribe(pub PortRef<Option<MeshFailure>>);
111
112#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
115pub struct Unsubscribe(pub PortRef<Option<MeshFailure>>);
116
117#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
120pub struct CheckState();
121
122#[hyperactor::export(handlers = [
130 Subscribe,
131 Unsubscribe,
132 resource::CreateOrUpdate<resource::mesh::Spec<()>> { cast = true },
133 resource::GetState<resource::mesh::State<()>> { cast = true },
134 resource::Stop { cast = true },
135])]
136pub struct ActorMeshController<A>
137where
138 A: Referable,
139{
140 mesh: ActorMeshRef<A>,
141 supervision_display_name: String,
142 health_state: HealthState,
144 monitor: Option<()>,
148}
149
150impl<A: Referable> resource::mesh::Mesh for ActorMeshController<A> {
151 type Spec = ();
152 type State = ();
153}
154
155impl<A: Referable> ActorMeshController<A> {
156 pub(crate) fn new(
158 mesh: ActorMeshRef<A>,
159 supervision_display_name: Option<String>,
160 port: Option<PortRef<MeshFailure>>,
161 initial_statuses: ValueMesh<resource::Status>,
162 ) -> Self {
163 let supervision_display_name =
164 supervision_display_name.unwrap_or_else(|| mesh.name().to_string());
165 Self {
166 mesh,
167 supervision_display_name,
168 health_state: HealthState::new(initial_statuses.iter().collect(), port),
169 monitor: None,
170 }
171 }
172
173 async fn stop(&self, cx: &impl context::Actor) -> v1::Result<ValueMesh<resource::Status>> {
174 self.mesh
176 .proc_mesh()
177 .stop_actor_by_name(cx, self.mesh.name().clone())
178 .await
179 }
180
181 fn self_check_state_message(&self, cx: &Instance<Self>) -> Result<(), ActorError> {
182 if self.monitor.is_some() {
184 cx.self_message_with_delay(
185 CheckState {},
186 hyperactor_config::global::get(SUPERVISION_POLL_FREQUENCY),
187 )
188 } else {
189 Ok(())
190 }
191 }
192}
193
194fn send_subscriber_message(
195 cx: &impl context::Actor,
196 subscriber: &PortRef<Option<MeshFailure>>,
197 message: MeshFailure,
198) {
199 if let Err(error) = subscriber.send(cx, Some(message.clone())) {
200 tracing::warn!(
201 event = %message,
202 "failed to send supervision event to subscriber {}: {}",
203 subscriber.port_id(),
204 error
205 );
206 } else {
207 tracing::info!(event = %message, "sent supervision failure message to subscriber {}", subscriber.port_id());
208 }
209}
210
211impl<A: Referable> Debug for ActorMeshController<A> {
212 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213 f.debug_struct("MeshController")
214 .field("mesh", &self.mesh)
215 .field("health_state", &self.health_state)
216 .field("monitor", &self.monitor)
217 .finish()
218 }
219}
220
221#[async_trait]
222impl<A: Referable> Actor for ActorMeshController<A> {
223 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
224 self.monitor = Some(());
229 self.self_check_state_message(this)?;
230 tracing::info!(actor = %this.self_id(), "started mesh controller for {}", self.mesh.name());
231 Ok(())
232 }
233
234 async fn cleanup(
235 &mut self,
236 this: &Instance<Self>,
237 _err: Option<&ActorError>,
238 ) -> Result<(), anyhow::Error> {
239 if self.monitor.take().is_some() {
242 self.stop(this).await?;
243 }
244 Ok(())
245 }
246
247 async fn handle_undeliverable_message(
248 &mut self,
249 cx: &Instance<Self>,
250 envelope: Undeliverable<MessageEnvelope>,
251 ) -> Result<(), anyhow::Error> {
252 let dest_port_id = envelope.clone().into_inner().dest().clone();
255 let port = PortRef::<Option<MeshFailure>>::attest(dest_port_id);
256 let did_exist = self.health_state.subscribers.remove(&port);
259 if did_exist {
260 tracing::debug!(
261 actor = %cx.self_id(),
262 "ActorMeshController: handle_undeliverable_message: removed subscriber {} from mesh controller",
263 port.port_id()
264 );
265 self.health_state.undeliverable_subscribers.insert(port);
267 Ok(())
268 } else if self.health_state.undeliverable_subscribers.contains(&port) {
269 Ok(())
274 } else {
275 handle_undeliverable_message(cx, update_undeliverable_envelope_for_casting(envelope))
278 }
279 }
280}
281
282#[async_trait]
283impl<A: Referable> Handler<Subscribe> for ActorMeshController<A> {
284 async fn handle(&mut self, cx: &Context<Self>, message: Subscribe) -> anyhow::Result<()> {
285 match &self.health_state.unhealthy_event {
289 None => {}
290 Some(Unhealthy::StreamClosed(msg)) => {
294 send_subscriber_message(cx, &message.0, msg.clone());
295 }
296 Some(Unhealthy::Crashed(msg)) => {
297 send_subscriber_message(cx, &message.0, msg.clone());
298 }
299 }
300 self.health_state.subscribers.insert(message.0);
301 Ok(())
302 }
303}
304
305#[async_trait]
306impl<A: Referable> Handler<Unsubscribe> for ActorMeshController<A> {
307 async fn handle(&mut self, _cx: &Context<Self>, message: Unsubscribe) -> anyhow::Result<()> {
308 self.health_state.subscribers.remove(&message.0);
309 Ok(())
310 }
311}
312
313#[async_trait]
314impl<A: Referable> Handler<resource::CreateOrUpdate<resource::mesh::Spec<()>>>
315 for ActorMeshController<A>
316{
317 async fn handle(
320 &mut self,
321 _cx: &Context<Self>,
322 _message: resource::CreateOrUpdate<resource::mesh::Spec<()>>,
323 ) -> anyhow::Result<()> {
324 Ok(())
325 }
326}
327
328#[async_trait]
329impl<A: Referable> Handler<resource::GetState<resource::mesh::State<()>>>
330 for ActorMeshController<A>
331{
332 async fn handle(
333 &mut self,
334 cx: &Context<Self>,
335 message: resource::GetState<resource::mesh::State<()>>,
336 ) -> anyhow::Result<()> {
337 let status = if let Some(Unhealthy::Crashed(e)) = &self.health_state.unhealthy_event {
338 resource::Status::Failed(e.to_string())
339 } else if let Some(Unhealthy::StreamClosed(_)) = &self.health_state.unhealthy_event {
340 resource::Status::Stopped
341 } else {
342 resource::Status::Running
343 };
344 let statuses = &self.health_state.statuses;
345 let mut statuses = statuses.clone().into_iter().collect::<Vec<_>>();
346 statuses.sort_by_key(|(p, _)| p.rank());
347 let statuses: ValueMesh<resource::Status> =
348 statuses
349 .into_iter()
350 .map(|(_, s)| s)
351 .collect_mesh::<ValueMesh<_>>(self.mesh.region().clone())?;
352 let state = resource::mesh::State {
353 statuses,
354 state: (),
355 };
356 message.reply.send(
357 cx,
358 resource::State {
359 name: message.name,
360 status,
361 state: Some(state),
362 },
363 )?;
364 Ok(())
365 }
366}
367
368#[async_trait]
369impl<A: Referable> Handler<resource::Stop> for ActorMeshController<A> {
370 async fn handle(&mut self, cx: &Context<Self>, _message: resource::Stop) -> anyhow::Result<()> {
371 let mesh = &self.mesh;
372 let mesh_name = mesh.name();
373 if self.monitor.take().is_none() {
378 tracing::debug!(actor = %cx.self_id(), %mesh_name, "duplicate stop request, actor mesh is already stopped");
379 return Ok(());
380 }
381 self.health_state.undeliverable_subscribers.clear();
384
385 let rank = 0usize;
392 let event = ActorSupervisionEvent::new(
393 mesh.get(rank).unwrap().actor_id().clone(),
395 None,
396 ActorStatus::Stopped,
397 None,
398 );
399 let message = MeshFailure {
400 actor_mesh_name: Some(mesh_name.to_string()),
401 rank: None,
403 event,
404 };
405 self.health_state.unhealthy_event = Some(Unhealthy::StreamClosed(message.clone()));
406 for subscriber in self.health_state.subscribers.iter() {
410 send_subscriber_message(cx, subscriber, message.clone());
411 }
412
413 let max_rank = self.health_state.statuses.keys().map(|p| p.rank()).max();
416 let extent = self
417 .health_state
418 .statuses
419 .keys()
420 .next()
421 .map(|p| p.extent().clone());
422 match self.stop(cx).await {
424 Ok(statuses) => {
425 for (rank, status) in statuses.iter() {
427 self.health_state
428 .statuses
429 .entry(rank)
430 .and_modify(move |s| *s = status);
431 }
432 }
433 Err(v1::Error::ActorStopError { statuses }) => {
435 if let Some(max_rank) = max_rank {
437 let extent = extent.expect("no actors in mesh");
438 for (rank, status) in statuses.materialized_iter(max_rank).enumerate() {
439 *self
440 .health_state
441 .statuses
442 .get_mut(&extent.point_of_rank(rank).expect("illegal rank"))
443 .unwrap() = status.clone();
444 }
445 }
446 }
447 Err(e) => {
449 return Err(e.into());
450 }
451 }
452
453 tracing::info!(actor = %cx.self_id(), %mesh_name, "stopped mesh");
454 Ok(())
455 }
456}
457
458fn send_heartbeat(cx: &impl context::Actor, health_state: &HealthState) {
465 tracing::debug!("sending heartbeat to subscribers");
466
467 for subscriber in health_state.subscribers.iter() {
468 if let Err(e) = subscriber.send(cx, None) {
469 tracing::warn!(subscriber = %subscriber.port_id(), "error sending heartbeat message: {:?}", e);
470 }
471 }
472}
473
474fn send_state_change(
479 cx: &impl context::Actor,
480 rank: usize,
481 event: ActorSupervisionEvent,
482 mesh_name: &Name,
483 is_proc_stopped: bool,
484 health_state: &mut HealthState,
485) {
486 let is_failed = event.is_error();
489 if is_failed {
490 tracing::warn!(
491 name = "SupervisionEvent",
492 %mesh_name,
493 %event,
494 "detected supervision error on monitored mesh: name={mesh_name}",
495 );
496 } else {
497 tracing::debug!(
498 name = "SupervisionEvent",
499 %mesh_name,
500 %event,
501 "detected non-error supervision event on monitored mesh: name={mesh_name}",
502 );
503 }
504
505 let failure_message = MeshFailure {
506 actor_mesh_name: Some(mesh_name.to_string()),
507 rank: Some(rank),
508 event: event.clone(),
509 };
510 health_state.crashed_ranks.insert(rank, event.clone());
511 health_state.unhealthy_event = Some(if is_proc_stopped {
512 Unhealthy::StreamClosed(failure_message.clone())
513 } else {
514 Unhealthy::Crashed(failure_message.clone())
515 });
516 if is_failed {
521 if let Some(owner) = &health_state.owner {
522 if let Err(error) = owner.send(cx, failure_message.clone()) {
523 tracing::warn!(
524 name = "SupervisionEvent",
525 %mesh_name,
526 %event,
527 %error,
528 "failed to send supervision event to owner {}: {}. dropping event",
529 owner.port_id(),
530 error
531 );
532 } else {
533 tracing::info!(%mesh_name, %event, "sent supervision failure message to owner {}", owner.port_id());
534 }
535 }
536 }
537 for subscriber in health_state.subscribers.iter() {
540 send_subscriber_message(cx, subscriber, failure_message.clone());
541 }
542}
543
544fn actor_state_to_supervision_events(
545 state: resource::State<ActorState>,
546) -> (usize, Vec<ActorSupervisionEvent>) {
547 let (rank, actor_id, events) = match state.state {
548 Some(inner) => (
549 inner.create_rank,
550 Some(inner.actor_id),
551 inner.supervision_events.clone(),
552 ),
553 None => (0, None, vec![]),
554 };
555 let events = match state.status {
556 resource::Status::NotExist | resource::Status::Stopped | resource::Status::Timeout(_) => {
559 if !events.is_empty() {
561 events
562 } else {
563 vec![ActorSupervisionEvent::new(
564 actor_id.expect("actor_id is None"),
565 None,
566 ActorStatus::Stopped,
567 None,
568 )]
569 }
570 }
571 resource::Status::Failed(_) => events,
572 _ => vec![],
574 };
575 (rank, events)
576}
577
578#[async_trait]
579impl<A: Referable> Handler<CheckState> for ActorMeshController<A> {
580 async fn handle(&mut self, cx: &Context<Self>, _: CheckState) -> Result<(), anyhow::Error> {
592 let mesh = &self.mesh;
599 let supervision_display_name = &self.supervision_display_name;
600 self.health_state.undeliverable_subscribers.clear();
605 let proc_states = mesh.proc_mesh().proc_states(cx).await;
607 if let Err(e) = proc_states {
608 send_state_change(
609 cx,
610 0,
611 ActorSupervisionEvent::new(
612 cx.self_id().clone(),
613 None,
614 ActorStatus::generic_failure(format!(
615 "unable to query for proc states: {:?}",
616 e
617 )),
618 None,
619 ),
620 mesh.name(),
621 false,
622 &mut self.health_state,
623 );
624 self.self_check_state_message(cx)?;
625 return Ok(());
626 }
627 if let Some(proc_states) = proc_states.unwrap() {
628 if let Some((point, state)) = proc_states
630 .iter()
631 .find(|(_rank, state)| state.status.is_terminating())
632 {
633 let actor_status = match state.state.and_then(|s| s.proc_status) {
637 Some(ProcStatus::Stopped { .. })
638 | Some(ProcStatus::Killed { signal: 15, .. })
640 | None => ActorStatus::Stopped,
642
643 Some(status) => ActorStatus::Failed(ActorErrorKind::Generic(format!(
644 "process failure: {}",
645 status
646 ))),
647 };
648 let display_name = if !point.is_empty() {
649 let coords_display = point.format_as_dict();
650 if let Some(pos) = supervision_display_name.rfind('>') {
651 format!(
652 "{}{}{}",
653 &supervision_display_name[..pos],
654 coords_display,
655 &supervision_display_name[pos..]
656 )
657 } else {
658 format!("{}{}", supervision_display_name, coords_display)
659 }
660 } else {
661 supervision_display_name.clone()
662 };
663 send_state_change(
664 cx,
665 point.rank(),
666 ActorSupervisionEvent::new(
667 mesh.get(point.rank()).unwrap().actor_id().clone(),
670 Some(format!("{} was running on a process which", display_name)),
671 actor_status,
672 None,
673 ),
674 mesh.name(),
675 true,
676 &mut self.health_state,
677 );
678 self.self_check_state_message(cx)?;
679 return Ok(());
680 }
681 }
682
683 let events = mesh.actor_states(cx).await;
685 if let Err(e) = events {
686 send_state_change(
687 cx,
688 0,
689 ActorSupervisionEvent::new(
690 cx.self_id().clone(),
691 Some(supervision_display_name.clone()),
692 ActorStatus::generic_failure(format!(
693 "unable to query for actor states: {:?}",
694 e
695 )),
696 None,
697 ),
698 mesh.name(),
699 false,
700 &mut self.health_state,
701 );
702 self.self_check_state_message(cx)?;
703 return Ok(());
704 }
705 for (point, state) in events.unwrap().iter() {
708 let mut is_new = false;
709 let entry = self
710 .health_state
711 .statuses
712 .entry(point.clone())
713 .or_insert_with(|| {
714 tracing::debug!(
715 "PythonActorMeshImpl: received initial state: point={:?}, state={:?}",
716 point,
717 state
718 );
719 let (_rank, events) = actor_state_to_supervision_events(state.clone());
720 if !events.is_empty() {
722 is_new = true;
723 }
724 state.status.clone()
725 });
726 let (rank, event) = if is_new {
728 let (rank, events) = actor_state_to_supervision_events(state.clone());
729 if events.is_empty() {
730 continue;
731 }
732 (rank, events[0].clone())
733 } else if *entry != state.status {
734 tracing::debug!(
735 "PythonActorMeshImpl: received state change event: point={:?}, old_state={:?}, new_state={:?}",
736 point,
737 entry,
738 state
739 );
740 let (rank, events) = actor_state_to_supervision_events(state.clone());
741 if events.is_empty() {
742 continue;
743 }
744 *entry = state.status;
745 (rank, events[0].clone())
746 } else {
747 if !entry.is_terminating() {
752 send_heartbeat(cx, &self.health_state);
753 }
754 continue;
755 };
756 send_state_change(cx, rank, event, mesh.name(), false, &mut self.health_state);
757 }
758
759 self.self_check_state_message(cx)?;
761 return Ok(());
762 }
763}
764
765#[derive(Debug)]
766#[hyperactor::export]
767pub(crate) struct ProcMeshController {
768 mesh: ProcMeshRef,
769}
770
771impl ProcMeshController {
772 pub(crate) fn new(mesh: ProcMeshRef) -> Self {
774 Self { mesh }
775 }
776}
777
778#[async_trait]
779impl Actor for ProcMeshController {
780 async fn cleanup(
781 &mut self,
782 this: &Instance<Self>,
783 _err: Option<&ActorError>,
784 ) -> Result<(), anyhow::Error> {
785 let names = self.mesh.proc_ids().collect::<Vec<ProcId>>();
787 let region = self.mesh.region().clone();
788 if let Some(hosts) = self.mesh.hosts() {
789 hosts
790 .stop_proc_mesh(this, self.mesh.name(), names, region)
791 .await
792 } else {
793 Ok(())
794 }
795 }
796}
797
798#[derive(Debug)]
799#[hyperactor::export]
800pub(crate) struct HostMeshController {
801 mesh: HostMeshRef,
802}
803
804impl HostMeshController {
805 pub(crate) fn new(mesh: HostMeshRef) -> Self {
807 Self { mesh }
808 }
809}
810
811#[async_trait]
812impl Actor for HostMeshController {
813 async fn cleanup(
814 &mut self,
815 this: &Instance<Self>,
816 _err: Option<&ActorError>,
817 ) -> Result<(), anyhow::Error> {
818 for host in self.mesh.values() {
820 if let Err(e) = host.shutdown(this).await {
821 tracing::warn!(host = %host, error = %e, "host shutdown failed");
822 }
823 }
824 Ok(())
825 }
826}