hyperactor_mesh/
mesh_controller.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::collections::HashSet;
11use std::fmt::Debug;
12
13use async_trait::async_trait;
14use hyperactor::Actor;
15use hyperactor::Bind;
16use hyperactor::Context;
17use hyperactor::Handler;
18use hyperactor::Instance;
19use hyperactor::Unbind;
20use hyperactor::actor::ActorError;
21use hyperactor::actor::ActorErrorKind;
22use hyperactor::actor::ActorStatus;
23use hyperactor::actor::Referable;
24use hyperactor::actor::handle_undeliverable_message;
25use hyperactor::context;
26use hyperactor::kv_pairs;
27use hyperactor::mailbox::MessageEnvelope;
28use hyperactor::mailbox::Undeliverable;
29use hyperactor::reference as hyperactor_reference;
30use hyperactor::supervision::ActorSupervisionEvent;
31use hyperactor_config::CONFIG;
32use hyperactor_config::ConfigAttr;
33use hyperactor_config::Flattrs;
34use hyperactor_config::attrs::declare_attrs;
35use hyperactor_telemetry::declare_static_counter;
36use ndslice::ViewExt;
37use ndslice::view::CollectMeshExt;
38use ndslice::view::Point;
39use ndslice::view::Ranked;
40use serde::Deserialize;
41use serde::Serialize;
42use tokio::time::Duration;
43use typeuri::Named;
44
45use crate::Name;
46use crate::ValueMesh;
47use crate::actor_mesh::ActorMeshRef;
48use crate::bootstrap::ProcStatus;
49use crate::casting::update_undeliverable_envelope_for_casting;
50use crate::host_mesh::HostMeshRef;
51use crate::proc_agent::ActorState;
52use crate::proc_agent::MESH_ORPHAN_TIMEOUT;
53use crate::proc_mesh::ProcMeshRef;
54use crate::resource;
55use crate::supervision::MeshFailure;
56use crate::supervision::Unhealthy;
57
58/// Actor name for `ActorMeshController` when spawned as a named child.
59pub const ACTOR_MESH_CONTROLLER_NAME: &str = "actor_mesh_controller";
60
61declare_attrs! {
62    /// Time between checks of actor states to create supervision events for
63    /// owners. The longer this is, the longer it will take to detect a failure
64    /// and report it to all subscribers; however, shorter intervals will send
65    /// more frequent messages and heartbeats just to see everything is still running.
66    /// The default is chosen to balance these two objectives.
67    /// This also controls how frequently the healthy heartbeat is sent out to
68    /// subscribers if there are no failures encountered.
69    @meta(CONFIG = ConfigAttr::new(
70        Some("HYPERACTOR_MESH_SUPERVISION_POLL_FREQUENCY".to_string()),
71        None,
72    ))
73    pub attr SUPERVISION_POLL_FREQUENCY: Duration = Duration::from_secs(10);
74}
75
76declare_static_counter!(
77    ACTOR_MESH_CONTROLLER_SUPERVISION_STALLS,
78    "actor.actor_mesh_controller.num_stalls"
79);
80
81#[derive(Debug)]
82struct HealthState {
83    /// The status of each actor in the controlled mesh.
84    /// TODO: replace with ValueMesh?
85    statuses: HashMap<Point, resource::Status>,
86    unhealthy_event: Option<Unhealthy>,
87    crashed_ranks: HashMap<usize, ActorSupervisionEvent>,
88    // The unique owner of this actor.
89    owner: Option<hyperactor_reference::PortRef<MeshFailure>>,
90    /// A set of subscribers to send messages to when events are encountered.
91    subscribers: HashSet<hyperactor_reference::PortRef<Option<MeshFailure>>>,
92}
93
94impl HealthState {
95    fn new(
96        statuses: HashMap<Point, resource::Status>,
97        owner: Option<hyperactor_reference::PortRef<MeshFailure>>,
98    ) -> Self {
99        Self {
100            statuses,
101            unhealthy_event: None,
102            crashed_ranks: HashMap::new(),
103            owner,
104            subscribers: HashSet::new(),
105        }
106    }
107}
108
109/// Subscribe me to updates about a mesh. If a duplicate is subscribed, only a single
110/// message is sent.
111/// Will send None if there are no failures on the mesh periodically. This guarantees
112/// the listener that the controller is still alive. Make sure to filter such events
113/// out as not useful.
114#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
115pub struct Subscribe(pub hyperactor_reference::PortRef<Option<MeshFailure>>);
116
117/// Unsubscribe me to future updates about a mesh. Should be the same port used in
118/// the Subscribe message.
119#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
120pub struct Unsubscribe(pub hyperactor_reference::PortRef<Option<MeshFailure>>);
121
122/// Query the number of active supervision subscribers on this controller.
123#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
124pub struct GetSubscriberCount(#[binding(include)] pub hyperactor_reference::PortRef<usize>);
125
126/// Check state of the actors in the mesh. This is used as a self message to
127/// periodically check.
128/// Stores the next time we expect to start running a check state message.
129/// Used to check for stalls in message handling.
130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Named, Bind, Unbind)]
131pub struct CheckState(pub std::time::SystemTime);
132
133/// The implementation of monitoring works as follows:
134/// * ActorMesh and ActorMeshRef subscribe for updates from this controller,
135///   which aggregates events from all owned actors.
136/// * The monitor continuously polls for new events. When new events are
137///   found, it sends messages to all subscribers
138/// * In addition to sending to subscribers, the owner is an automatic subscriber
139///   that also has to handle the events.
140#[hyperactor::export(handlers = [
141    Subscribe,
142    Unsubscribe,
143    GetSubscriberCount,
144    resource::CreateOrUpdate<resource::mesh::Spec<()>> { cast = true },
145    resource::GetState<resource::mesh::State<()>> { cast = true },
146    resource::Stop { cast = true },
147])]
148pub struct ActorMeshController<A>
149where
150    A: Referable,
151{
152    mesh: ActorMeshRef<A>,
153    supervision_display_name: String,
154    // Shared health state for the monitor and responding to queries.
155    health_state: HealthState,
156    // The monitor which continuously runs in the background to refresh the state
157    // of actors.
158    // If None, the actor it monitors has already stopped.
159    monitor: Option<()>,
160}
161
162impl<A: Referable> resource::mesh::Mesh for ActorMeshController<A> {
163    type Spec = ();
164    type State = ();
165}
166
167impl<A: Referable> ActorMeshController<A> {
168    /// Create a new mesh controller based on the provided reference.
169    pub(crate) fn new(
170        mesh: ActorMeshRef<A>,
171        supervision_display_name: Option<String>,
172        port: Option<hyperactor_reference::PortRef<MeshFailure>>,
173        initial_statuses: ValueMesh<resource::Status>,
174    ) -> Self {
175        let supervision_display_name =
176            supervision_display_name.unwrap_or_else(|| mesh.name().to_string());
177        Self {
178            mesh,
179            supervision_display_name,
180            health_state: HealthState::new(initial_statuses.iter().collect(), port),
181            monitor: None,
182        }
183    }
184
185    async fn stop(
186        &self,
187        cx: &impl context::Actor,
188        reason: String,
189    ) -> crate::Result<ValueMesh<resource::Status>> {
190        // Cannot use "ActorMesh::stop" as it tries to message the controller, which is this actor.
191        self.mesh
192            .proc_mesh()
193            .stop_actor_by_name(cx, self.mesh.name().clone(), reason)
194            .await
195    }
196
197    fn self_check_state_message(&self, cx: &Instance<Self>) -> Result<(), ActorError> {
198        // Only schedule a self message if the monitor has not been dropped.
199        if self.monitor.is_some() {
200            // Save when we expect the next check state message, so we can automatically
201            // detect stalls as they accumulate.
202            let delay = hyperactor_config::global::get(SUPERVISION_POLL_FREQUENCY);
203            cx.self_message_with_delay(CheckState(std::time::SystemTime::now() + delay), delay)
204        } else {
205            Ok(())
206        }
207    }
208}
209
210declare_attrs! {
211    /// If present in a message header, the message is from an ActorMeshController
212    /// to a subscriber and can be safely dropped if it is returned as undeliverable.
213    pub attr ACTOR_MESH_SUBSCRIBER_MESSAGE: bool;
214}
215
216fn send_subscriber_message(
217    cx: &impl context::Actor,
218    subscriber: &hyperactor_reference::PortRef<Option<MeshFailure>>,
219    message: MeshFailure,
220) {
221    let mut headers = Flattrs::new();
222    headers.set(ACTOR_MESH_SUBSCRIBER_MESSAGE, true);
223    if let Err(error) = subscriber.send_with_headers(cx, headers, Some(message.clone())) {
224        tracing::warn!(
225            event = %message,
226            "failed to send supervision event to subscriber {}: {}",
227            subscriber.port_id(),
228            error
229        );
230    } else {
231        tracing::info!(event = %message, "sent supervision failure message to subscriber {}", subscriber.port_id());
232    }
233}
234
235impl<A: Referable> Debug for ActorMeshController<A> {
236    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
237        f.debug_struct("MeshController")
238            .field("mesh", &self.mesh)
239            .field("health_state", &self.health_state)
240            .field("monitor", &self.monitor)
241            .finish()
242    }
243}
244
245#[async_trait]
246impl<A: Referable> Actor for ActorMeshController<A> {
247    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
248        this.set_system();
249        // Start the monitor task.
250        // There's a shared monitor for all whole mesh ref. Note that slices do
251        // not share the health state. This is fine because requerying a slice
252        // of a mesh will still return any failed state.
253        self.monitor = Some(());
254        self.self_check_state_message(this)?;
255        let owner = if let Some(owner) = &self.health_state.owner {
256            owner.to_string()
257        } else {
258            String::from("None")
259        };
260        tracing::info!(actor_id = %this.self_id(), %owner, "started mesh controller for {}", self.mesh.name());
261        Ok(())
262    }
263
264    async fn cleanup(
265        &mut self,
266        this: &Instance<Self>,
267        _err: Option<&ActorError>,
268    ) -> Result<(), anyhow::Error> {
269        // If the monitor hasn't been dropped yet, send a stop message to the
270        // proc mesh.
271        if self.monitor.take().is_some() {
272            tracing::info!(actor_id = %this.self_id(), actor_mesh = %self.mesh.name(), "starting cleanup for ActorMeshController, stopping actor mesh");
273            self.stop(this, "actor mesh controller cleanup".to_string())
274                .await?;
275        }
276        Ok(())
277    }
278
279    async fn handle_undeliverable_message(
280        &mut self,
281        cx: &Instance<Self>,
282        mut envelope: Undeliverable<MessageEnvelope>,
283    ) -> Result<(), anyhow::Error> {
284        // Update the destination in case this was a casting message.
285        envelope = update_undeliverable_envelope_for_casting(envelope);
286        if let Some(true) = envelope.0.headers().get(ACTOR_MESH_SUBSCRIBER_MESSAGE) {
287            // Remove from the subscriber list (if it existed) so we don't
288            // send to this subscriber again.
289            // NOTE: The only part of the port that is used for equality checks is
290            // the port id, so create a new one just for the comparison.
291            let dest_port_id = envelope.0.dest().clone();
292            let port = hyperactor_reference::PortRef::<Option<MeshFailure>>::attest(dest_port_id);
293            let did_exist = self.health_state.subscribers.remove(&port);
294            if did_exist {
295                tracing::debug!(
296                    actor_id = %cx.self_id(),
297                    num_subscribers = self.health_state.subscribers.len(),
298                    "ActorMeshController: handle_undeliverable_message: removed subscriber {} from mesh controller",
299                    port.port_id()
300                );
301            }
302            Ok(())
303        } else {
304            handle_undeliverable_message(cx, envelope)
305        }
306    }
307}
308
309#[async_trait]
310impl<A: Referable> Handler<Subscribe> for ActorMeshController<A> {
311    async fn handle(&mut self, cx: &Context<Self>, message: Subscribe) -> anyhow::Result<()> {
312        // If we can't send a message to a subscriber, the subscriber might be gone.
313        // That shouldn't cause this actor to exit.
314        // This is handled by the handle_undeliverable_message method.
315        match &self.health_state.unhealthy_event {
316            None => {}
317            // For an adverse event like stopped or crashed, send a notification
318            // immediately. This represents an initial bad state, if subscribing
319            // to an already-dead mesh.
320            Some(Unhealthy::StreamClosed(msg)) => {
321                send_subscriber_message(cx, &message.0, msg.clone());
322            }
323            Some(Unhealthy::Crashed(msg)) => {
324                send_subscriber_message(cx, &message.0, msg.clone());
325            }
326        }
327        let port_id = message.0.port_id().clone();
328        if self.health_state.subscribers.insert(message.0) {
329            tracing::debug!(actor_id = %cx.self_id(), num_subscribers = self.health_state.subscribers.len(), "added subscriber {} to mesh controller", port_id);
330        }
331        Ok(())
332    }
333}
334
335#[async_trait]
336impl<A: Referable> Handler<Unsubscribe> for ActorMeshController<A> {
337    async fn handle(&mut self, cx: &Context<Self>, message: Unsubscribe) -> anyhow::Result<()> {
338        if self.health_state.subscribers.remove(&message.0) {
339            tracing::debug!(actor_id = %cx.self_id(), num_subscribers = self.health_state.subscribers.len(), "removed subscriber {} from mesh controller", message.0.port_id());
340        }
341        Ok(())
342    }
343}
344
345#[async_trait]
346impl<A: Referable> Handler<GetSubscriberCount> for ActorMeshController<A> {
347    async fn handle(
348        &mut self,
349        cx: &Context<Self>,
350        message: GetSubscriberCount,
351    ) -> anyhow::Result<()> {
352        message.0.send(cx, self.health_state.subscribers.len())?;
353        Ok(())
354    }
355}
356
357#[async_trait]
358impl<A: Referable> Handler<resource::CreateOrUpdate<resource::mesh::Spec<()>>>
359    for ActorMeshController<A>
360{
361    /// Currently a no-op as there's nothing to create or update, but allows
362    /// ActorMeshController to implement the resource mesh behavior.
363    async fn handle(
364        &mut self,
365        _cx: &Context<Self>,
366        _message: resource::CreateOrUpdate<resource::mesh::Spec<()>>,
367    ) -> anyhow::Result<()> {
368        Ok(())
369    }
370}
371
372#[async_trait]
373impl<A: Referable> Handler<resource::GetState<resource::mesh::State<()>>>
374    for ActorMeshController<A>
375{
376    async fn handle(
377        &mut self,
378        cx: &Context<Self>,
379        message: resource::GetState<resource::mesh::State<()>>,
380    ) -> anyhow::Result<()> {
381        let status = if let Some(Unhealthy::Crashed(e)) = &self.health_state.unhealthy_event {
382            resource::Status::Failed(e.to_string())
383        } else if let Some(Unhealthy::StreamClosed(_)) = &self.health_state.unhealthy_event {
384            resource::Status::Stopped
385        } else {
386            resource::Status::Running
387        };
388        let statuses = &self.health_state.statuses;
389        let mut statuses = statuses.clone().into_iter().collect::<Vec<_>>();
390        statuses.sort_by_key(|(p, _)| p.rank());
391        let statuses: ValueMesh<resource::Status> =
392            statuses
393                .into_iter()
394                .map(|(_, s)| s)
395                .collect_mesh::<ValueMesh<_>>(self.mesh.region().clone())?;
396        let state = resource::mesh::State {
397            statuses,
398            state: (),
399        };
400        message.reply.send(
401            cx,
402            resource::State {
403                name: message.name,
404                status,
405                state: Some(state),
406            },
407        )?;
408        Ok(())
409    }
410}
411
412#[async_trait]
413impl<A: Referable> Handler<resource::Stop> for ActorMeshController<A> {
414    async fn handle(&mut self, cx: &Context<Self>, message: resource::Stop) -> anyhow::Result<()> {
415        let mesh = &self.mesh;
416        let mesh_name = mesh.name();
417        tracing::info!(
418            name = "ActorMeshControllerStatus",
419            %mesh_name,
420            reason = %message.reason,
421            "stopping actor mesh"
422        );
423        // Run the drop on the monitor loop. The actors will not change state
424        // after this point, because they will be stopped.
425        // This message is idempotent because multiple stops only send out one
426        // set of messages to subscribers.
427        if self.monitor.take().is_none() {
428            tracing::debug!(actor_id = %cx.self_id(), actor_mesh = %mesh_name, "duplicate stop request, actor mesh is already stopped");
429            return Ok(());
430        }
431        tracing::info!(actor_id = %cx.self_id(), actor_mesh = %mesh_name, "forwarding stop request from ActorMeshController to proc mesh");
432
433        // Let the client know that the controller has stopped. Since the monitor
434        // is cancelled, it will not alert the owner or the subscribers.
435        // We use a placeholder rank to get an actor id, but really there should
436        // be a stop event for every rank in the mesh. Since every rank has the
437        // same owner, we assume the rank doesn't matter, and the owner can just
438        // assume the stop happened on all actors.
439        let rank = 0usize;
440        let event = ActorSupervisionEvent::new(
441            // Use an actor id from the mesh.
442            mesh.get(rank).unwrap().actor_id().clone(),
443            None,
444            ActorStatus::Stopped("ActorMeshController received explicit stop request".to_string()),
445            None,
446        );
447        let failure_message = MeshFailure {
448            actor_mesh_name: Some(mesh_name.to_string()),
449            // Rank = none means it affects the whole mesh.
450            rank: None,
451            event,
452        };
453        self.health_state.unhealthy_event = Some(Unhealthy::StreamClosed(failure_message.clone()));
454        // We don't send a message to the owner on stops, because only the owner
455        // can request a stop. We just send to subscribers instead, as they did
456        // not request the stop themselves.
457        for subscriber in self.health_state.subscribers.iter() {
458            send_subscriber_message(cx, subscriber, failure_message.clone());
459        }
460
461        // max_rank and extent are only needed for the deprecated RankedValues.
462        // TODO: add cmp::Ord to Point for a max() impl.
463        let max_rank = self.health_state.statuses.keys().map(|p| p.rank()).max();
464        let extent = self
465            .health_state
466            .statuses
467            .keys()
468            .next()
469            .map(|p| p.extent().clone());
470        // Send a stop message to the ProcAgent for these actors.
471        match self.stop(cx, message.reason.clone()).await {
472            Ok(statuses) => {
473                // All stops successful, set actor status on health state.
474                for (rank, status) in statuses.iter() {
475                    self.health_state
476                        .statuses
477                        .entry(rank)
478                        .and_modify(move |s| *s = status);
479                }
480            }
481            // An ActorStopError means some actors didn't reach the stopped state.
482            Err(crate::Error::ActorStopError { statuses }) => {
483                // If there are no states yet, nothing to update.
484                if let Some(max_rank) = max_rank {
485                    let extent = extent.expect("no actors in mesh");
486                    for (rank, status) in statuses.materialized_iter(max_rank).enumerate() {
487                        *self
488                            .health_state
489                            .statuses
490                            .get_mut(&extent.point_of_rank(rank).expect("illegal rank"))
491                            .unwrap() = status.clone();
492                    }
493                }
494            }
495            // Other error types should be reported as supervision errors.
496            Err(e) => {
497                return Err(e.into());
498            }
499        }
500
501        tracing::info!(actor_id = %cx.self_id(), actor_mesh = %mesh_name, "stopped mesh");
502        Ok(())
503    }
504}
505
506/// Like send_state_change, but when there was no state change that occurred.
507/// Will send a None message to subscribers, and there is no state to change.
508/// Is not sent to the owner, because the owner is only watching for failures.
509/// Should be called once every so often so subscribers can discern the difference
510/// between "no messages because no errors" and "no messages because controller died".
511/// Without sending these hearbeats, subscribers will assume the mesh is dead.
512fn send_heartbeat(cx: &impl context::Actor, health_state: &HealthState) {
513    tracing::debug!(
514        num_subscribers = health_state.subscribers.len(),
515        "sending heartbeat to subscribers",
516    );
517
518    for subscriber in health_state.subscribers.iter() {
519        let mut headers = Flattrs::new();
520        headers.set(ACTOR_MESH_SUBSCRIBER_MESSAGE, true);
521        if let Err(e) = subscriber.send_with_headers(cx, headers, None) {
522            tracing::warn!(subscriber = %subscriber.port_id(), "error sending heartbeat message: {:?}", e);
523        }
524    }
525}
526
527/// Sends a MeshFailure to the owner and subscribers of this mesh,
528/// and changes the health state stored unhealthy_event.
529/// Owners are sent a message only for Failure events, not for Stopped events.
530/// Subscribers are sent both Stopped and Failure events.
531fn send_state_change(
532    cx: &impl context::Actor,
533    rank: usize,
534    event: ActorSupervisionEvent,
535    mesh_name: &Name,
536    is_proc_stopped: bool,
537    health_state: &mut HealthState,
538) {
539    // This does not include the Stopped status, which is a state that occurs when the
540    // user calls stop() on a proc or actor mesh.
541    let is_failed = event.is_error();
542    if is_failed {
543        tracing::warn!(
544            name = "SupervisionEvent",
545            actor_mesh = %mesh_name,
546            %event,
547            "detected supervision error on monitored mesh: name={mesh_name}",
548        );
549    } else {
550        tracing::debug!(
551            name = "SupervisionEvent",
552            actor_mesh = %mesh_name,
553            %event,
554            "detected non-error supervision event on monitored mesh: name={mesh_name}",
555        );
556    }
557
558    let failure_message = MeshFailure {
559        actor_mesh_name: Some(mesh_name.to_string()),
560        rank: Some(rank),
561        event: event.clone(),
562    };
563    health_state.crashed_ranks.insert(rank, event.clone());
564    health_state.unhealthy_event = Some(if is_proc_stopped {
565        Unhealthy::StreamClosed(failure_message.clone())
566    } else {
567        Unhealthy::Crashed(failure_message.clone())
568    });
569    // Send a notification to the owning actor of this mesh, if there is one.
570    // Don't send a message to the owner for non-failure events such as "stopped".
571    // Those events are always initiated by the owner, who don't need to be
572    // told that they were stopped.
573    if is_failed {
574        if let Some(owner) = &health_state.owner {
575            if let Err(error) = owner.send(cx, failure_message.clone()) {
576                tracing::warn!(
577                    name = "SupervisionEvent",
578                    actor_mesh = %mesh_name,
579                    %event,
580                    %error,
581                    "failed to send supervision event to owner {}: {}. dropping event",
582                    owner.port_id(),
583                    error
584                );
585            } else {
586                tracing::info!(actor_mesh = %mesh_name, %event, "sent supervision failure message to owner {}", owner.port_id());
587            }
588        }
589    }
590    // Subscribers get all messages, even for non-failures like Stopped, because
591    // they need to know if the owner stopped the mesh.
592    for subscriber in health_state.subscribers.iter() {
593        send_subscriber_message(cx, subscriber, failure_message.clone());
594    }
595}
596
597fn actor_state_to_supervision_events(
598    state: resource::State<ActorState>,
599) -> (usize, Vec<ActorSupervisionEvent>) {
600    let (rank, actor_id, events) = match state.state {
601        Some(inner) => (
602            inner.create_rank,
603            Some(inner.actor_id),
604            inner.supervision_events.clone(),
605        ),
606        None => (0, None, vec![]),
607    };
608    let events = match state.status {
609        // If the actor was killed, it might not have a Failed status
610        // or supervision events, and it can't tell us which rank
611        resource::Status::NotExist | resource::Status::Stopped | resource::Status::Timeout(_) => {
612            // it was.
613            if !events.is_empty() {
614                events
615            } else {
616                vec![ActorSupervisionEvent::new(
617                    actor_id.expect("actor_id is None"),
618                    None,
619                    ActorStatus::Stopped(
620                        format!(
621                            "actor status is {}; actor may have been killed",
622                            state.status
623                        )
624                        .to_string(),
625                    ),
626                    None,
627                )]
628            }
629        }
630        resource::Status::Failed(_) => events,
631        // All other states are successful.
632        _ => vec![],
633    };
634    (rank, events)
635}
636
637/// Map a process-level [`ProcStatus`] to an actor-level [`ActorStatus`].
638///
639/// When the supervision poll discovers that a process is terminating, this
640/// function decides whether to treat it as a clean stop or a failure.
641/// Notably, [`ProcStatus::Stopping`] (SIGTERM sent, process not yet exited)
642/// is mapped to [`ActorStatus::Stopped`] rather than [`ActorStatus::Failed`]
643/// so that a graceful shutdown in progress does not trigger unhandled
644/// supervision errors.
645fn proc_status_to_actor_status(proc_status: Option<ProcStatus>) -> ActorStatus {
646    match proc_status {
647        Some(ProcStatus::Stopped { exit_code: 0, .. }) => {
648            ActorStatus::Stopped("process exited cleanly".to_string())
649        }
650        Some(ProcStatus::Stopped { exit_code, .. }) => ActorStatus::Failed(
651            ActorErrorKind::Generic(format!("process exited with non-zero code {}", exit_code)),
652        ),
653        // Stopping is a transient state during graceful shutdown. Treat it the
654        // same as a clean stop rather than a failure.
655        Some(ProcStatus::Stopping { .. }) => {
656            ActorStatus::Stopped("process is stopping".to_string())
657        }
658        // Conservatively treat lack of status as stopped
659        None => ActorStatus::Stopped("no status received from process".to_string()),
660        Some(status) => ActorStatus::Failed(ActorErrorKind::Generic(format!(
661            "process failure: {}",
662            status
663        ))),
664    }
665}
666
667fn format_system_time(time: std::time::SystemTime) -> String {
668    let datetime: chrono::DateTime<chrono::Local> = time.into();
669    datetime.format("%Y-%m-%d %H:%M:%S").to_string()
670}
671
672#[async_trait]
673impl<A: Referable> Handler<CheckState> for ActorMeshController<A> {
674    /// Checks actor states and reschedules as a self-message.
675    ///
676    /// When any actor in this mesh changes state,
677    /// including once for the initial state of all actors, send a message to the
678    /// owners and subscribers of this mesh.
679    /// The receivers will get a MeshFailure. The created rank is
680    /// the original rank of the actor on the mesh, not the rank after
681    /// slicing.
682    ///
683    /// * SUPERVISION_POLL_FREQUENCY controls how frequently to poll.
684    /// * self-messaging stops when self.monitor is set to None.
685    async fn handle(
686        &mut self,
687        cx: &Context<Self>,
688        CheckState(expected_time): CheckState,
689    ) -> Result<(), anyhow::Error> {
690        // This implementation polls every "time_between_checks" duration, checking
691        // for changes in the actor states. It can be improved in two ways:
692        // 1. Use accumulation, to get *any* actor with a change in state, not *all*
693        //    actors.
694        // 2. Use a push-based mode instead of polling.
695        // Wait in between checking to avoid using too much network.
696
697        // Check for stalls in the supervision loop. These delays can cause the
698        // subscribers to think the controller is dead.
699        // Allow a little slack time to avoid logging for innocuous delays.
700        // If it's greater than 2x the expected time, log a warning.
701        if std::time::SystemTime::now()
702            > expected_time + hyperactor_config::global::get(SUPERVISION_POLL_FREQUENCY)
703        {
704            // Current time is included by default in the log message.
705            let expected_time = format_system_time(expected_time);
706            // Track in both metrics and tracing.
707            ACTOR_MESH_CONTROLLER_SUPERVISION_STALLS.add(1, kv_pairs!("actor_id" => cx.self_id().to_string(), "expected_time" => expected_time.clone()));
708            tracing::warn!(
709                actor_id = %cx.self_id(),
710                "Handler<CheckState> is being stalled, expected at {}",
711                expected_time,
712            );
713        }
714        let mesh = &self.mesh;
715        let supervision_display_name = &self.supervision_display_name;
716        // First check if the proc mesh is dead before trying to query their agents.
717        let proc_states = mesh.proc_mesh().proc_states(cx).await;
718        if let Err(e) = proc_states {
719            send_state_change(
720                cx,
721                0,
722                ActorSupervisionEvent::new(
723                    cx.self_id().clone(),
724                    None,
725                    ActorStatus::generic_failure(format!(
726                        "unable to query for proc states: {:?}",
727                        e
728                    )),
729                    None,
730                ),
731                mesh.name(),
732                false,
733                &mut self.health_state,
734            );
735            self.self_check_state_message(cx)?;
736            return Ok(());
737        }
738        if let Some(proc_states) = proc_states.unwrap() {
739            // Check if the proc mesh is still alive.
740            if let Some((point, state)) = proc_states
741                .iter()
742                .find(|(_rank, state)| state.status.is_terminating())
743            {
744                // TODO: allow "actor supervision event" to be general, and
745                // make the proc failure the cause. It is a hack to try to determine
746                // the correct status based on process exit status.
747                let actor_status =
748                    proc_status_to_actor_status(state.state.and_then(|s| s.proc_status));
749                let display_name = crate::actor_display_name(supervision_display_name, &point);
750                send_state_change(
751                    cx,
752                    point.rank(),
753                    ActorSupervisionEvent::new(
754                        // Attribute this to the monitored actor, even if the underlying
755                        // cause is a proc_failure. We propagate the cause explicitly.
756                        mesh.get(point.rank()).unwrap().actor_id().clone(),
757                        Some(format!("{} was running on a process which", display_name)),
758                        actor_status,
759                        None,
760                    ),
761                    mesh.name(),
762                    true,
763                    &mut self.health_state,
764                );
765                self.self_check_state_message(cx)?;
766                return Ok(());
767            }
768        }
769
770        // Now that we know the proc mesh is alive, check for actor state changes.
771        let orphan_timeout = hyperactor_config::global::get(MESH_ORPHAN_TIMEOUT);
772        let keepalive = if orphan_timeout.is_zero() {
773            None
774        } else {
775            Some(std::time::SystemTime::now() + orphan_timeout)
776        };
777        let events = mesh.actor_states_with_keepalive(cx, keepalive).await;
778        if let Err(e) = events {
779            send_state_change(
780                cx,
781                0,
782                ActorSupervisionEvent::new(
783                    cx.self_id().clone(),
784                    Some(supervision_display_name.clone()),
785                    ActorStatus::generic_failure(format!(
786                        "unable to query for actor states: {:?}",
787                        e
788                    )),
789                    None,
790                ),
791                mesh.name(),
792                false,
793                &mut self.health_state,
794            );
795            self.self_check_state_message(cx)?;
796            return Ok(());
797        }
798        // If there was any state change, we don't need to send a heartbeat.
799        let mut did_send_state_change = false;
800        // True if any rank is in a terminal status. Once that is true, no more
801        // heartbeats are sent.
802        let mut is_terminal = false;
803        // This returned point is the created rank, *not* the rank of
804        // the possibly sliced input mesh.
805        for (point, state) in events.unwrap().iter() {
806            let mut is_new = false;
807            let entry = self
808                .health_state
809                .statuses
810                .entry(point.clone())
811                .or_insert_with(|| {
812                    tracing::debug!(
813                        "PythonActorMeshImpl: received initial state: point={:?}, state={:?}",
814                        point,
815                        state
816                    );
817                    let (_rank, events) = actor_state_to_supervision_events(state.clone());
818                    // Wait for next event if the change in state produced no supervision events.
819                    if !events.is_empty() {
820                        is_new = true;
821                    }
822                    state.status.clone()
823                });
824            // If the status of any rank is terminal, we don't want to send
825            // a heartbeat message.
826            if !is_terminal && entry.is_terminating() {
827                is_terminal = true;
828            }
829            // If this actor is new, or the state changed, send a message to the owner.
830            let (rank, event) = if is_new {
831                let (rank, events) = actor_state_to_supervision_events(state.clone());
832                if events.is_empty() {
833                    continue;
834                }
835                (rank, events[0].clone())
836            } else if *entry != state.status {
837                tracing::debug!(
838                    "PythonActorMeshImpl: received state change event: point={:?}, old_state={:?}, new_state={:?}",
839                    point,
840                    entry,
841                    state
842                );
843                let (rank, events) = actor_state_to_supervision_events(state.clone());
844                if events.is_empty() {
845                    continue;
846                }
847                *entry = state.status;
848                (rank, events[0].clone())
849            } else {
850                continue;
851            };
852            did_send_state_change = true;
853            send_state_change(cx, rank, event, mesh.name(), false, &mut self.health_state);
854        }
855        if !did_send_state_change && !is_terminal {
856            // No state change, but subscribers need to be sent a message
857            // every so often so they know the controller is still alive.
858            // Send a "no state change" message.
859            // Only if the last state for any actor in this mesh is not a terminal state.
860            send_heartbeat(cx, &self.health_state);
861        }
862
863        // If all ranks are in a terminal state, we don't need to continue checking,
864        // as statuses cannot change.
865        // Any new subscribers will get an immediate message saying the mesh is stopped.
866        let all_ranks_terminal = self
867            .health_state
868            .statuses
869            .values()
870            .all(|s| s.is_terminating());
871        if !all_ranks_terminal {
872            // Schedule a self send after a waiting period.
873            self.self_check_state_message(cx)?;
874        } else {
875            // There's no need to send a stop message during cleanup if all the
876            // ranks are already terminal.
877            self.monitor.take();
878        }
879        return Ok(());
880    }
881}
882
883#[derive(Debug)]
884#[hyperactor::export]
885pub(crate) struct ProcMeshController {
886    mesh: ProcMeshRef,
887}
888
889impl ProcMeshController {
890    /// Create a new proc controller based on the provided reference.
891    pub(crate) fn new(mesh: ProcMeshRef) -> Self {
892        Self { mesh }
893    }
894}
895
896#[async_trait]
897impl Actor for ProcMeshController {
898    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
899        this.set_system();
900        Ok(())
901    }
902
903    async fn cleanup(
904        &mut self,
905        this: &Instance<Self>,
906        _err: Option<&ActorError>,
907    ) -> Result<(), anyhow::Error> {
908        // Cannot use "ProcMesh::stop" as it's only defined on ProcMesh, not ProcMeshRef.
909        let names = self
910            .mesh
911            .proc_ids()
912            .collect::<Vec<hyperactor_reference::ProcId>>();
913        let region = self.mesh.region().clone();
914        if let Some(hosts) = self.mesh.hosts() {
915            hosts
916                .stop_proc_mesh(
917                    this,
918                    self.mesh.name(),
919                    names,
920                    region,
921                    "proc mesh controller cleanup".to_string(),
922                )
923                .await
924        } else {
925            Ok(())
926        }
927    }
928}
929
930#[derive(Debug)]
931#[hyperactor::export]
932pub(crate) struct HostMeshController {
933    mesh: HostMeshRef,
934}
935
936impl HostMeshController {
937    /// Create a new host controller based on the provided reference.
938    pub(crate) fn new(mesh: HostMeshRef) -> Self {
939        Self { mesh }
940    }
941}
942
943#[async_trait]
944impl Actor for HostMeshController {
945    async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
946        this.set_system();
947        Ok(())
948    }
949
950    async fn cleanup(
951        &mut self,
952        this: &Instance<Self>,
953        _err: Option<&ActorError>,
954    ) -> Result<(), anyhow::Error> {
955        // Cannot use "HostMesh::shutdown" as it's only defined on HostMesh, not HostMeshRef.
956        for host in self.mesh.values() {
957            if let Err(e) = host.shutdown(this).await {
958                tracing::warn!(host = %host, error = %e, "host shutdown failed");
959            }
960        }
961        Ok(())
962    }
963}
964
965#[cfg(test)]
966mod tests {
967    use std::ops::Deref;
968    use std::time::Duration;
969
970    use hyperactor::actor::ActorStatus;
971    use ndslice::Extent;
972    use ndslice::ViewExt;
973
974    use super::SUPERVISION_POLL_FREQUENCY;
975    use super::proc_status_to_actor_status;
976    use crate::ActorMesh;
977    use crate::Name;
978    use crate::bootstrap::ProcStatus;
979    use crate::proc_agent::MESH_ORPHAN_TIMEOUT;
980    use crate::resource;
981    use crate::supervision::MeshFailure;
982    use crate::test_utils::local_host_mesh;
983    use crate::testactor;
984    use crate::testing;
985
986    /// Verify that actors spawned without a controller are cleaned up
987    /// when their keepalive expiry lapses. We:
988    ///   1. Enable the orphan timeout on the `ProcMeshAgent`.
989    ///   2. Spawn actors as *system actors* (no `ActorMeshController`).
990    ///   3. Send a single keepalive with a short expiry time.
991    ///   4. Wait for the expiry to pass and `SelfCheck` to fire.
992    ///   5. Assert that the actors are now stopped.
993    #[tokio::test]
994    async fn test_orphaned_actors_are_cleaned_up() {
995        let config = hyperactor_config::global::lock();
996        // Short orphan timeout so SelfCheck fires frequently.
997        let _orphan = config.override_key(MESH_ORPHAN_TIMEOUT, Duration::from_secs(1));
998
999        let instance = testing::instance();
1000        let host_mesh = local_host_mesh(2).await;
1001        let proc_mesh = host_mesh
1002            .spawn(instance, "test", Extent::unity())
1003            .await
1004            .unwrap();
1005
1006        let actor_name = Name::new("orphan_test").unwrap();
1007        // Spawn as a system actor so no controller is created. This lets us
1008        // control keepalive messages directly without the controller
1009        // interfering.
1010        let actor_mesh: ActorMesh<testactor::TestActor> = proc_mesh
1011            .spawn_with_name(instance, actor_name.clone(), &(), None, true)
1012            .await
1013            .unwrap();
1014        assert!(
1015            actor_mesh.deref().extent().num_ranks() > 0,
1016            "should have spawned at least one actor"
1017        );
1018
1019        // Send a keepalive with a short expiry. This is what the
1020        // ActorMeshController would normally do on each supervision poll.
1021        let states = proc_mesh
1022            .actor_states_with_keepalive(
1023                instance,
1024                actor_name.clone(),
1025                Some(std::time::SystemTime::now() + Duration::from_secs(2)),
1026            )
1027            .await
1028            .unwrap();
1029        // All actors should be running right now.
1030        for state in states.values() {
1031            assert_eq!(
1032                state.status,
1033                resource::Status::Running,
1034                "actor should be running before expiry"
1035            );
1036        }
1037
1038        // Wait long enough for the expiry to pass and at least one
1039        // SelfCheck cycle to fire. With MESH_ORPHAN_TIMEOUT = 1s and
1040        // expiry in 2s, by around 4s at least two SelfCheck cycles will
1041        // have elapsed after the expiry.
1042        tokio::time::sleep(Duration::from_secs(5)).await;
1043
1044        // Query again, this time *without* a keepalive so we don't
1045        // extend the expiry.
1046        let states = proc_mesh
1047            .actor_states(instance, actor_name.clone())
1048            .await
1049            .unwrap();
1050        for state in states.values() {
1051            assert_eq!(
1052                state.status,
1053                resource::Status::Stopped,
1054                "actor should be stopped after keepalive expiry"
1055            );
1056        }
1057    }
1058
1059    /// Create a multi-process host mesh that propagates the current
1060    /// process's config overrides to child processes via Bootstrap.
1061    #[cfg(fbcode_build)]
1062    async fn host_mesh_with_config(n: usize) -> crate::host_mesh::HostMesh {
1063        use hyperactor::channel::ChannelTransport;
1064        use tokio::process::Command;
1065
1066        let program = crate::testresource::get("monarch/hyperactor_mesh/bootstrap");
1067        let mut host_addrs = vec![];
1068        for _ in 0..n {
1069            host_addrs.push(ChannelTransport::Unix.any());
1070        }
1071
1072        for host in host_addrs.iter() {
1073            let mut cmd = Command::new(program.clone());
1074            let boot = crate::Bootstrap::Host {
1075                addr: host.clone(),
1076                command: None,
1077                config: Some(hyperactor_config::global::attrs()),
1078                exit_on_shutdown: false,
1079            };
1080            boot.to_env(&mut cmd);
1081            cmd.kill_on_drop(false);
1082            // SAFETY: pre_exec sets PR_SET_PDEATHSIG so the child is
1083            // cleaned up if the parent (test) process dies.
1084            unsafe {
1085                cmd.pre_exec(crate::bootstrap::install_pdeathsig_kill);
1086            }
1087            cmd.spawn().unwrap();
1088        }
1089
1090        let host_mesh = crate::HostMeshRef::from_hosts(Name::new("test").unwrap(), host_addrs);
1091        crate::host_mesh::HostMesh::take(host_mesh)
1092    }
1093
1094    /// Verify that actors are cleaned up via the orphan timeout when the
1095    /// `ActorMeshController`'s process crashes. Unlike the system-actor test
1096    /// above, this spawns actors through a real controller (via `WrapperActor`)
1097    /// and then kills the controller's process uncleanly with `ProcessExit`.
1098    /// The agents on the surviving proc mesh detect the expired keepalive
1099    /// and stop the actors.
1100    #[tokio::test]
1101    #[cfg(fbcode_build)]
1102    async fn test_orphaned_actors_cleaned_up_on_controller_crash() {
1103        let config = hyperactor_config::global::lock();
1104        let _orphan = config.override_key(MESH_ORPHAN_TIMEOUT, Duration::from_secs(2));
1105        let _poll = config.override_key(SUPERVISION_POLL_FREQUENCY, Duration::from_secs(1));
1106
1107        let instance = testing::instance();
1108        let num_replicas = 2;
1109
1110        // Host mesh for the test actors (these survive the crash).
1111        // host_mesh_with_config propagates config overrides to child
1112        // processes via Bootstrap, so agents boot with
1113        // MESH_ORPHAN_TIMEOUT=2s and start the SelfCheck loop.
1114        let mut actor_hm = host_mesh_with_config(num_replicas).await;
1115        let actor_proc_mesh = actor_hm
1116            .spawn(instance, "actors", Extent::unity())
1117            .await
1118            .unwrap();
1119
1120        // Host mesh for the wrapper + controller (will be killed).
1121        let mut controller_hm = host_mesh_with_config(1).await;
1122        let controller_proc_mesh = controller_hm
1123            .spawn(instance, "controller", Extent::unity())
1124            .await
1125            .unwrap();
1126
1127        let child_name = Name::new("orphan_child").unwrap();
1128
1129        // Supervision port required by WrapperActor params.
1130        let (supervision_port, _supervision_receiver) = instance.open_port::<MeshFailure>();
1131        let supervisor = supervision_port.bind();
1132
1133        // Spawn WrapperActor on controller_proc_mesh. Its init() spawns
1134        // ActorMesh<TestActor> on actor_proc_mesh with a real
1135        // ActorMeshController co-located on the controller's process.
1136        let wrapper_mesh: ActorMesh<testactor::WrapperActor> = controller_proc_mesh
1137            .spawn(
1138                instance,
1139                "wrapper",
1140                &(
1141                    actor_proc_mesh.deref().clone(),
1142                    supervisor,
1143                    child_name.clone(),
1144                ),
1145            )
1146            .await
1147            .unwrap();
1148
1149        // Give the controller time to run at least one CheckState cycle
1150        // (polling every 1s) so it sends KeepaliveGetState to the agents.
1151        tokio::time::sleep(Duration::from_secs(3)).await;
1152
1153        // Verify actors are running before the crash.
1154        let states = actor_proc_mesh
1155            .actor_states(instance, child_name.clone())
1156            .await
1157            .unwrap();
1158        for state in states.values() {
1159            assert_eq!(
1160                state.status,
1161                resource::Status::Running,
1162                "actor should be running before controller crash"
1163            );
1164        }
1165
1166        // Kill the controller's process uncleanly. send_to_children: false
1167        // means only the WrapperActor's process exits; the TestActors on
1168        // actor_proc_mesh survive.
1169        wrapper_mesh
1170            .cast(
1171                instance,
1172                testactor::CauseSupervisionEvent {
1173                    kind: testactor::SupervisionEventType::ProcessExit(1),
1174                    send_to_children: false,
1175                },
1176            )
1177            .unwrap();
1178
1179        // Wait for:
1180        //  - keepalive expiry (2s from last CheckState)
1181        //  - at least one SelfCheck cycle (every 2s)
1182        //  - margin for processing
1183        tokio::time::sleep(Duration::from_secs(8)).await;
1184
1185        // Actors should now be stopped via the orphan timeout.
1186        let states = actor_proc_mesh
1187            .actor_states(instance, child_name.clone())
1188            .await
1189            .unwrap();
1190        for state in states.values() {
1191            assert_eq!(
1192                state.status,
1193                resource::Status::Stopped,
1194                "actor should be stopped after controller crash and orphan timeout"
1195            );
1196        }
1197
1198        let _ = actor_hm.shutdown(instance).await;
1199        let _ = controller_hm.shutdown(instance).await;
1200    }
1201
1202    #[test]
1203    fn test_proc_status_to_actor_status_stopped_cleanly() {
1204        let status = proc_status_to_actor_status(Some(ProcStatus::Stopped {
1205            exit_code: 0,
1206            stderr_tail: vec![],
1207        }));
1208        assert!(
1209            matches!(status, ActorStatus::Stopped(ref msg) if msg.contains("cleanly")),
1210            "expected Stopped, got {:?}",
1211            status
1212        );
1213    }
1214
1215    #[test]
1216    fn test_proc_status_to_actor_status_nonzero_exit() {
1217        let status = proc_status_to_actor_status(Some(ProcStatus::Stopped {
1218            exit_code: 1,
1219            stderr_tail: vec![],
1220        }));
1221        assert!(
1222            matches!(status, ActorStatus::Failed(_)),
1223            "expected Failed, got {:?}",
1224            status
1225        );
1226    }
1227
1228    #[test]
1229    fn test_proc_status_to_actor_status_stopping_is_not_a_failure() {
1230        let status = proc_status_to_actor_status(Some(ProcStatus::Stopping {
1231            started_at: std::time::SystemTime::now(),
1232        }));
1233        assert!(
1234            matches!(status, ActorStatus::Stopped(ref msg) if msg.contains("stopping")),
1235            "expected Stopped, got {:?}",
1236            status
1237        );
1238    }
1239
1240    #[test]
1241    fn test_proc_status_to_actor_status_none() {
1242        let status = proc_status_to_actor_status(None);
1243        assert!(
1244            matches!(status, ActorStatus::Stopped(_)),
1245            "expected Stopped, got {:?}",
1246            status
1247        );
1248    }
1249
1250    #[test]
1251    fn test_proc_status_to_actor_status_killed() {
1252        let status = proc_status_to_actor_status(Some(ProcStatus::Killed {
1253            signal: 9,
1254            core_dumped: false,
1255        }));
1256        assert!(
1257            matches!(status, ActorStatus::Failed(_)),
1258            "expected Failed, got {:?}",
1259            status
1260        );
1261    }
1262
1263    #[test]
1264    fn test_proc_status_to_actor_status_failed() {
1265        let status = proc_status_to_actor_status(Some(ProcStatus::Failed {
1266            reason: "oom".to_string(),
1267        }));
1268        assert!(
1269            matches!(status, ActorStatus::Failed(_)),
1270            "expected Failed, got {:?}",
1271            status
1272        );
1273    }
1274}