1use std::collections::HashMap;
10use std::collections::HashSet;
11use std::fmt::Debug;
12
13use async_trait::async_trait;
14use hyperactor::Actor;
15use hyperactor::Bind;
16use hyperactor::Context;
17use hyperactor::Handler;
18use hyperactor::Instance;
19use hyperactor::Unbind;
20use hyperactor::actor::ActorError;
21use hyperactor::actor::ActorErrorKind;
22use hyperactor::actor::ActorStatus;
23use hyperactor::actor::Referable;
24use hyperactor::actor::handle_undeliverable_message;
25use hyperactor::context;
26use hyperactor::kv_pairs;
27use hyperactor::mailbox::MessageEnvelope;
28use hyperactor::mailbox::Undeliverable;
29use hyperactor::reference as hyperactor_reference;
30use hyperactor::supervision::ActorSupervisionEvent;
31use hyperactor_config::CONFIG;
32use hyperactor_config::ConfigAttr;
33use hyperactor_config::Flattrs;
34use hyperactor_config::attrs::declare_attrs;
35use hyperactor_telemetry::declare_static_counter;
36use ndslice::ViewExt;
37use ndslice::view::CollectMeshExt;
38use ndslice::view::Point;
39use ndslice::view::Ranked;
40use serde::Deserialize;
41use serde::Serialize;
42use tokio::time::Duration;
43use typeuri::Named;
44
45use crate::Name;
46use crate::ValueMesh;
47use crate::actor_mesh::ActorMeshRef;
48use crate::bootstrap::ProcStatus;
49use crate::casting::update_undeliverable_envelope_for_casting;
50use crate::host_mesh::HostMeshRef;
51use crate::proc_agent::ActorState;
52use crate::proc_agent::MESH_ORPHAN_TIMEOUT;
53use crate::proc_mesh::ProcMeshRef;
54use crate::resource;
55use crate::supervision::MeshFailure;
56use crate::supervision::Unhealthy;
57
58pub const ACTOR_MESH_CONTROLLER_NAME: &str = "actor_mesh_controller";
60
61declare_attrs! {
62 @meta(CONFIG = ConfigAttr::new(
70 Some("HYPERACTOR_MESH_SUPERVISION_POLL_FREQUENCY".to_string()),
71 None,
72 ))
73 pub attr SUPERVISION_POLL_FREQUENCY: Duration = Duration::from_secs(10);
74}
75
76declare_static_counter!(
77 ACTOR_MESH_CONTROLLER_SUPERVISION_STALLS,
78 "actor.actor_mesh_controller.num_stalls"
79);
80
81#[derive(Debug)]
82struct HealthState {
83 statuses: HashMap<Point, resource::Status>,
86 unhealthy_event: Option<Unhealthy>,
87 crashed_ranks: HashMap<usize, ActorSupervisionEvent>,
88 owner: Option<hyperactor_reference::PortRef<MeshFailure>>,
90 subscribers: HashSet<hyperactor_reference::PortRef<Option<MeshFailure>>>,
92}
93
94impl HealthState {
95 fn new(
96 statuses: HashMap<Point, resource::Status>,
97 owner: Option<hyperactor_reference::PortRef<MeshFailure>>,
98 ) -> Self {
99 Self {
100 statuses,
101 unhealthy_event: None,
102 crashed_ranks: HashMap::new(),
103 owner,
104 subscribers: HashSet::new(),
105 }
106 }
107}
108
109#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
115pub struct Subscribe(pub hyperactor_reference::PortRef<Option<MeshFailure>>);
116
117#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
120pub struct Unsubscribe(pub hyperactor_reference::PortRef<Option<MeshFailure>>);
121
122#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
124pub struct GetSubscriberCount(#[binding(include)] pub hyperactor_reference::PortRef<usize>);
125
126#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
131pub struct CheckState(pub std::time::SystemTime);
132
133#[hyperactor::export(handlers = [
141 Subscribe,
142 Unsubscribe,
143 GetSubscriberCount,
144 resource::CreateOrUpdate<resource::mesh::Spec<()>> { cast = true },
145 resource::GetState<resource::mesh::State<()>> { cast = true },
146 resource::Stop { cast = true },
147])]
148pub struct ActorMeshController<A>
149where
150 A: Referable,
151{
152 mesh: ActorMeshRef<A>,
153 supervision_display_name: String,
154 health_state: HealthState,
156 monitor: Option<()>,
160}
161
162impl<A: Referable> resource::mesh::Mesh for ActorMeshController<A> {
163 type Spec = ();
164 type State = ();
165}
166
167impl<A: Referable> ActorMeshController<A> {
168 pub(crate) fn new(
170 mesh: ActorMeshRef<A>,
171 supervision_display_name: Option<String>,
172 port: Option<hyperactor_reference::PortRef<MeshFailure>>,
173 initial_statuses: ValueMesh<resource::Status>,
174 ) -> Self {
175 let supervision_display_name =
176 supervision_display_name.unwrap_or_else(|| mesh.name().to_string());
177 Self {
178 mesh,
179 supervision_display_name,
180 health_state: HealthState::new(initial_statuses.iter().collect(), port),
181 monitor: None,
182 }
183 }
184
185 async fn stop(
186 &self,
187 cx: &impl context::Actor,
188 reason: String,
189 ) -> crate::Result<ValueMesh<resource::Status>> {
190 self.mesh
192 .proc_mesh()
193 .stop_actor_by_name(cx, self.mesh.name().clone(), reason)
194 .await
195 }
196
197 fn self_check_state_message(&self, cx: &Instance<Self>) -> Result<(), ActorError> {
198 if self.monitor.is_some() {
200 let delay = hyperactor_config::global::get(SUPERVISION_POLL_FREQUENCY);
203 cx.self_message_with_delay(CheckState(std::time::SystemTime::now() + delay), delay)
204 } else {
205 Ok(())
206 }
207 }
208}
209
210declare_attrs! {
211 pub attr ACTOR_MESH_SUBSCRIBER_MESSAGE: bool;
214}
215
216fn send_subscriber_message(
217 cx: &impl context::Actor,
218 subscriber: &hyperactor_reference::PortRef<Option<MeshFailure>>,
219 message: MeshFailure,
220) {
221 let mut headers = Flattrs::new();
222 headers.set(ACTOR_MESH_SUBSCRIBER_MESSAGE, true);
223 if let Err(error) = subscriber.send_with_headers(cx, headers, Some(message.clone())) {
224 tracing::warn!(
225 event = %message,
226 "failed to send supervision event to subscriber {}: {}",
227 subscriber.port_id(),
228 error
229 );
230 } else {
231 tracing::info!(event = %message, "sent supervision failure message to subscriber {}", subscriber.port_id());
232 }
233}
234
235impl<A: Referable> Debug for ActorMeshController<A> {
236 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237 f.debug_struct("MeshController")
238 .field("mesh", &self.mesh)
239 .field("health_state", &self.health_state)
240 .field("monitor", &self.monitor)
241 .finish()
242 }
243}
244
245#[async_trait]
246impl<A: Referable> Actor for ActorMeshController<A> {
247 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
248 this.set_system();
249 self.monitor = Some(());
254 self.self_check_state_message(this)?;
255 let owner = if let Some(owner) = &self.health_state.owner {
256 owner.to_string()
257 } else {
258 String::from("None")
259 };
260 tracing::info!(actor_id = %this.self_id(), %owner, "started mesh controller for {}", self.mesh.name());
261 Ok(())
262 }
263
264 async fn cleanup(
265 &mut self,
266 this: &Instance<Self>,
267 _err: Option<&ActorError>,
268 ) -> Result<(), anyhow::Error> {
269 if self.monitor.take().is_some() {
272 tracing::info!(actor_id = %this.self_id(), actor_mesh = %self.mesh.name(), "starting cleanup for ActorMeshController, stopping actor mesh");
273 self.stop(this, "actor mesh controller cleanup".to_string())
274 .await?;
275 }
276 Ok(())
277 }
278
279 async fn handle_undeliverable_message(
280 &mut self,
281 cx: &Instance<Self>,
282 mut envelope: Undeliverable<MessageEnvelope>,
283 ) -> Result<(), anyhow::Error> {
284 envelope = update_undeliverable_envelope_for_casting(envelope);
286 if let Some(true) = envelope.0.headers().get(ACTOR_MESH_SUBSCRIBER_MESSAGE) {
287 let dest_port_id = envelope.0.dest().clone();
292 let port = hyperactor_reference::PortRef::<Option<MeshFailure>>::attest(dest_port_id);
293 let did_exist = self.health_state.subscribers.remove(&port);
294 if did_exist {
295 tracing::debug!(
296 actor_id = %cx.self_id(),
297 num_subscribers = self.health_state.subscribers.len(),
298 "ActorMeshController: handle_undeliverable_message: removed subscriber {} from mesh controller",
299 port.port_id()
300 );
301 }
302 Ok(())
303 } else {
304 handle_undeliverable_message(cx, envelope)
305 }
306 }
307}
308
309#[async_trait]
310impl<A: Referable> Handler<Subscribe> for ActorMeshController<A> {
311 async fn handle(&mut self, cx: &Context<Self>, message: Subscribe) -> anyhow::Result<()> {
312 match &self.health_state.unhealthy_event {
316 None => {}
317 Some(Unhealthy::StreamClosed(msg)) => {
321 send_subscriber_message(cx, &message.0, msg.clone());
322 }
323 Some(Unhealthy::Crashed(msg)) => {
324 send_subscriber_message(cx, &message.0, msg.clone());
325 }
326 }
327 let port_id = message.0.port_id().clone();
328 if self.health_state.subscribers.insert(message.0) {
329 tracing::debug!(actor_id = %cx.self_id(), num_subscribers = self.health_state.subscribers.len(), "added subscriber {} to mesh controller", port_id);
330 }
331 Ok(())
332 }
333}
334
335#[async_trait]
336impl<A: Referable> Handler<Unsubscribe> for ActorMeshController<A> {
337 async fn handle(&mut self, cx: &Context<Self>, message: Unsubscribe) -> anyhow::Result<()> {
338 if self.health_state.subscribers.remove(&message.0) {
339 tracing::debug!(actor_id = %cx.self_id(), num_subscribers = self.health_state.subscribers.len(), "removed subscriber {} from mesh controller", message.0.port_id());
340 }
341 Ok(())
342 }
343}
344
345#[async_trait]
346impl<A: Referable> Handler<GetSubscriberCount> for ActorMeshController<A> {
347 async fn handle(
348 &mut self,
349 cx: &Context<Self>,
350 message: GetSubscriberCount,
351 ) -> anyhow::Result<()> {
352 message.0.send(cx, self.health_state.subscribers.len())?;
353 Ok(())
354 }
355}
356
357#[async_trait]
358impl<A: Referable> Handler<resource::CreateOrUpdate<resource::mesh::Spec<()>>>
359 for ActorMeshController<A>
360{
361 async fn handle(
364 &mut self,
365 _cx: &Context<Self>,
366 _message: resource::CreateOrUpdate<resource::mesh::Spec<()>>,
367 ) -> anyhow::Result<()> {
368 Ok(())
369 }
370}
371
372#[async_trait]
373impl<A: Referable> Handler<resource::GetState<resource::mesh::State<()>>>
374 for ActorMeshController<A>
375{
376 async fn handle(
377 &mut self,
378 cx: &Context<Self>,
379 message: resource::GetState<resource::mesh::State<()>>,
380 ) -> anyhow::Result<()> {
381 let status = if let Some(Unhealthy::Crashed(e)) = &self.health_state.unhealthy_event {
382 resource::Status::Failed(e.to_string())
383 } else if let Some(Unhealthy::StreamClosed(_)) = &self.health_state.unhealthy_event {
384 resource::Status::Stopped
385 } else {
386 resource::Status::Running
387 };
388 let statuses = &self.health_state.statuses;
389 let mut statuses = statuses.clone().into_iter().collect::<Vec<_>>();
390 statuses.sort_by_key(|(p, _)| p.rank());
391 let statuses: ValueMesh<resource::Status> =
392 statuses
393 .into_iter()
394 .map(|(_, s)| s)
395 .collect_mesh::<ValueMesh<_>>(self.mesh.region().clone())?;
396 let state = resource::mesh::State {
397 statuses,
398 state: (),
399 };
400 message.reply.send(
401 cx,
402 resource::State {
403 name: message.name,
404 status,
405 state: Some(state),
406 },
407 )?;
408 Ok(())
409 }
410}
411
412#[async_trait]
413impl<A: Referable> Handler<resource::Stop> for ActorMeshController<A> {
414 async fn handle(&mut self, cx: &Context<Self>, message: resource::Stop) -> anyhow::Result<()> {
415 let mesh = &self.mesh;
416 let mesh_name = mesh.name();
417 tracing::info!(
418 name = "ActorMeshControllerStatus",
419 %mesh_name,
420 reason = %message.reason,
421 "stopping actor mesh"
422 );
423 if self.monitor.take().is_none() {
428 tracing::debug!(actor_id = %cx.self_id(), actor_mesh = %mesh_name, "duplicate stop request, actor mesh is already stopped");
429 return Ok(());
430 }
431 tracing::info!(actor_id = %cx.self_id(), actor_mesh = %mesh_name, "forwarding stop request from ActorMeshController to proc mesh");
432
433 let rank = 0usize;
440 let event = ActorSupervisionEvent::new(
441 mesh.get(rank).unwrap().actor_id().clone(),
443 None,
444 ActorStatus::Stopped("ActorMeshController received explicit stop request".to_string()),
445 None,
446 );
447 let failure_message = MeshFailure {
448 actor_mesh_name: Some(mesh_name.to_string()),
449 rank: None,
451 event,
452 };
453 self.health_state.unhealthy_event = Some(Unhealthy::StreamClosed(failure_message.clone()));
454 for subscriber in self.health_state.subscribers.iter() {
458 send_subscriber_message(cx, subscriber, failure_message.clone());
459 }
460
461 let max_rank = self.health_state.statuses.keys().map(|p| p.rank()).max();
464 let extent = self
465 .health_state
466 .statuses
467 .keys()
468 .next()
469 .map(|p| p.extent().clone());
470 match self.stop(cx, message.reason.clone()).await {
472 Ok(statuses) => {
473 for (rank, status) in statuses.iter() {
475 self.health_state
476 .statuses
477 .entry(rank)
478 .and_modify(move |s| *s = status);
479 }
480 }
481 Err(crate::Error::ActorStopError { statuses }) => {
483 if let Some(max_rank) = max_rank {
485 let extent = extent.expect("no actors in mesh");
486 for (rank, status) in statuses.materialized_iter(max_rank).enumerate() {
487 *self
488 .health_state
489 .statuses
490 .get_mut(&extent.point_of_rank(rank).expect("illegal rank"))
491 .unwrap() = status.clone();
492 }
493 }
494 }
495 Err(e) => {
497 return Err(e.into());
498 }
499 }
500
501 tracing::info!(actor_id = %cx.self_id(), actor_mesh = %mesh_name, "stopped mesh");
502 Ok(())
503 }
504}
505
506fn send_heartbeat(cx: &impl context::Actor, health_state: &HealthState) {
513 tracing::debug!(
514 num_subscribers = health_state.subscribers.len(),
515 "sending heartbeat to subscribers",
516 );
517
518 for subscriber in health_state.subscribers.iter() {
519 let mut headers = Flattrs::new();
520 headers.set(ACTOR_MESH_SUBSCRIBER_MESSAGE, true);
521 if let Err(e) = subscriber.send_with_headers(cx, headers, None) {
522 tracing::warn!(subscriber = %subscriber.port_id(), "error sending heartbeat message: {:?}", e);
523 }
524 }
525}
526
527fn send_state_change(
532 cx: &impl context::Actor,
533 rank: usize,
534 event: ActorSupervisionEvent,
535 mesh_name: &Name,
536 is_proc_stopped: bool,
537 health_state: &mut HealthState,
538) {
539 let is_failed = event.is_error();
542 if is_failed {
543 tracing::warn!(
544 name = "SupervisionEvent",
545 actor_mesh = %mesh_name,
546 %event,
547 "detected supervision error on monitored mesh: name={mesh_name}",
548 );
549 } else {
550 tracing::debug!(
551 name = "SupervisionEvent",
552 actor_mesh = %mesh_name,
553 %event,
554 "detected non-error supervision event on monitored mesh: name={mesh_name}",
555 );
556 }
557
558 let failure_message = MeshFailure {
559 actor_mesh_name: Some(mesh_name.to_string()),
560 rank: Some(rank),
561 event: event.clone(),
562 };
563 health_state.crashed_ranks.insert(rank, event.clone());
564 health_state.unhealthy_event = Some(if is_proc_stopped {
565 Unhealthy::StreamClosed(failure_message.clone())
566 } else {
567 Unhealthy::Crashed(failure_message.clone())
568 });
569 if is_failed {
574 if let Some(owner) = &health_state.owner {
575 if let Err(error) = owner.send(cx, failure_message.clone()) {
576 tracing::warn!(
577 name = "SupervisionEvent",
578 actor_mesh = %mesh_name,
579 %event,
580 %error,
581 "failed to send supervision event to owner {}: {}. dropping event",
582 owner.port_id(),
583 error
584 );
585 } else {
586 tracing::info!(actor_mesh = %mesh_name, %event, "sent supervision failure message to owner {}", owner.port_id());
587 }
588 }
589 }
590 for subscriber in health_state.subscribers.iter() {
593 send_subscriber_message(cx, subscriber, failure_message.clone());
594 }
595}
596
597fn actor_state_to_supervision_events(
598 state: resource::State<ActorState>,
599) -> (usize, Vec<ActorSupervisionEvent>) {
600 let (rank, actor_id, events) = match state.state {
601 Some(inner) => (
602 inner.create_rank,
603 Some(inner.actor_id),
604 inner.supervision_events.clone(),
605 ),
606 None => (0, None, vec![]),
607 };
608 let events = match state.status {
609 resource::Status::NotExist | resource::Status::Stopped | resource::Status::Timeout(_) => {
612 if !events.is_empty() {
614 events
615 } else {
616 vec![ActorSupervisionEvent::new(
617 actor_id.expect("actor_id is None"),
618 None,
619 ActorStatus::Stopped(
620 format!(
621 "actor status is {}; actor may have been killed",
622 state.status
623 )
624 .to_string(),
625 ),
626 None,
627 )]
628 }
629 }
630 resource::Status::Failed(_) => events,
631 _ => vec![],
633 };
634 (rank, events)
635}
636
637fn proc_status_to_actor_status(proc_status: Option<ProcStatus>) -> ActorStatus {
646 match proc_status {
647 Some(ProcStatus::Stopped { exit_code: 0, .. }) => {
648 ActorStatus::Stopped("process exited cleanly".to_string())
649 }
650 Some(ProcStatus::Stopped { exit_code, .. }) => ActorStatus::Failed(
651 ActorErrorKind::Generic(format!("process exited with non-zero code {}", exit_code)),
652 ),
653 Some(ProcStatus::Stopping { .. }) => {
656 ActorStatus::Stopped("process is stopping".to_string())
657 }
658 None => ActorStatus::Stopped("no status received from process".to_string()),
660 Some(status) => ActorStatus::Failed(ActorErrorKind::Generic(format!(
661 "process failure: {}",
662 status
663 ))),
664 }
665}
666
667fn format_system_time(time: std::time::SystemTime) -> String {
668 let datetime: chrono::DateTime<chrono::Local> = time.into();
669 datetime.format("%Y-%m-%d %H:%M:%S").to_string()
670}
671
672#[async_trait]
673impl<A: Referable> Handler<CheckState> for ActorMeshController<A> {
674 async fn handle(
686 &mut self,
687 cx: &Context<Self>,
688 CheckState(expected_time): CheckState,
689 ) -> Result<(), anyhow::Error> {
690 if std::time::SystemTime::now()
702 > expected_time + hyperactor_config::global::get(SUPERVISION_POLL_FREQUENCY)
703 {
704 let expected_time = format_system_time(expected_time);
706 ACTOR_MESH_CONTROLLER_SUPERVISION_STALLS.add(1, kv_pairs!("actor_id" => cx.self_id().to_string(), "expected_time" => expected_time.clone()));
708 tracing::warn!(
709 actor_id = %cx.self_id(),
710 "Handler<CheckState> is being stalled, expected at {}",
711 expected_time,
712 );
713 }
714 let mesh = &self.mesh;
715 let supervision_display_name = &self.supervision_display_name;
716 let proc_states = mesh.proc_mesh().proc_states(cx).await;
718 if let Err(e) = proc_states {
719 send_state_change(
720 cx,
721 0,
722 ActorSupervisionEvent::new(
723 cx.self_id().clone(),
724 None,
725 ActorStatus::generic_failure(format!(
726 "unable to query for proc states: {:?}",
727 e
728 )),
729 None,
730 ),
731 mesh.name(),
732 false,
733 &mut self.health_state,
734 );
735 self.self_check_state_message(cx)?;
736 return Ok(());
737 }
738 if let Some(proc_states) = proc_states.unwrap() {
739 if let Some((point, state)) = proc_states
741 .iter()
742 .find(|(_rank, state)| state.status.is_terminating())
743 {
744 let actor_status =
748 proc_status_to_actor_status(state.state.and_then(|s| s.proc_status));
749 let display_name = crate::actor_display_name(supervision_display_name, &point);
750 send_state_change(
751 cx,
752 point.rank(),
753 ActorSupervisionEvent::new(
754 mesh.get(point.rank()).unwrap().actor_id().clone(),
757 Some(format!("{} was running on a process which", display_name)),
758 actor_status,
759 None,
760 ),
761 mesh.name(),
762 true,
763 &mut self.health_state,
764 );
765 self.self_check_state_message(cx)?;
766 return Ok(());
767 }
768 }
769
770 let orphan_timeout = hyperactor_config::global::get(MESH_ORPHAN_TIMEOUT);
772 let keepalive = if orphan_timeout.is_zero() {
773 None
774 } else {
775 Some(std::time::SystemTime::now() + orphan_timeout)
776 };
777 let events = mesh.actor_states_with_keepalive(cx, keepalive).await;
778 if let Err(e) = events {
779 send_state_change(
780 cx,
781 0,
782 ActorSupervisionEvent::new(
783 cx.self_id().clone(),
784 Some(supervision_display_name.clone()),
785 ActorStatus::generic_failure(format!(
786 "unable to query for actor states: {:?}",
787 e
788 )),
789 None,
790 ),
791 mesh.name(),
792 false,
793 &mut self.health_state,
794 );
795 self.self_check_state_message(cx)?;
796 return Ok(());
797 }
798 let mut did_send_state_change = false;
800 let mut is_terminal = false;
803 for (point, state) in events.unwrap().iter() {
806 let mut is_new = false;
807 let entry = self
808 .health_state
809 .statuses
810 .entry(point.clone())
811 .or_insert_with(|| {
812 tracing::debug!(
813 "PythonActorMeshImpl: received initial state: point={:?}, state={:?}",
814 point,
815 state
816 );
817 let (_rank, events) = actor_state_to_supervision_events(state.clone());
818 if !events.is_empty() {
820 is_new = true;
821 }
822 state.status.clone()
823 });
824 if !is_terminal && entry.is_terminating() {
827 is_terminal = true;
828 }
829 let (rank, event) = if is_new {
831 let (rank, events) = actor_state_to_supervision_events(state.clone());
832 if events.is_empty() {
833 continue;
834 }
835 (rank, events[0].clone())
836 } else if *entry != state.status {
837 tracing::debug!(
838 "PythonActorMeshImpl: received state change event: point={:?}, old_state={:?}, new_state={:?}",
839 point,
840 entry,
841 state
842 );
843 let (rank, events) = actor_state_to_supervision_events(state.clone());
844 if events.is_empty() {
845 continue;
846 }
847 *entry = state.status;
848 (rank, events[0].clone())
849 } else {
850 continue;
851 };
852 did_send_state_change = true;
853 send_state_change(cx, rank, event, mesh.name(), false, &mut self.health_state);
854 }
855 if !did_send_state_change && !is_terminal {
856 send_heartbeat(cx, &self.health_state);
861 }
862
863 let all_ranks_terminal = self
867 .health_state
868 .statuses
869 .values()
870 .all(|s| s.is_terminating());
871 if !all_ranks_terminal {
872 self.self_check_state_message(cx)?;
874 } else {
875 self.monitor.take();
878 }
879 return Ok(());
880 }
881}
882
883#[derive(Debug)]
884#[hyperactor::export]
885pub(crate) struct ProcMeshController {
886 mesh: ProcMeshRef,
887}
888
889impl ProcMeshController {
890 pub(crate) fn new(mesh: ProcMeshRef) -> Self {
892 Self { mesh }
893 }
894}
895
896#[async_trait]
897impl Actor for ProcMeshController {
898 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
899 this.set_system();
900 Ok(())
901 }
902
903 async fn cleanup(
904 &mut self,
905 this: &Instance<Self>,
906 _err: Option<&ActorError>,
907 ) -> Result<(), anyhow::Error> {
908 let names = self
910 .mesh
911 .proc_ids()
912 .collect::<Vec<hyperactor_reference::ProcId>>();
913 let region = self.mesh.region().clone();
914 if let Some(hosts) = self.mesh.hosts() {
915 hosts
916 .stop_proc_mesh(
917 this,
918 self.mesh.name(),
919 names,
920 region,
921 "proc mesh controller cleanup".to_string(),
922 )
923 .await
924 } else {
925 Ok(())
926 }
927 }
928}
929
930#[derive(Debug)]
931#[hyperactor::export]
932pub(crate) struct HostMeshController {
933 mesh: HostMeshRef,
934}
935
936impl HostMeshController {
937 pub(crate) fn new(mesh: HostMeshRef) -> Self {
939 Self { mesh }
940 }
941}
942
943#[async_trait]
944impl Actor for HostMeshController {
945 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
946 this.set_system();
947 Ok(())
948 }
949
950 async fn cleanup(
951 &mut self,
952 this: &Instance<Self>,
953 _err: Option<&ActorError>,
954 ) -> Result<(), anyhow::Error> {
955 for host in self.mesh.values() {
957 if let Err(e) = host.shutdown(this).await {
958 tracing::warn!(host = %host, error = %e, "host shutdown failed");
959 }
960 }
961 Ok(())
962 }
963}
964
965#[cfg(test)]
966mod tests {
967 use std::ops::Deref;
968 use std::time::Duration;
969
970 use hyperactor::actor::ActorStatus;
971 use ndslice::Extent;
972 use ndslice::ViewExt;
973
974 use super::SUPERVISION_POLL_FREQUENCY;
975 use super::proc_status_to_actor_status;
976 use crate::ActorMesh;
977 use crate::Name;
978 use crate::bootstrap::ProcStatus;
979 use crate::proc_agent::MESH_ORPHAN_TIMEOUT;
980 use crate::resource;
981 use crate::supervision::MeshFailure;
982 use crate::test_utils::local_host_mesh;
983 use crate::testactor;
984 use crate::testing;
985
986 #[tokio::test]
994 async fn test_orphaned_actors_are_cleaned_up() {
995 let config = hyperactor_config::global::lock();
996 let _orphan = config.override_key(MESH_ORPHAN_TIMEOUT, Duration::from_secs(1));
998
999 let instance = testing::instance();
1000 let host_mesh = local_host_mesh(2).await;
1001 let proc_mesh = host_mesh
1002 .spawn(instance, "test", Extent::unity())
1003 .await
1004 .unwrap();
1005
1006 let actor_name = Name::new("orphan_test").unwrap();
1007 let actor_mesh: ActorMesh<testactor::TestActor> = proc_mesh
1011 .spawn_with_name(instance, actor_name.clone(), &(), None, true)
1012 .await
1013 .unwrap();
1014 assert!(
1015 actor_mesh.deref().extent().num_ranks() > 0,
1016 "should have spawned at least one actor"
1017 );
1018
1019 let states = proc_mesh
1022 .actor_states_with_keepalive(
1023 instance,
1024 actor_name.clone(),
1025 Some(std::time::SystemTime::now() + Duration::from_secs(2)),
1026 )
1027 .await
1028 .unwrap();
1029 for state in states.values() {
1031 assert_eq!(
1032 state.status,
1033 resource::Status::Running,
1034 "actor should be running before expiry"
1035 );
1036 }
1037
1038 tokio::time::sleep(Duration::from_secs(5)).await;
1043
1044 let states = proc_mesh
1047 .actor_states(instance, actor_name.clone())
1048 .await
1049 .unwrap();
1050 for state in states.values() {
1051 assert_eq!(
1052 state.status,
1053 resource::Status::Stopped,
1054 "actor should be stopped after keepalive expiry"
1055 );
1056 }
1057 }
1058
1059 #[cfg(fbcode_build)]
1062 async fn host_mesh_with_config(n: usize) -> crate::host_mesh::HostMesh {
1063 use hyperactor::channel::ChannelTransport;
1064 use tokio::process::Command;
1065
1066 let program = crate::testresource::get("monarch/hyperactor_mesh/bootstrap");
1067 let mut host_addrs = vec![];
1068 for _ in 0..n {
1069 host_addrs.push(ChannelTransport::Unix.any());
1070 }
1071
1072 for host in host_addrs.iter() {
1073 let mut cmd = Command::new(program.clone());
1074 let boot = crate::Bootstrap::Host {
1075 addr: host.clone(),
1076 command: None,
1077 config: Some(hyperactor_config::global::attrs()),
1078 exit_on_shutdown: false,
1079 };
1080 boot.to_env(&mut cmd);
1081 cmd.kill_on_drop(false);
1082 unsafe {
1085 cmd.pre_exec(crate::bootstrap::install_pdeathsig_kill);
1086 }
1087 cmd.spawn().unwrap();
1088 }
1089
1090 let host_mesh = crate::HostMeshRef::from_hosts(Name::new("test").unwrap(), host_addrs);
1091 crate::host_mesh::HostMesh::take(host_mesh)
1092 }
1093
1094 #[tokio::test]
1101 #[cfg(fbcode_build)]
1102 async fn test_orphaned_actors_cleaned_up_on_controller_crash() {
1103 let config = hyperactor_config::global::lock();
1104 let _orphan = config.override_key(MESH_ORPHAN_TIMEOUT, Duration::from_secs(2));
1105 let _poll = config.override_key(SUPERVISION_POLL_FREQUENCY, Duration::from_secs(1));
1106
1107 let instance = testing::instance();
1108 let num_replicas = 2;
1109
1110 let mut actor_hm = host_mesh_with_config(num_replicas).await;
1115 let actor_proc_mesh = actor_hm
1116 .spawn(instance, "actors", Extent::unity())
1117 .await
1118 .unwrap();
1119
1120 let mut controller_hm = host_mesh_with_config(1).await;
1122 let controller_proc_mesh = controller_hm
1123 .spawn(instance, "controller", Extent::unity())
1124 .await
1125 .unwrap();
1126
1127 let child_name = Name::new("orphan_child").unwrap();
1128
1129 let (supervision_port, _supervision_receiver) = instance.open_port::<MeshFailure>();
1131 let supervisor = supervision_port.bind();
1132
1133 let wrapper_mesh: ActorMesh<testactor::WrapperActor> = controller_proc_mesh
1137 .spawn(
1138 instance,
1139 "wrapper",
1140 &(
1141 actor_proc_mesh.deref().clone(),
1142 supervisor,
1143 child_name.clone(),
1144 ),
1145 )
1146 .await
1147 .unwrap();
1148
1149 tokio::time::sleep(Duration::from_secs(3)).await;
1152
1153 let states = actor_proc_mesh
1155 .actor_states(instance, child_name.clone())
1156 .await
1157 .unwrap();
1158 for state in states.values() {
1159 assert_eq!(
1160 state.status,
1161 resource::Status::Running,
1162 "actor should be running before controller crash"
1163 );
1164 }
1165
1166 wrapper_mesh
1170 .cast(
1171 instance,
1172 testactor::CauseSupervisionEvent {
1173 kind: testactor::SupervisionEventType::ProcessExit(1),
1174 send_to_children: false,
1175 },
1176 )
1177 .unwrap();
1178
1179 tokio::time::sleep(Duration::from_secs(8)).await;
1184
1185 let states = actor_proc_mesh
1187 .actor_states(instance, child_name.clone())
1188 .await
1189 .unwrap();
1190 for state in states.values() {
1191 assert_eq!(
1192 state.status,
1193 resource::Status::Stopped,
1194 "actor should be stopped after controller crash and orphan timeout"
1195 );
1196 }
1197
1198 let _ = actor_hm.shutdown(instance).await;
1199 let _ = controller_hm.shutdown(instance).await;
1200 }
1201
1202 #[test]
1203 fn test_proc_status_to_actor_status_stopped_cleanly() {
1204 let status = proc_status_to_actor_status(Some(ProcStatus::Stopped {
1205 exit_code: 0,
1206 stderr_tail: vec![],
1207 }));
1208 assert!(
1209 matches!(status, ActorStatus::Stopped(ref msg) if msg.contains("cleanly")),
1210 "expected Stopped, got {:?}",
1211 status
1212 );
1213 }
1214
1215 #[test]
1216 fn test_proc_status_to_actor_status_nonzero_exit() {
1217 let status = proc_status_to_actor_status(Some(ProcStatus::Stopped {
1218 exit_code: 1,
1219 stderr_tail: vec![],
1220 }));
1221 assert!(
1222 matches!(status, ActorStatus::Failed(_)),
1223 "expected Failed, got {:?}",
1224 status
1225 );
1226 }
1227
1228 #[test]
1229 fn test_proc_status_to_actor_status_stopping_is_not_a_failure() {
1230 let status = proc_status_to_actor_status(Some(ProcStatus::Stopping {
1231 started_at: std::time::SystemTime::now(),
1232 }));
1233 assert!(
1234 matches!(status, ActorStatus::Stopped(ref msg) if msg.contains("stopping")),
1235 "expected Stopped, got {:?}",
1236 status
1237 );
1238 }
1239
1240 #[test]
1241 fn test_proc_status_to_actor_status_none() {
1242 let status = proc_status_to_actor_status(None);
1243 assert!(
1244 matches!(status, ActorStatus::Stopped(_)),
1245 "expected Stopped, got {:?}",
1246 status
1247 );
1248 }
1249
1250 #[test]
1251 fn test_proc_status_to_actor_status_killed() {
1252 let status = proc_status_to_actor_status(Some(ProcStatus::Killed {
1253 signal: 9,
1254 core_dumped: false,
1255 }));
1256 assert!(
1257 matches!(status, ActorStatus::Failed(_)),
1258 "expected Failed, got {:?}",
1259 status
1260 );
1261 }
1262
1263 #[test]
1264 fn test_proc_status_to_actor_status_failed() {
1265 let status = proc_status_to_actor_status(Some(ProcStatus::Failed {
1266 reason: "oom".to_string(),
1267 }));
1268 assert!(
1269 matches!(status, ActorStatus::Failed(_)),
1270 "expected Failed, got {:?}",
1271 status
1272 );
1273 }
1274}