hyperactor_multiprocess/
system_actor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! System actor manages a system.
10
11use std::collections::BTreeSet;
12use std::collections::HashMap;
13use std::collections::HashSet;
14use std::collections::hash_map::Entry;
15use std::fmt::Display;
16use std::fmt::Formatter;
17use std::hash::Hash;
18use std::sync::Arc;
19use std::sync::LazyLock;
20use std::time::SystemTime;
21
22use async_trait::async_trait;
23use dashmap::DashMap;
24use enum_as_inner::EnumAsInner;
25use hyperactor::Actor;
26use hyperactor::ActorHandle;
27use hyperactor::ActorId;
28use hyperactor::ActorRef;
29use hyperactor::Context;
30use hyperactor::HandleClient;
31use hyperactor::Instance;
32use hyperactor::Named;
33use hyperactor::OncePortRef;
34use hyperactor::PortHandle;
35use hyperactor::PortRef;
36use hyperactor::ProcId;
37use hyperactor::RefClient;
38use hyperactor::WorldId;
39use hyperactor::actor::Handler;
40use hyperactor::channel::ChannelAddr;
41use hyperactor::channel::sim::SimAddr;
42use hyperactor::clock::Clock;
43use hyperactor::clock::ClockKind;
44use hyperactor::id;
45use hyperactor::mailbox::BoxedMailboxSender;
46use hyperactor::mailbox::DialMailboxRouter;
47use hyperactor::mailbox::MailboxSender;
48use hyperactor::mailbox::MailboxSenderError;
49use hyperactor::mailbox::MessageEnvelope;
50use hyperactor::mailbox::PortSender;
51use hyperactor::mailbox::Undeliverable;
52use hyperactor::mailbox::mailbox_admin_message::MailboxAdminMessage;
53use hyperactor::mailbox::monitored_return_handle;
54use hyperactor::proc::Proc;
55use hyperactor::reference::Index;
56use serde::Deserialize;
57use serde::Serialize;
58use tokio::time::Duration;
59use tokio::time::Instant;
60
61use super::proc_actor::ProcMessage;
62use crate::proc_actor::Environment;
63use crate::proc_actor::ProcActor;
64use crate::proc_actor::ProcStopResult;
65use crate::supervision::ProcStatus;
66use crate::supervision::ProcSupervisionMessage;
67use crate::supervision::ProcSupervisionState;
68use crate::supervision::WorldSupervisionMessage;
69use crate::supervision::WorldSupervisionState;
70
71/// A snapshot of a single proc.
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73pub struct WorldSnapshotProcInfo {
74    /// The labels of the proc.
75    pub labels: HashMap<String, String>,
76}
77
78impl From<&ProcInfo> for WorldSnapshotProcInfo {
79    fn from(proc_info: &ProcInfo) -> Self {
80        Self {
81            labels: proc_info.labels.clone(),
82        }
83    }
84}
85
86/// A snapshot view of a world in the system.
87#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
88pub struct WorldSnapshot {
89    /// The host procs used to spawn procs in this world. Some caveats:
90    ///   1. The host procs are actually not in this world but in a different
91    ///      "shadow" world. The shadow world's ID can be told from the host
92    ///      ProcId.
93    ///   2. Not all host procs are captured here. This field only captures the
94    ///      hosts that joined before the world were created.
95    pub host_procs: HashSet<ProcId>,
96
97    /// The procs in this world.
98    pub procs: HashMap<ProcId, WorldSnapshotProcInfo>,
99
100    /// The status of the world.
101    pub status: WorldStatus,
102
103    /// Labels attached to this world. They can be used later to query
104    /// world(s) using system snapshot api.
105    pub labels: HashMap<String, String>,
106}
107
108impl WorldSnapshot {
109    fn from_world_filtered(world: &World, filter: &SystemSnapshotFilter) -> Self {
110        WorldSnapshot {
111            host_procs: world.state.host_map.keys().map(|h| &h.0).cloned().collect(),
112            procs: world
113                .state
114                .procs
115                .iter()
116                .map_while(|(k, v)| filter.proc_matches(v).then_some((k.clone(), v.into())))
117                .collect(),
118            status: world.state.status.clone(),
119            labels: world.labels.clone(),
120        }
121    }
122}
123
124/// A snapshot view of the system.
125#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
126pub struct SystemSnapshot {
127    /// Snapshots of all the worlds in this system.
128    pub worlds: HashMap<WorldId, WorldSnapshot>,
129    /// Execution ID of the system.
130    pub execution_id: String,
131}
132
133/// A filter used to filter the snapshot view of the system.
134#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named, Default)]
135pub struct SystemSnapshotFilter {
136    /// The world ids to filter. Empty list matches all.
137    pub worlds: Vec<WorldId>,
138    /// World labels to filter. Empty matches all.
139    pub world_labels: HashMap<String, String>,
140    /// Proc labels to filter. Empty matches all.
141    pub proc_labels: HashMap<String, String>,
142}
143
144impl SystemSnapshotFilter {
145    /// Create an empty filter that matches everything.
146    pub fn all() -> Self {
147        Self {
148            worlds: Vec::new(),
149            world_labels: HashMap::new(),
150            proc_labels: HashMap::new(),
151        }
152    }
153
154    /// Whether the filter matches the given world.
155    fn world_matches(&self, world: &World) -> bool {
156        if !self.worlds.is_empty() && !self.worlds.contains(&world.world_id) {
157            return false;
158        }
159        Self::labels_match(&self.world_labels, &world.labels)
160    }
161
162    fn proc_matches(&self, proc_info: &ProcInfo) -> bool {
163        Self::labels_match(&self.proc_labels, &proc_info.labels)
164    }
165
166    /// Whether the filter matches the given proc labels.
167    fn labels_match(
168        filter_labels: &HashMap<String, String>,
169        labels: &HashMap<String, String>,
170    ) -> bool {
171        filter_labels.is_empty()
172            || filter_labels
173                .iter()
174                .all(|(k, v)| labels.contains_key(k) && labels.get(k).unwrap() == v)
175    }
176}
177
178/// Update the states of worlds, specifically checking if they are unhealthy.
179/// Evict the world if it is unhealthy for too long.
180#[derive(Debug, Clone, PartialEq)]
181struct MaintainWorldHealth;
182
183/// The proc's lifecyle management mode.
184#[derive(Named, Debug, Clone, Serialize, Deserialize, PartialEq)]
185pub enum ProcLifecycleMode {
186    /// Proc is detached, its lifecycle isn't managed by the system.
187    Detached,
188    /// Proc's lifecycle is managed by the system, supervision is enabled for the proc.
189    ManagedBySystem,
190    /// The proc manages the lifecyle of the system, supervision is enabled for the proc.
191    /// System goes down when the proc stops.
192    ManagingSystem,
193}
194
195impl ProcLifecycleMode {
196    /// Whether the lifecycle mode indicates whether proc is managed by/managing system or not.
197    pub fn is_managed(&self) -> bool {
198        matches!(
199            self,
200            ProcLifecycleMode::ManagedBySystem | ProcLifecycleMode::ManagingSystem
201        )
202    }
203}
204
205/// System messages.
206#[derive(
207    hyperactor::Handler,
208    HandleClient,
209    RefClient,
210    Named,
211    Debug,
212    Clone,
213    Serialize,
214    Deserialize,
215    PartialEq
216)]
217pub enum SystemMessage {
218    /// Join the system at the given proc id.
219    Join {
220        /// The world that is being joined.
221        world_id: WorldId,
222        /// The proc id that is joining.
223        proc_id: ProcId,
224        /// Reference to the proc actor managing the proc.
225        proc_message_port: PortRef<ProcMessage>,
226        /// The channel address used to communicate with the proc.
227        proc_addr: ChannelAddr,
228        /// Arbitrary name/value pairs that can be used to identify the proc.
229        labels: HashMap<String, String>,
230        /// The lifecyle mode of the proc.
231        lifecycle_mode: ProcLifecycleMode,
232    },
233
234    /// Create a new world or update an existing world.
235    UpsertWorld {
236        /// The world id.
237        world_id: WorldId,
238        /// The shape of the world.
239        shape: Shape,
240        /// The number of procs per host.
241        num_procs_per_host: usize,
242        /// How to spawn procs in the world.
243        env: Environment,
244        /// Arbitrary name/value pairs that can be used to identify the world.
245        labels: HashMap<String, String>,
246    },
247
248    /// Return a snapshot view of this system. Used for debugging.
249    #[log_level(debug)]
250    Snapshot {
251        /// The filter used to filter the snapshot view.
252        filter: SystemSnapshotFilter,
253        /// Used to return the snapshot view to the caller.
254        #[reply]
255        ret: OncePortRef<SystemSnapshot>,
256    },
257
258    /// Start the shutdown process of everything in this system. It tries to
259    /// shutdown all the procs first, and then the system actor itself.
260    ///
261    /// Note this shutdown sequence is best effort, yet not guaranteed. It is
262    /// possible the system actor/proc might already stop, while the remote
263    /// procs are still in the middle of shutting down.
264    Stop {
265        /// List of worlds to stop. If provided, only the procs belonging to
266        /// the list of worlds are stopped, otherwise all worlds are stopped
267        /// including the system proc itself.
268        worlds: Option<Vec<WorldId>>,
269        /// The timeout used by ProcActor to stop the proc.
270        proc_timeout: Duration,
271        /// Used to return success to the caller.
272        reply_port: OncePortRef<()>,
273    },
274}
275
276/// Errors that can occur inside a system actor.
277#[derive(thiserror::Error, Debug)]
278pub enum SystemActorError {
279    /// A proc is trying to join before a world is created
280    #[error("procs cannot join uncreated world {0}")]
281    UnknownWorldId(WorldId),
282
283    /// Spawn procs failed
284    #[error("failed to spawn procs")]
285    SpawnProcsFailed(#[from] MailboxSenderError),
286
287    /// Host ID does not start with valid prefix.
288    #[error("invalid host {0}: does not start with prefix '{SHADOW_PREFIX}'")]
289    InvalidHostPrefix(HostId),
290
291    /// A host is trying to join the world which already has a joined host with the same ID.
292    #[error("host ID {0} already exists in world")]
293    DuplicatedHostId(HostId),
294
295    /// Trying to get the actor ref for a host that doesn't exist in a world.
296    #[error("host {0} does not exist in world")]
297    HostNotExist(HostId),
298}
299
300/// TODO: add missing doc
301#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
302pub enum Shape {
303    /// A definite N-dimensional shape of the world, the semantics of the shape can be defined
304    /// by the user (TODO: implement this), e.g. in a shape like [3, 2, 2], user will be able
305    /// to express things like dim 0: ai zone, dim 1: rack, dim 2: host.
306    Definite(Vec<usize>),
307    /// Shape is unknown.
308    Unknown,
309}
310
311/// TODO: Toss this world implementation away once we have
312/// a more clearly defined allocation API.
313/// Currently, each world in a system has two worlds beneath:
314/// the actual world and the shadow world. The shadow world
315/// is a world that is used to maintain hosts which in turn
316/// spawn procs for the world.
317/// This is needed in order to support the current scheduler implementation
318/// which does not support per-proc scheduling.
319///
320/// That means, each host is a proc in the shadow world. Each host proc spawns
321/// a number of procs for the actual world.
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct World {
324    /// The world id.
325    world_id: WorldId,
326    /// TODO: add misssing doc
327    scheduler_params: SchedulerParams,
328    /// Artbitrary labels attached to the world.
329    labels: HashMap<String, String>,
330    /// The state of the world.
331    state: WorldState,
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
335struct Host {
336    num_procs_assigned: usize,
337    proc_message_port: PortRef<ProcMessage>,
338    host_rank: usize,
339}
340
341impl Host {
342    fn new(proc_message_port: PortRef<ProcMessage>, host_rank: usize) -> Self {
343        Self {
344            num_procs_assigned: 0,
345            proc_message_port,
346            host_rank,
347        }
348    }
349
350    fn get_assigned_procs(
351        &mut self,
352        world_id: &WorldId,
353        scheduler_params: &mut SchedulerParams,
354    ) -> Vec<ProcId> {
355        // Get Host from hosts given host_id else return empty vec
356        let mut proc_ids = Vec::new();
357
358        // The number of hosts that will be assigned a total of scheduler_params.num_procs_per_host
359        // procs. If scheduler_params.num_procs() is 31 and scheduler_params.num_procs_per_host is 8,
360        // then num_saturated_hosts == 3 even though total number of hosts will be 4.
361        let num_saturated_hosts =
362            scheduler_params.num_procs() / scheduler_params.num_procs_per_host;
363        // If num_saturated_hosts is less than total hosts, then the final host_rank will be equal
364        // to num_saturated_hosts, and should not be assigned the full scheduler_params.num_procs_per_host.
365        // Instead, we should only assign the remaining procs. So if num_procs is 31, num_procs_per_host is 8,
366        // then host_rank 3 should only be assigned 7 procs.
367        let num_scheduled = if self.host_rank == num_saturated_hosts {
368            scheduler_params.num_procs() % scheduler_params.num_procs_per_host
369        } else {
370            scheduler_params.num_procs_per_host
371        };
372
373        scheduler_params.num_procs_scheduled += num_scheduled;
374
375        for _ in 0..num_scheduled {
376            // Compute each proc id (which will become the RANK env var on each worker)
377            // based on host_rank, which is (optionally) assigned to each host at bootstrap
378            // time according to a sorted hostname file.
379            //
380            // More precisely, when a host process starts up, it gets its host rank from some
381            // global source of truth common to all host nodes. This source of truth could be
382            // a file or an env var. In order to be consistent with the SPMD world, assuming
383            // num_procs_per_host == N, we would want worker ranks 0 through N-1 on host 0;
384            // ranks N through 2N-1 on host 1; etc. So, for host H, we assign proc ids in the
385            // interval [H*N, (H+1)*N).
386            let rank =
387                self.host_rank * scheduler_params.num_procs_per_host + self.num_procs_assigned;
388            let proc_id = ProcId::Ranked(world_id.clone(), rank);
389            proc_ids.push(proc_id);
390            self.num_procs_assigned += 1;
391        }
392
393        proc_ids
394    }
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
398struct SchedulerParams {
399    shape: Shape,
400    num_procs_scheduled: usize,
401    num_procs_per_host: usize,
402    next_rank: Index,
403    env: Environment,
404}
405
406impl SchedulerParams {
407    fn num_procs(&self) -> usize {
408        match &self.shape {
409            Shape::Definite(v) => v.iter().product(),
410            Shape::Unknown => unimplemented!(),
411        }
412    }
413}
414
415/// A world id that is used to identify a host.
416pub type HostWorldId = WorldId;
417static SHADOW_PREFIX: &str = "host";
418
419/// A host id that is used to identify a host.
420#[derive(
421    Debug,
422    Serialize,
423    Deserialize,
424    Clone,
425    PartialEq,
426    Eq,
427    PartialOrd,
428    Hash,
429    Ord
430)]
431pub struct HostId(ProcId);
432impl HostId {
433    /// Creates a new HostId from a proc_id.
434    pub fn new(proc_id: ProcId) -> Result<Self, anyhow::Error> {
435        if !proc_id
436            .world_name()
437            .expect("proc must be ranked for world_name check")
438            .starts_with(SHADOW_PREFIX)
439        {
440            anyhow::bail!(
441                "proc_id {} is not a valid HostId because it does not start with {}",
442                proc_id,
443                SHADOW_PREFIX
444            )
445        }
446        Ok(Self(proc_id))
447    }
448}
449
450impl TryFrom<ProcId> for HostId {
451    type Error = anyhow::Error;
452
453    fn try_from(proc_id: ProcId) -> Result<Self, anyhow::Error> {
454        if !proc_id
455            .world_name()
456            .expect("proc must be ranked for world_name check")
457            .starts_with(SHADOW_PREFIX)
458        {
459            anyhow::bail!(
460                "proc_id {} is not a valid HostId because it does not start with {}",
461                proc_id,
462                SHADOW_PREFIX
463            )
464        }
465        Ok(Self(proc_id))
466    }
467}
468
469impl std::fmt::Display for HostId {
470    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
471        self.0.fmt(f)
472    }
473}
474
475type HostMap = HashMap<HostId, Host>;
476
477#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
478struct ProcInfo {
479    port_ref: PortRef<ProcMessage>,
480    labels: HashMap<String, String>,
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
484struct WorldState {
485    host_map: HostMap,
486    procs: HashMap<ProcId, ProcInfo>,
487    status: WorldStatus,
488}
489
490/// A world status represents the different phases of a world.
491#[derive(Debug, Clone, Serialize, Deserialize, EnumAsInner, PartialEq)]
492pub enum WorldStatus {
493    /// Waiting for the world to be created. Accumulate joined hosts or procs while we're waiting.
494    AwaitingCreation,
495
496    /// World is created and enough procs based on the scheduler parameter.
497    /// All procs in the world are without failures.
498    Live,
499
500    /// World is created but it does not have enough procs or some procs are failing.
501    /// [`SystemTime`] contains the time when the world became unhealthy.
502    // Use SystemTime instead of Instant to avoid the issue of serialization.
503    Unhealthy(SystemTime),
504}
505
506impl Display for WorldStatus {
507    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
508        match self {
509            WorldStatus::AwaitingCreation => write!(f, "Awaiting Creation"),
510            WorldStatus::Live => write!(f, "Live"),
511            WorldStatus::Unhealthy(_) => write!(f, "Unhealthy"),
512        }
513    }
514}
515
516impl WorldState {
517    /// Gets the mutable ref to host_map.
518    fn get_host_map_mut(&mut self) -> &mut HostMap {
519        &mut self.host_map
520    }
521
522    /// Gets the ref to host_map.
523    fn get_host_map(&self) -> &HostMap {
524        &self.host_map
525    }
526}
527
528impl World {
529    fn new(
530        world_id: WorldId,
531        shape: Shape,
532        state: WorldState,
533        num_procs_per_host: usize,
534        env: Environment,
535        labels: HashMap<String, String>,
536    ) -> Result<Self, anyhow::Error> {
537        if world_id.name().starts_with(SHADOW_PREFIX) {
538            anyhow::bail!(
539                "world name {} cannot start with {}!",
540                world_id,
541                SHADOW_PREFIX
542            )
543        }
544        tracing::info!("creating world {}", world_id,);
545        Ok(Self {
546            world_id,
547            scheduler_params: SchedulerParams {
548                shape,
549                num_procs_per_host,
550                num_procs_scheduled: 0,
551                next_rank: 0,
552                env,
553            },
554            state,
555            labels,
556        })
557    }
558
559    fn get_real_world_id(proc_world_id: &WorldId) -> WorldId {
560        WorldId(
561            proc_world_id
562                .name()
563                .strip_prefix(SHADOW_PREFIX)
564                .unwrap_or(proc_world_id.name())
565                .to_string(),
566        )
567    }
568
569    fn is_host_world(world_id: &WorldId) -> bool {
570        world_id.name().starts_with(SHADOW_PREFIX)
571    }
572
573    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
574    fn get_port_ref_from_host(
575        &self,
576        host_id: &HostId,
577    ) -> Result<PortRef<ProcMessage>, SystemActorError> {
578        let host_map = self.state.get_host_map();
579        // Get Host from hosts given proc_id
580        match host_map.get(host_id) {
581            Some(h) => Ok(h.proc_message_port.clone()),
582            None => Err(SystemActorError::HostNotExist(host_id.clone())),
583        }
584    }
585
586    /// Adds procs to the world.
587    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SystemActorError`.
588    fn add_proc(
589        &mut self,
590        proc_id: ProcId,
591        proc_message_port: PortRef<ProcMessage>,
592        labels: HashMap<String, String>,
593    ) -> Result<(), SystemActorError> {
594        self.state.procs.insert(
595            proc_id,
596            ProcInfo {
597                port_ref: proc_message_port,
598                labels,
599            },
600        );
601        if self.state.status.is_unhealthy()
602            && self.state.procs.len() >= self.scheduler_params.num_procs()
603        {
604            self.state.status = WorldStatus::Live;
605            tracing::info!(
606                "world {}: ready to serve with {} procs",
607                self.world_id,
608                self.state.procs.len()
609            );
610        }
611        Ok(())
612    }
613
614    /// 1. Adds a host to the hosts map.
615    /// 2. Create executor procs for the host.
616    /// 3. Run necessary programs
617    async fn on_host_join(
618        &mut self,
619        host_id: HostId,
620        proc_message_port: PortRef<ProcMessage>,
621        router: &DialMailboxRouter,
622    ) -> Result<(), SystemActorError> {
623        let mut host_entry = match self.state.host_map.entry(host_id.clone()) {
624            Entry::Occupied(_) => {
625                return Err(SystemActorError::DuplicatedHostId(host_id));
626            }
627            Entry::Vacant(entry) => entry.insert_entry(Host::new(
628                proc_message_port.clone(),
629                host_id
630                    .0
631                    .rank()
632                    .expect("host proc must be ranked for rank access"),
633            )),
634        };
635
636        if self.state.status == WorldStatus::AwaitingCreation {
637            return Ok(());
638        }
639
640        let proc_ids = host_entry
641            .get_mut()
642            .get_assigned_procs(&self.world_id, &mut self.scheduler_params);
643
644        router.serialize_and_send(
645            &proc_message_port,
646            ProcMessage::SpawnProc {
647                env: self.scheduler_params.env.clone(),
648                world_id: self.world_id.clone(),
649                proc_ids,
650                world_size: self.scheduler_params.num_procs(),
651            },
652            monitored_return_handle(),
653        )?;
654        Ok(())
655    }
656
657    #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SystemActorError`.
658    fn get_hosts_to_procs(&mut self) -> Result<HashMap<HostId, Vec<ProcId>>, SystemActorError> {
659        // A map from host ID to scheduled proc IDs on this host.
660        let mut host_proc_map: HashMap<HostId, Vec<ProcId>> = HashMap::new();
661        let host_map = self.state.get_host_map_mut();
662        // Iterate over each entry in self.hosts
663        for (host_id, host) in host_map {
664            // Had to clone hosts in order to call schedule_procs
665            if host.num_procs_assigned == self.scheduler_params.num_procs_per_host {
666                continue;
667            }
668            let host_procs = host.get_assigned_procs(&self.world_id, &mut self.scheduler_params);
669            if host_procs.is_empty() {
670                continue;
671            }
672            host_proc_map.insert(host_id.clone(), host_procs);
673        }
674        Ok(host_proc_map)
675    }
676
677    async fn on_create(&mut self, router: &DialMailboxRouter) -> Result<(), anyhow::Error> {
678        let host_procs_map = self.get_hosts_to_procs()?;
679        for (host_id, procs_ids) in host_procs_map {
680            if procs_ids.is_empty() {
681                continue;
682            }
683
684            // REFACTOR(marius): remove
685            let world_id = procs_ids
686                .first()
687                .unwrap()
688                .clone()
689                .into_ranked()
690                .expect("proc must be ranked for world_id access")
691                .0
692                .clone();
693            // Open port ref
694            tracing::info!("spawning procs for host {:?}", host_id);
695            router.serialize_and_send(
696                // Get host proc!
697                &self.get_port_ref_from_host(&host_id)?,
698                ProcMessage::SpawnProc {
699                    env: self.scheduler_params.env.clone(),
700                    world_id,
701                    // REFACTOR(marius): remove
702                    proc_ids: procs_ids,
703                    world_size: self.scheduler_params.num_procs(),
704                },
705                monitored_return_handle(),
706            )?;
707        }
708        Ok(())
709    }
710}
711
712/// A mailbox router that forwards messages to their destinations and
713/// additionally reports the destination address back to the sender’s
714/// [`ProcActor`], allowing it to cache the address for future use.
715#[derive(Debug, Clone)]
716pub struct ReportingRouter {
717    router: DialMailboxRouter,
718    /// A record of cached addresses from dst_proc_id to HashSet(src_proc_id)
719    /// Right now only the proc_ids are recorded for updating purpose.
720    /// We can also cache the address here in the future.
721    address_cache: Arc<DashMap<ProcId, HashSet<ProcId>>>,
722}
723
724impl MailboxSender for ReportingRouter {
725    fn post_unchecked(
726        &self,
727        envelope: MessageEnvelope,
728        return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
729    ) {
730        let ReportingRouter { router, .. } = self;
731        self.post_update_address(&envelope);
732        router.post_unchecked(envelope, return_handle);
733    }
734}
735
736impl ReportingRouter {
737    fn new() -> Self {
738        Self {
739            router: DialMailboxRouter::new(),
740            address_cache: Arc::new(DashMap::new()),
741        }
742    }
743    fn post_update_address(&self, envelope: &MessageEnvelope) {
744        let system_proc_id = id!(system[0]);
745        // These are edge cases that are at unlikely to come up in a
746        // well ordered system but in the event that they do we skip
747        // sending update address messages:
748        // - The sender ID is "unknown" (it makes no sense to remember
749        //   the address of an unknown sender)
750        // - The sender world is "user", which doesn't have a ProcActor running
751        //   to process the address update message.
752        // - The sender is the system (the system knows all addresses)
753        // - The destination is the system (every proc knows the
754        //   system address)
755        // - The sender and the destination are on the same proc (it
756        //   doesn't make sense to be dialing connections between them).
757        if envelope.sender().proc_id() == &id!(unknown[0])
758            || envelope.sender().proc_id().world_id() == Some(&id!(user))
759            || envelope.sender().proc_id() == &system_proc_id
760            || envelope.dest().actor_id().proc_id() == &system_proc_id
761            || envelope.sender().proc_id() == envelope.dest().actor_id().proc_id()
762        {
763            return;
764        }
765        let (dst_proc_id, dst_proc_addr) = self.dest_proc_id_and_address(envelope);
766        let Some(dst_proc_addr) = dst_proc_addr else {
767            tracing::warn!("unknown address for {}", &dst_proc_id);
768            return;
769        };
770
771        let sender_proc_id = envelope.sender().proc_id();
772        self.upsert_address_cache(sender_proc_id, &dst_proc_id);
773        // Sim addresses have a concept of directionality. When we notify a proc of an address we should
774        // use the proc's address as the source for the sim address.
775        let sender_address = self.router.lookup_addr(envelope.sender());
776        let dst_proc_addr =
777            if let (Some(ChannelAddr::Sim(sender_sim_addr)), ChannelAddr::Sim(dest_sim_addr)) =
778                (sender_address, &dst_proc_addr)
779            {
780                ChannelAddr::Sim(
781                    SimAddr::new_with_src(
782                        // source is the sender
783                        sender_sim_addr.addr().clone(),
784                        // dest remains unchanged
785                        dest_sim_addr.addr().clone(),
786                    )
787                    .unwrap(),
788                )
789            } else {
790                dst_proc_addr
791            };
792        self.serialize_and_send(
793            &self.proc_port_ref(sender_proc_id),
794            MailboxAdminMessage::UpdateAddress {
795                proc_id: dst_proc_id,
796                addr: dst_proc_addr,
797            },
798            monitored_return_handle(),
799        )
800        .expect("unexpected serialization failure")
801    }
802
803    /// broadcasts the address of the proc if there's any stale record that has been sent
804    /// out to senders before.
805    fn broadcast_addr(&self, dst_proc_id: &ProcId, dst_proc_addr: ChannelAddr) {
806        if let Some(r) = self.address_cache.get(dst_proc_id) {
807            for sender_proc_id in r.value() {
808                tracing::info!(
809                    "broadcasting address change to {} for {}: {}",
810                    sender_proc_id,
811                    dst_proc_id,
812                    dst_proc_addr
813                );
814                self.serialize_and_send(
815                    &self.proc_port_ref(sender_proc_id),
816                    MailboxAdminMessage::UpdateAddress {
817                        proc_id: dst_proc_id.clone(),
818                        addr: dst_proc_addr.clone(),
819                    },
820                    monitored_return_handle(),
821                )
822                .expect("unexpected serialization failure")
823            }
824        }
825    }
826
827    fn upsert_address_cache(&self, src_proc_id: &ProcId, dst_proc_id: &ProcId) {
828        self.address_cache
829            .entry(dst_proc_id.clone())
830            .and_modify(|src_proc_ids| {
831                src_proc_ids.insert(src_proc_id.clone());
832            })
833            .or_insert({
834                let mut set = HashSet::new();
835                set.insert(src_proc_id.clone());
836                set
837            });
838    }
839
840    fn dest_proc_id_and_address(
841        &self,
842        envelope: &MessageEnvelope,
843    ) -> (ProcId, Option<ChannelAddr>) {
844        let dest_proc_port_id = envelope.dest();
845        let dest_proc_actor_id = dest_proc_port_id.actor_id();
846        let dest_proc_id = dest_proc_actor_id.proc_id();
847        let dest_proc_addr = self.router.lookup_addr(dest_proc_actor_id);
848        (dest_proc_id.clone(), dest_proc_addr)
849    }
850
851    fn proc_port_ref(&self, proc_id: &ProcId) -> PortRef<MailboxAdminMessage> {
852        let proc_actor_id = ActorId(proc_id.clone(), "proc".to_string(), 0);
853        let proc_actor_ref = ActorRef::<ProcActor>::attest(proc_actor_id);
854        proc_actor_ref.port::<MailboxAdminMessage>()
855    }
856}
857
858/// TODO: add misssing doc
859#[derive(Debug, Clone)]
860pub struct SystemActorParams {
861    mailbox_router: ReportingRouter,
862
863    /// The duration to declare an actor dead if no supervision update received.
864    supervision_update_timeout: Duration,
865
866    /// The duration to evict an unhealthy world, after which a world fails supervision states.
867    world_eviction_timeout: Duration,
868}
869
870impl SystemActorParams {
871    /// Create a new system actor params.
872    pub fn new(supervision_update_timeout: Duration, world_eviction_timeout: Duration) -> Self {
873        Self {
874            mailbox_router: ReportingRouter::new(),
875            supervision_update_timeout,
876            world_eviction_timeout,
877        }
878    }
879}
880
881/// A map of all alive procs with their proc ids as the key, the value is the supervision info of this proc.
882#[derive(Debug, Clone, Serialize, Deserialize)]
883struct SystemSupervisionState {
884    // A map from world id to world supervision state.
885    supervision_map: HashMap<WorldId, WorldSupervisionInfo>,
886    // Supervision expiration duration.
887    supervision_update_timeout: Duration,
888}
889
890// Used to record when procs sent their last heartbeats.
891#[derive(Debug, Clone, Default)]
892struct HeartbeatRecord {
893    // This index is used to efficiently find expired procs.
894    // T208419148: Handle btree_index initialization during system actor recovery
895    btree_index: BTreeSet<(Instant, ProcId)>,
896
897    // Last time when proc was updated.
898    proc_last_update_time: HashMap<ProcId, Instant>,
899}
900
901impl HeartbeatRecord {
902    // Update this proc's heartbeat record with timestamp as "now".
903    fn update(&mut self, proc_id: &ProcId, clock: &impl Clock) {
904        // Remove previous entry in btree_index if exists.
905        if let Some(update_time) = self.proc_last_update_time.get(proc_id) {
906            self.btree_index
907                .remove(&(update_time.clone(), proc_id.clone()));
908        }
909
910        // Insert new entry into btree_index.
911        let now = clock.now();
912        self.proc_last_update_time
913            .insert(proc_id.clone(), now.clone());
914        self.btree_index.insert((now.clone(), proc_id.clone()));
915    }
916
917    // Find all the procs with expired heartbeat, and mark them as expired in
918    // WorldSupervisionState.
919    fn mark_expired_procs(
920        &self,
921        state: &mut WorldSupervisionState,
922        clock: &impl Clock,
923        supervision_update_timeout: Duration,
924    ) {
925        // Update procs' live status.
926        let now = clock.now();
927        self.btree_index
928            .iter()
929            .take_while(|(last_update_time, _)| {
930                now > *last_update_time + supervision_update_timeout
931            })
932            .for_each(|(_, proc_id)| {
933                if let Some(proc_state) = state
934                    .procs
935                    .get_mut(&proc_id.rank().expect("proc must be ranked for rank access"))
936                {
937                    match proc_state.proc_health {
938                        ProcStatus::Alive => proc_state.proc_health = ProcStatus::Expired,
939                        // Do not overwrite the health of a proc already known to be unhealthy.
940                        _ => (),
941                    }
942                }
943            });
944    }
945}
946
947#[derive(Debug, Clone, Serialize, Deserialize)]
948struct WorldSupervisionInfo {
949    state: WorldSupervisionState,
950
951    // The lifecycle mode of the proc.
952    lifecycle_mode: HashMap<ProcId, ProcLifecycleMode>,
953
954    #[serde(skip)]
955    heartbeat_record: HeartbeatRecord,
956}
957
958impl WorldSupervisionInfo {
959    fn new() -> Self {
960        Self {
961            state: WorldSupervisionState {
962                procs: HashMap::new(),
963            },
964            lifecycle_mode: HashMap::new(),
965            heartbeat_record: HeartbeatRecord::default(),
966        }
967    }
968}
969
970impl SystemSupervisionState {
971    fn new(supervision_update_timeout: Duration) -> Self {
972        Self {
973            supervision_map: HashMap::new(),
974            supervision_update_timeout,
975        }
976    }
977
978    /// Create a proc supervision entry.
979    fn create(
980        &mut self,
981        proc_state: ProcSupervisionState,
982        lifecycle_mode: ProcLifecycleMode,
983        clock: &impl Clock,
984    ) {
985        if World::is_host_world(&proc_state.world_id) {
986            return;
987        }
988
989        let world = self
990            .supervision_map
991            .entry(proc_state.world_id.clone())
992            .or_insert_with(WorldSupervisionInfo::new);
993        world
994            .lifecycle_mode
995            .insert(proc_state.proc_id.clone(), lifecycle_mode);
996
997        self.update(proc_state, clock);
998    }
999
1000    /// Update a proc supervision entry.
1001    fn update(&mut self, proc_state: ProcSupervisionState, clock: &impl Clock) {
1002        if World::is_host_world(&proc_state.world_id) {
1003            return;
1004        }
1005
1006        let world = self
1007            .supervision_map
1008            .entry(proc_state.world_id.clone())
1009            .or_insert_with(WorldSupervisionInfo::new);
1010
1011        world.heartbeat_record.update(&proc_state.proc_id, clock);
1012
1013        // Update supervision map.
1014        if let Some(info) = world.state.procs.get_mut(
1015            &proc_state
1016                .proc_id
1017                .rank()
1018                .expect("proc must be ranked for proc state update"),
1019        ) {
1020            match info.proc_health {
1021                ProcStatus::Alive => info.proc_health = proc_state.proc_health,
1022                // Do not overwrite the health of a proc already known to be unhealthy.
1023                _ => (),
1024            }
1025            info.failed_actors.extend(proc_state.failed_actors);
1026        } else {
1027            world.state.procs.insert(
1028                proc_state
1029                    .proc_id
1030                    .rank()
1031                    .expect("proc must be ranked for rank access"),
1032                proc_state,
1033            );
1034        }
1035    }
1036
1037    /// Report the given proc's supervision state. If the proc is not in the map, do nothing.
1038    fn report(&mut self, proc_state: ProcSupervisionState, clock: &impl Clock) {
1039        if World::is_host_world(&proc_state.world_id) {
1040            return;
1041        }
1042
1043        let proc_id = proc_state.proc_id.clone();
1044        match self.supervision_map.entry(proc_state.world_id.clone()) {
1045            Entry::Occupied(mut world_supervision_info) => {
1046                match world_supervision_info
1047                    .get_mut()
1048                    .state
1049                    .procs
1050                    .entry(proc_id.rank().expect("proc must be ranked for rank access"))
1051                {
1052                    Entry::Occupied(_) => {
1053                        self.update(proc_state, clock);
1054                    }
1055                    Entry::Vacant(_) => {
1056                        tracing::error!("supervision not enabled for proc {}", &proc_id);
1057                    }
1058                }
1059            }
1060            Entry::Vacant(_) => {
1061                tracing::error!("supervision not enabled for proc {}", &proc_id);
1062            }
1063        }
1064    }
1065
1066    /// Get procs of a world with expired supervision updates, as well as procs with
1067    /// actor failures.
1068    fn get_world_with_failures(
1069        &mut self,
1070        world_id: &WorldId,
1071        clock: &impl Clock,
1072    ) -> Option<WorldSupervisionState> {
1073        if let Some(world) = self.supervision_map.get_mut(world_id) {
1074            world.heartbeat_record.mark_expired_procs(
1075                &mut world.state,
1076                clock,
1077                self.supervision_update_timeout,
1078            );
1079            // Get procs with failures.
1080            let mut world_state_copy = world.state.clone();
1081            // Only return failed procs if there is any
1082            world_state_copy
1083                .procs
1084                .retain(|_, proc_state| !proc_state.is_healthy());
1085            return Some(world_state_copy);
1086        }
1087        None
1088    }
1089
1090    fn is_world_healthy(&mut self, world_id: &WorldId, clock: &impl Clock) -> bool {
1091        self.get_world_with_failures(world_id, clock)
1092            .is_none_or(|state| WorldSupervisionState::is_healthy(&state))
1093    }
1094}
1095
1096#[derive(Debug, Clone, Serialize, Deserialize)]
1097struct WorldStoppingState {
1098    stopping_procs: HashSet<ProcId>,
1099    stopped_procs: HashSet<ProcId>,
1100}
1101
1102/// A message to stop the system actor.
1103#[derive(Debug, Clone, PartialEq, EnumAsInner)]
1104enum SystemStopMessage {
1105    StopSystemActor,
1106    EvictWorlds(Vec<WorldId>),
1107}
1108
1109/// The system actor manages the whole system. It is responsible for
1110/// managing the systems' worlds, and for managing their constituent
1111/// procs. The system actor also provides a central mailbox that can
1112/// route messages to any live actor in the system.
1113#[derive(Debug, Clone)]
1114#[hyperactor::export(
1115    handlers = [
1116        SystemMessage,
1117        ProcSupervisionMessage,
1118        WorldSupervisionMessage,
1119    ],
1120)]
1121pub struct SystemActor {
1122    params: SystemActorParams,
1123    supervision_state: SystemSupervisionState,
1124    worlds: HashMap<WorldId, World>,
1125    // A map from request id to stop state for inflight stop requests.
1126    worlds_to_stop: HashMap<WorldId, WorldStoppingState>,
1127    shutting_down: bool,
1128}
1129
1130/// The well known ID of the world that hosts the system actor, it is always `system`.
1131pub static SYSTEM_WORLD: LazyLock<WorldId> = LazyLock::new(|| id!(system));
1132
1133/// The well known ID of the system actor, it is always `system[0].root`.
1134static SYSTEM_ACTOR_ID: LazyLock<ActorId> = LazyLock::new(|| id!(system[0].root));
1135
1136/// The ref corresponding to the well known [`ID`].
1137pub static SYSTEM_ACTOR_REF: LazyLock<ActorRef<SystemActor>> =
1138    LazyLock::new(|| ActorRef::attest(id!(system[0].root)));
1139
1140impl SystemActor {
1141    /// Adds a new world that's awaiting creation to the worlds.
1142    fn add_new_world(&mut self, world_id: WorldId) -> Result<(), anyhow::Error> {
1143        let world_state = WorldState {
1144            host_map: HashMap::new(),
1145            procs: HashMap::new(),
1146            status: WorldStatus::AwaitingCreation,
1147        };
1148        let world = World::new(
1149            world_id.clone(),
1150            Shape::Unknown,
1151            world_state,
1152            0,
1153            Environment::Local,
1154            HashMap::new(),
1155        )?;
1156        self.worlds.insert(world_id.clone(), world);
1157        Ok(())
1158    }
1159
1160    fn router(&self) -> &ReportingRouter {
1161        &self.params.mailbox_router
1162    }
1163
1164    /// Bootstrap the system actor. This will create a proc, spawn the actor
1165    /// on that proc, and then return the actor handle and the corresponding
1166    /// proc.
1167    pub async fn bootstrap(
1168        params: SystemActorParams,
1169    ) -> Result<(ActorHandle<SystemActor>, Proc), anyhow::Error> {
1170        Self::bootstrap_with_clock(params, ClockKind::default()).await
1171    }
1172
1173    /// Bootstrap the system actor with a specified clock.This will create a proc, spawn the actor
1174    /// on that proc, and then return the actor handle and the corresponding
1175    /// proc.
1176    pub async fn bootstrap_with_clock(
1177        params: SystemActorParams,
1178        clock: ClockKind,
1179    ) -> Result<(ActorHandle<SystemActor>, Proc), anyhow::Error> {
1180        let system_proc = Proc::new_with_clock(
1181            SYSTEM_ACTOR_ID.proc_id().clone(),
1182            BoxedMailboxSender::new(params.mailbox_router.clone()),
1183            clock,
1184        );
1185        let actor_handle = system_proc
1186            .spawn::<SystemActor>(SYSTEM_ACTOR_ID.name(), params)
1187            .await?;
1188
1189        Ok((actor_handle, system_proc))
1190    }
1191
1192    /// Evict a single world
1193    fn evict_world(&mut self, world_id: &WorldId) {
1194        self.worlds.remove(world_id);
1195        self.supervision_state.supervision_map.remove(world_id);
1196        // Remove all the addresses starting with the world_id as the prefix.
1197        self.params
1198            .mailbox_router
1199            .router
1200            .unbind(&world_id.clone().into());
1201    }
1202}
1203
1204#[async_trait]
1205impl Actor for SystemActor {
1206    type Params = SystemActorParams;
1207
1208    async fn new(params: SystemActorParams) -> Result<Self, anyhow::Error> {
1209        let supervision_update_timeout = params.supervision_update_timeout.clone();
1210        Ok(Self {
1211            params,
1212            supervision_state: SystemSupervisionState::new(supervision_update_timeout),
1213            worlds: HashMap::new(),
1214            worlds_to_stop: HashMap::new(),
1215            shutting_down: false,
1216        })
1217    }
1218
1219    async fn init(&mut self, cx: &Instance<Self>) -> Result<(), anyhow::Error> {
1220        // Start to periodically check the unhealthy worlds.
1221        cx.self_message_with_delay(MaintainWorldHealth {}, Duration::from_secs(0))?;
1222        Ok(())
1223    }
1224
1225    async fn handle_undeliverable_message(
1226        &mut self,
1227        _cx: &Instance<Self>,
1228        Undeliverable(envelope): Undeliverable<MessageEnvelope>,
1229    ) -> Result<(), anyhow::Error> {
1230        let to = envelope.dest().clone();
1231        let from = envelope.sender().clone();
1232        tracing::info!(
1233            "a message from {} to {} was undeliverable and returned to the system actor",
1234            from,
1235            to,
1236        );
1237
1238        // The channel to the receiver's proc is lost or can't be
1239        // established. Update the proc's supervision status
1240        // accordingly.
1241        let proc_id = to.actor_id().proc_id();
1242        let world_id = proc_id
1243            .world_id()
1244            .expect("proc must be ranked for world_id access");
1245        if let Some(world) = &mut self.supervision_state.supervision_map.get_mut(world_id) {
1246            if let Some(proc) = world
1247                .state
1248                .procs
1249                .get_mut(&proc_id.rank().expect("proc must be ranked for rank access"))
1250            {
1251                match proc.proc_health {
1252                    ProcStatus::Alive => proc.proc_health = ProcStatus::ConnectionFailure,
1253                    // Do not overwrite the health of a proc already
1254                    // known to be unhealthy.
1255                    _ => (),
1256                }
1257            } else {
1258                tracing::error!(
1259                    "can't update proc {} status because there isn't one",
1260                    proc_id
1261                );
1262            }
1263        } else {
1264            tracing::error!(
1265                "can't update world {} status because there isn't one",
1266                world_id
1267            );
1268        }
1269        Ok(())
1270    }
1271}
1272
1273///
1274/// +------+  spawns   +----+  joins   +-----+
1275/// | Proc |<----------|Host|--------->|World|
1276/// +------+           +----+          +-----+
1277///    |                                   ^
1278///    |          joins                    |
1279///    +-----------------------------------+
1280/// When bootstrapping the system,
1281///   1. hosts will join the world,
1282///   2. hosts will spawn (worker) procs,
1283///   3. procs will join the world
1284#[async_trait]
1285#[hyperactor::forward(SystemMessage)]
1286impl SystemMessageHandler for SystemActor {
1287    async fn join(
1288        &mut self,
1289        cx: &Context<Self>,
1290        world_id: WorldId,
1291        proc_id: ProcId,
1292        proc_message_port: PortRef<ProcMessage>,
1293        channel_addr: ChannelAddr,
1294        labels: HashMap<String, String>,
1295        lifecycle_mode: ProcLifecycleMode,
1296    ) -> Result<(), anyhow::Error> {
1297        tracing::info!("received join for proc {} in world {}", proc_id, world_id);
1298        // todo: check that proc_id is a user id
1299        self.router()
1300            .router
1301            .bind(proc_id.clone().into(), channel_addr.clone());
1302
1303        self.router().broadcast_addr(&proc_id, channel_addr.clone());
1304
1305        // TODO: handle potential undeliverable message return
1306        self.router().serialize_and_send(
1307            &proc_message_port,
1308            ProcMessage::Joined(),
1309            monitored_return_handle(),
1310        )?;
1311
1312        if lifecycle_mode.is_managed() {
1313            self.supervision_state.create(
1314                ProcSupervisionState {
1315                    world_id: world_id.clone(),
1316                    proc_id: proc_id.clone(),
1317                    proc_addr: channel_addr.clone(),
1318                    proc_health: ProcStatus::Alive,
1319                    failed_actors: Vec::new(),
1320                },
1321                lifecycle_mode.clone(),
1322                cx.clock(),
1323            );
1324        }
1325
1326        // If the proc's life cycle is not managed by system actor, system actor
1327        // doesn't need to track it in its "worlds" field.
1328        if lifecycle_mode != ProcLifecycleMode::ManagedBySystem {
1329            tracing::info!("ignoring join for proc {} in world {}", proc_id, world_id);
1330            return Ok(());
1331        }
1332
1333        let world_id = World::get_real_world_id(&world_id);
1334        if !self.worlds.contains_key(&world_id) {
1335            self.add_new_world(world_id.clone())?;
1336        }
1337        let world = self
1338            .worlds
1339            .get_mut(&world_id)
1340            .ok_or(anyhow::anyhow!("failed to get world from map"))?;
1341
1342        match HostId::try_from(proc_id.clone()) {
1343            Ok(host_id) => {
1344                tracing::info!("{}: adding host {}", world_id, host_id);
1345                return world
1346                    .on_host_join(
1347                        host_id,
1348                        proc_message_port,
1349                        &self.params.mailbox_router.router,
1350                    )
1351                    .await
1352                    .map_err(anyhow::Error::from);
1353            }
1354            // If it is not a host ID, it must be a regular proc ID. e.g.
1355            // worker procs spawned by the host proc actor.
1356            Err(_) => {
1357                tracing::info!("proc {} joined to world {}", &proc_id, &world_id,);
1358                // TODO(T207602936) add reconciliation machine to make sure
1359                // 1. only add procs that are created by the host
1360                // 2. retry upon failed proc creation by host.
1361                if let Err(e) = world.add_proc(proc_id.clone(), proc_message_port, labels) {
1362                    tracing::warn!(
1363                        "failed to add proc {} to world {}: {}",
1364                        &proc_id,
1365                        &world_id,
1366                        e
1367                    );
1368                }
1369            }
1370        };
1371        Ok(())
1372    }
1373
1374    async fn upsert_world(
1375        &mut self,
1376        cx: &Context<Self>,
1377        world_id: WorldId,
1378        shape: Shape,
1379        num_procs_per_host: usize,
1380        env: Environment,
1381        labels: HashMap<String, String>,
1382    ) -> Result<(), anyhow::Error> {
1383        tracing::info!("received upsert_world for world {}!", world_id);
1384        match self.worlds.get_mut(&world_id) {
1385            Some(world) => {
1386                tracing::info!("found existing world {}!", world_id);
1387                match &world.state.status {
1388                    WorldStatus::AwaitingCreation => {
1389                        world.scheduler_params.shape = shape;
1390                        world.scheduler_params.num_procs_per_host = num_procs_per_host;
1391                        world.scheduler_params.env = env;
1392                        world.state = WorldState {
1393                            host_map: world.state.host_map.clone(),
1394                            procs: world.state.procs.clone(),
1395                            status: if world.state.procs.len() < world.scheduler_params.num_procs()
1396                                || !self
1397                                    .supervision_state
1398                                    .is_world_healthy(&world_id, cx.clock())
1399                            {
1400                                WorldStatus::Unhealthy(cx.clock().system_time_now())
1401                            } else {
1402                                WorldStatus::Live
1403                            },
1404                        };
1405                        for (k, v) in labels {
1406                            if world.labels.contains_key(&k) {
1407                                anyhow::bail!("cannot overwrite world label: {}", k);
1408                            }
1409                            world.labels.insert(k.clone(), v.clone());
1410                        }
1411                    }
1412                    _ => {
1413                        anyhow::bail!("cannot modify world {}: already exists", world.world_id)
1414                    }
1415                }
1416
1417                world.on_create(&self.params.mailbox_router.router).await?;
1418                tracing::info!(
1419                    "modified parameters to world {} with shape: {:?} and labels {:?}",
1420                    &world.world_id,
1421                    world.scheduler_params.shape,
1422                    world.labels
1423                );
1424            }
1425            None => {
1426                let world = World::new(
1427                    world_id.clone(),
1428                    shape.clone(),
1429                    WorldState {
1430                        host_map: HashMap::new(),
1431                        procs: HashMap::new(),
1432                        status: WorldStatus::Unhealthy(cx.clock().system_time_now()),
1433                    },
1434                    num_procs_per_host,
1435                    env,
1436                    labels,
1437                )?;
1438                tracing::info!("new world {} added with shape: {:?}", world_id, &shape);
1439                self.worlds.insert(world_id, world);
1440            }
1441        };
1442        Ok(())
1443    }
1444
1445    async fn snapshot(
1446        &mut self,
1447        _cx: &Context<Self>,
1448        filter: SystemSnapshotFilter,
1449    ) -> Result<SystemSnapshot, anyhow::Error> {
1450        let world_snapshots = self
1451            .worlds
1452            .iter()
1453            .filter(|(_, world)| filter.world_matches(world))
1454            .map(|(world_id, world)| {
1455                (
1456                    world_id.clone(),
1457                    WorldSnapshot::from_world_filtered(world, &filter),
1458                )
1459            })
1460            .collect();
1461        Ok(SystemSnapshot {
1462            worlds: world_snapshots,
1463            execution_id: hyperactor_telemetry::env::execution_id(),
1464        })
1465    }
1466
1467    async fn stop(
1468        &mut self,
1469        cx: &Context<Self>,
1470        worlds: Option<Vec<WorldId>>,
1471        proc_timeout: Duration,
1472        reply_port: OncePortRef<()>,
1473    ) -> Result<(), anyhow::Error> {
1474        // TODO: this needn't be async
1475
1476        match &worlds {
1477            Some(world_ids) => {
1478                tracing::info!("stopping worlds: {:?}", world_ids);
1479            }
1480            None => {
1481                tracing::info!("stopping system actor and all worlds");
1482                self.shutting_down = true;
1483            }
1484        }
1485
1486        // If there's no worlds left to stop, shutdown now.
1487        if self.worlds.is_empty() && self.shutting_down {
1488            cx.stop()?;
1489            reply_port.send(cx, ())?;
1490            return Ok(());
1491        }
1492
1493        let mut world_ids = vec![];
1494        match &worlds {
1495            Some(worlds) => {
1496                // Stop only the specified worlds
1497                world_ids.extend(worlds.clone().into_iter().collect::<Vec<_>>());
1498            }
1499            None => {
1500                // Stop all worlds
1501                world_ids.extend(
1502                    self.worlds
1503                        .keys()
1504                        .filter(|x| x.name() != "user")
1505                        .cloned()
1506                        .collect::<Vec<_>>(),
1507                );
1508            }
1509        }
1510
1511        for world_id in &world_ids {
1512            if self.worlds_to_stop.contains_key(world_id) || !self.worlds.contains_key(world_id) {
1513                // The world is being stopped already.
1514                continue;
1515            }
1516            self.worlds_to_stop.insert(
1517                world_id.clone(),
1518                WorldStoppingState {
1519                    stopping_procs: HashSet::new(),
1520                    stopped_procs: HashSet::new(),
1521                },
1522            );
1523        }
1524
1525        let all_procs = self
1526            .worlds
1527            .iter()
1528            .filter(|(world_id, _)| match &worlds {
1529                Some(worlds_ids) => worlds_ids.contains(world_id),
1530                None => true,
1531            })
1532            .flat_map(|(_, world)| {
1533                world
1534                    .state
1535                    .host_map
1536                    .iter()
1537                    .map(|(host_id, host)| (host_id.0.clone(), host.proc_message_port.clone()))
1538                    .chain(
1539                        world
1540                            .state
1541                            .procs
1542                            .iter()
1543                            .map(|(proc_id, info)| (proc_id.clone(), info.port_ref.clone())),
1544                    )
1545                    .collect::<Vec<_>>()
1546            })
1547            .collect::<HashMap<_, _>>();
1548
1549        // Send Stop message to all processes known to the system. This is a best
1550        // effort, because the message might fail to deliver due to network
1551        // partition.
1552        for (proc_id, port) in all_procs.into_iter() {
1553            let stopping_state = self
1554                .worlds_to_stop
1555                .get_mut(&World::get_real_world_id(
1556                    proc_id
1557                        .world_id()
1558                        .expect("proc must be ranked for world_id access"),
1559                ))
1560                .unwrap();
1561            if !stopping_state.stopping_procs.insert(proc_id) {
1562                continue;
1563            }
1564
1565            // This is a hack. Due to T214365263, SystemActor cannot get reply
1566            // from a 2-way message when that message is sent from its handler.
1567            // As a result, we set the reply to a handle port, so that reply
1568            // can be processed as a separate message. See Handler<ProcStopResult>
1569            // for how the received reply is further processed.
1570            let reply_to = cx.port::<ProcStopResult>().bind().into_once();
1571            port.send(
1572                cx,
1573                ProcMessage::Stop {
1574                    timeout: proc_timeout,
1575                    reply_to,
1576                },
1577            )?;
1578        }
1579
1580        let stop_msg = match &worlds {
1581            Some(_) => SystemStopMessage::EvictWorlds(world_ids.clone()),
1582            None => SystemStopMessage::StopSystemActor {},
1583        };
1584
1585        // Schedule a message to stop the system actor itself.
1586        cx.self_message_with_delay(stop_msg, Duration::from_secs(8))?;
1587
1588        reply_port.send(cx, ())?;
1589        Ok(())
1590    }
1591}
1592
1593#[async_trait]
1594impl Handler<MaintainWorldHealth> for SystemActor {
1595    async fn handle(&mut self, cx: &Context<Self>, _: MaintainWorldHealth) -> anyhow::Result<()> {
1596        // TODO: this needn't be async
1597
1598        // Find the world with the oldest unhealthy time so we can schedule the next check.
1599        let mut next_check_delay = self.params.world_eviction_timeout;
1600        tracing::debug!("Checking world state. Got {} worlds", self.worlds.len());
1601
1602        for world in self.worlds.values_mut() {
1603            if world.state.status == WorldStatus::AwaitingCreation {
1604                continue;
1605            }
1606
1607            let Some(state) = self
1608                .supervision_state
1609                .get_world_with_failures(&world.world_id, cx.clock())
1610            else {
1611                tracing::debug!("world {} does not have failures, skipping.", world.world_id);
1612                continue;
1613            };
1614
1615            if state.is_healthy() {
1616                tracing::debug!(
1617                    "world {} with procs {:?} is healthy, skipping.",
1618                    world.world_id,
1619                    state
1620                        .procs
1621                        .values()
1622                        .map(|p| p.proc_id.clone())
1623                        .collect::<Vec<_>>()
1624                );
1625                continue;
1626            }
1627            // Some procs are not healthy, check if any of the proc should manage the system.
1628            for (_, proc_state) in state.procs.iter() {
1629                if proc_state.proc_health == ProcStatus::Alive {
1630                    tracing::debug!("proc {} is still alive.", proc_state.proc_id);
1631                    continue;
1632                }
1633                if self
1634                    .supervision_state
1635                    .supervision_map
1636                    .get(&world.world_id)
1637                    .and_then(|world| world.lifecycle_mode.get(&proc_state.proc_id))
1638                    .map_or(true, |mode| *mode != ProcLifecycleMode::ManagingSystem)
1639                {
1640                    tracing::debug!(
1641                        "proc {} with state {} does not manage system.",
1642                        proc_state.proc_id,
1643                        proc_state.proc_health
1644                    );
1645                    continue;
1646                }
1647
1648                tracing::error!(
1649                    "proc {}  is unhealthy, stop the system as the proc manages the system",
1650                    proc_state.proc_id
1651                );
1652
1653                // The proc has expired heartbeating and it manages the lifecycle of system, schedule system stop
1654                let (tx, _) = cx.open_once_port::<()>();
1655                cx.port().send(SystemMessage::Stop {
1656                    worlds: None,
1657                    proc_timeout: Duration::from_secs(5),
1658                    reply_port: tx.bind(),
1659                })?;
1660            }
1661
1662            if world.state.status == WorldStatus::Live {
1663                world.state.status = WorldStatus::Unhealthy(cx.clock().system_time_now());
1664            }
1665
1666            match &world.state.status {
1667                WorldStatus::Unhealthy(last_unhealthy_time) => {
1668                    let elapsed = last_unhealthy_time
1669                        .elapsed()
1670                        .inspect_err(|err| {
1671                            tracing::error!(
1672                                "failed to get elapsed time for unhealthy world {}: {}",
1673                                world.world_id,
1674                                err
1675                            )
1676                        })
1677                        .unwrap_or_else(|_| Duration::from_secs(0));
1678
1679                    if elapsed < self.params.world_eviction_timeout {
1680                        // We can live a bit longer still.
1681                        next_check_delay = std::cmp::min(
1682                            next_check_delay,
1683                            self.params.world_eviction_timeout - elapsed,
1684                        );
1685                    } else {
1686                        next_check_delay = Duration::from_secs(0);
1687                    }
1688                }
1689                _ => {
1690                    tracing::error!(
1691                        "find a failed world {} with healthy state {}",
1692                        world.world_id,
1693                        world.state.status
1694                    );
1695                    continue;
1696                }
1697            }
1698        }
1699        cx.self_message_with_delay(MaintainWorldHealth {}, next_check_delay)?;
1700
1701        Ok(())
1702    }
1703}
1704
1705#[async_trait]
1706impl Handler<ProcSupervisionMessage> for SystemActor {
1707    async fn handle(
1708        &mut self,
1709        cx: &Context<Self>,
1710        msg: ProcSupervisionMessage,
1711    ) -> anyhow::Result<()> {
1712        match msg {
1713            ProcSupervisionMessage::Update(state, reply_port) => {
1714                self.supervision_state.report(state, cx.clock());
1715                let _ = reply_port.send(cx, ());
1716            }
1717        }
1718        Ok(())
1719    }
1720}
1721
1722#[async_trait]
1723impl Handler<WorldSupervisionMessage> for SystemActor {
1724    async fn handle(
1725        &mut self,
1726        cx: &Context<Self>,
1727        msg: WorldSupervisionMessage,
1728    ) -> anyhow::Result<()> {
1729        match msg {
1730            WorldSupervisionMessage::State(world_id, reply_port) => {
1731                let world_state = self
1732                    .supervision_state
1733                    .get_world_with_failures(&world_id, cx.clock());
1734                // TODO: handle potential undeliverable message return
1735                let _ = reply_port.send(cx, world_state);
1736            }
1737        }
1738        Ok(())
1739    }
1740}
1741
1742// Temporary solution to allow SystemMessage::Stop receive replies from 2-way
1743// messages. Can be remove after T214365263 is implemented.
1744#[async_trait]
1745impl Handler<ProcStopResult> for SystemActor {
1746    async fn handle(&mut self, cx: &Context<Self>, msg: ProcStopResult) -> anyhow::Result<()> {
1747        fn stopping_proc_msg<'a>(sprocs: impl Iterator<Item = &'a ProcId>) -> String {
1748            let sprocs = sprocs.collect::<Vec<_>>();
1749            if sprocs.is_empty() {
1750                return "no procs left".to_string();
1751            }
1752            let msg = sprocs
1753                .iter()
1754                .take(3)
1755                .map(|proc_id| proc_id.to_string())
1756                .collect::<Vec<_>>()
1757                .join(", ");
1758            if sprocs.len() > 3 {
1759                format!("remaining procs: {} and {} more", msg, sprocs.len() - 3)
1760            } else {
1761                format!("remaining procs: {}", msg)
1762            }
1763        }
1764        let mut world_stopped = false;
1765        let world_id = &msg
1766            .proc_id
1767            .clone()
1768            .into_ranked()
1769            .expect("proc must be ranked for world_id access")
1770            .0;
1771        if let Some(stopping_state) = self.worlds_to_stop.get_mut(world_id) {
1772            stopping_state.stopped_procs.insert(msg.proc_id.clone());
1773            tracing::debug!(
1774                "received stop response from {}: {} stopped actors, {} aborted actors: {}",
1775                msg.proc_id,
1776                msg.actors_stopped,
1777                msg.actors_aborted,
1778                stopping_proc_msg(
1779                    stopping_state
1780                        .stopping_procs
1781                        .difference(&stopping_state.stopped_procs)
1782                ),
1783            );
1784            world_stopped =
1785                stopping_state.stopping_procs.len() == stopping_state.stopped_procs.len();
1786        } else {
1787            tracing::warn!(
1788                "received stop response from {} but no inflight stopping request is found, possibly late response",
1789                msg.proc_id
1790            );
1791        }
1792
1793        if world_stopped {
1794            self.evict_world(world_id);
1795            self.worlds_to_stop.remove(world_id);
1796        }
1797
1798        if self.shutting_down && self.worlds.is_empty() {
1799            cx.stop()?;
1800        }
1801
1802        Ok(())
1803    }
1804}
1805
1806#[async_trait]
1807impl Handler<SystemStopMessage> for SystemActor {
1808    async fn handle(
1809        &mut self,
1810        cx: &Context<Self>,
1811        message: SystemStopMessage,
1812    ) -> anyhow::Result<()> {
1813        match message {
1814            SystemStopMessage::EvictWorlds(world_ids) => {
1815                for world_id in &world_ids {
1816                    if self.worlds_to_stop.contains_key(world_id) {
1817                        tracing::warn!(
1818                            "Waiting for world to stop timed out, evicting world anyways: {:?}",
1819                            world_id
1820                        );
1821                        self.evict_world(world_id);
1822                    }
1823                }
1824            }
1825            SystemStopMessage::StopSystemActor => {
1826                if self.worlds_to_stop.is_empty() {
1827                    tracing::warn!(
1828                        "waiting for all worlds to stop timed out, stopping the system actor and evicting the these worlds anyways: {:?}",
1829                        self.worlds_to_stop.keys()
1830                    );
1831                } else {
1832                    tracing::warn!(
1833                        "waiting for all worlds to stop timed out, stopping the system actor"
1834                    );
1835                }
1836
1837                cx.stop()?;
1838            }
1839        }
1840        Ok(())
1841    }
1842}
1843
1844#[cfg(test)]
1845mod tests {
1846    use std::assert_matches::assert_matches;
1847
1848    use anyhow::Result;
1849    use hyperactor::PortId;
1850    use hyperactor::actor::ActorStatus;
1851    use hyperactor::attrs::Attrs;
1852    use hyperactor::channel;
1853    use hyperactor::channel::ChannelTransport;
1854    use hyperactor::channel::Rx;
1855    use hyperactor::clock::Clock;
1856    use hyperactor::clock::RealClock;
1857    use hyperactor::data::Serialized;
1858    use hyperactor::mailbox::Mailbox;
1859    use hyperactor::mailbox::MailboxServer;
1860    use hyperactor::mailbox::MessageEnvelope;
1861    use hyperactor::mailbox::PortHandle;
1862    use hyperactor::mailbox::PortReceiver;
1863    use hyperactor::simnet;
1864    use hyperactor::test_utils::pingpong::PingPongActorParams;
1865
1866    use super::*;
1867    use crate::System;
1868
1869    struct MockHostActor {
1870        local_proc_id: ProcId,
1871        local_proc_addr: ChannelAddr,
1872        local_proc_message_port: PortHandle<ProcMessage>,
1873        local_proc_message_receiver: PortReceiver<ProcMessage>,
1874    }
1875
1876    async fn spawn_mock_host_actor(proc_world_id: WorldId, host_id: usize) -> MockHostActor {
1877        // Set up a local actor.
1878        let local_proc_id = ProcId::Ranked(
1879            WorldId(format!("{}{}", SHADOW_PREFIX, proc_world_id.name())),
1880            host_id,
1881        );
1882        let (local_proc_addr, local_proc_rx) =
1883            channel::serve::<MessageEnvelope>(ChannelAddr::any(ChannelTransport::Local)).unwrap();
1884        let local_proc_mbox = Mailbox::new_detached(local_proc_id.actor_id("test".to_string(), 0));
1885        let (local_proc_message_port, local_proc_message_receiver) = local_proc_mbox.open_port();
1886        let _local_proc_serve_handle = local_proc_mbox.clone().serve(local_proc_rx);
1887        MockHostActor {
1888            local_proc_id,
1889            local_proc_addr,
1890            local_proc_message_port,
1891            local_proc_message_receiver,
1892        }
1893    }
1894
1895    #[tokio::test]
1896    async fn test_supervision_state() {
1897        let mut sv = SystemSupervisionState::new(Duration::from_secs(1));
1898        let world_id = id!(world);
1899        let proc_id_0 = world_id.proc_id(0);
1900        let clock = ClockKind::Real(RealClock);
1901        sv.create(
1902            ProcSupervisionState {
1903                world_id: world_id.clone(),
1904                proc_addr: ChannelAddr::any(ChannelTransport::Local),
1905                proc_id: proc_id_0.clone(),
1906                proc_health: ProcStatus::Alive,
1907                failed_actors: Vec::new(),
1908            },
1909            ProcLifecycleMode::ManagedBySystem,
1910            &clock,
1911        );
1912        let actor_id = id!(world[1].actor);
1913        let proc_id_1 = actor_id.proc_id();
1914        sv.create(
1915            ProcSupervisionState {
1916                world_id: world_id.clone(),
1917                proc_addr: ChannelAddr::any(ChannelTransport::Local),
1918                proc_id: proc_id_1.clone(),
1919                proc_health: ProcStatus::Alive,
1920                failed_actors: Vec::new(),
1921            },
1922            ProcLifecycleMode::ManagedBySystem,
1923            &clock,
1924        );
1925        let world_id = id!(world);
1926
1927        let unknown_world_id = id!(unknow_world);
1928        let failures = sv.get_world_with_failures(&unknown_world_id, &clock);
1929        assert!(failures.is_none());
1930
1931        // No supervision expiration yet.
1932        let failures = sv.get_world_with_failures(&world_id, &clock);
1933        assert!(failures.is_some());
1934        assert_eq!(failures.unwrap().procs.len(), 0);
1935
1936        // One proc expired.
1937        RealClock.sleep(Duration::from_secs(2)).await;
1938        sv.report(
1939            ProcSupervisionState {
1940                world_id: world_id.clone(),
1941                proc_addr: ChannelAddr::any(ChannelTransport::Local),
1942                proc_id: proc_id_1.clone(),
1943                proc_health: ProcStatus::Alive,
1944                failed_actors: Vec::new(),
1945            },
1946            &clock,
1947        );
1948        let failures = sv.get_world_with_failures(&world_id, &clock);
1949        let procs = failures.unwrap().procs;
1950        assert_eq!(procs.len(), 1);
1951        assert!(
1952            procs.contains_key(
1953                &proc_id_0
1954                    .rank()
1955                    .expect("proc must be ranked for rank access")
1956            )
1957        );
1958
1959        // Actor failure happened to proc_1
1960        sv.report(
1961            ProcSupervisionState {
1962                world_id: world_id.clone(),
1963                proc_addr: ChannelAddr::any(ChannelTransport::Local),
1964                proc_id: proc_id_1.clone(),
1965                proc_health: ProcStatus::Alive,
1966                failed_actors: [(actor_id.clone(), ActorStatus::Failed("Actor failed".into()))]
1967                    .to_vec(),
1968            },
1969            &clock,
1970        );
1971
1972        let failures = sv.get_world_with_failures(&world_id, &clock);
1973        let procs = failures.unwrap().procs;
1974        assert_eq!(procs.len(), 2);
1975        assert!(
1976            procs.contains_key(
1977                &proc_id_0
1978                    .rank()
1979                    .expect("proc must be ranked for rank access")
1980            )
1981        );
1982        assert!(
1983            procs.contains_key(
1984                &proc_id_1
1985                    .rank()
1986                    .expect("proc must be ranked for rank access")
1987            )
1988        );
1989    }
1990
1991    #[tracing_test::traced_test]
1992    #[tokio::test]
1993    async fn test_host_join_before_world() {
1994        // Spins up a new world with 2 hosts, with 3 procs each.
1995        let params = SystemActorParams::new(Duration::from_secs(10), Duration::from_secs(10));
1996        let (system_actor_handle, _system_proc) = SystemActor::bootstrap(params).await.unwrap();
1997
1998        // Use a local proc actor to join the system.
1999        let mut host_actors: Vec<MockHostActor> = Vec::new();
2000
2001        let world_name = "test".to_string();
2002        let world_id = WorldId(world_name.clone());
2003        host_actors.push(spawn_mock_host_actor(world_id.clone(), 0).await);
2004        host_actors.push(spawn_mock_host_actor(world_id.clone(), 1).await);
2005
2006        for host_actor in host_actors.iter_mut() {
2007            // Join the world.
2008            system_actor_handle
2009                .send(SystemMessage::Join {
2010                    proc_id: host_actor.local_proc_id.clone(),
2011                    world_id: world_id.clone(),
2012                    proc_message_port: host_actor.local_proc_message_port.bind(),
2013                    proc_addr: host_actor.local_proc_addr.clone(),
2014                    labels: HashMap::new(),
2015                    lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
2016                })
2017                .unwrap();
2018
2019            // We should get a joined message.
2020            // and a spawn proc message.
2021            assert_matches!(
2022                host_actor.local_proc_message_receiver.recv().await.unwrap(),
2023                ProcMessage::Joined()
2024            );
2025        }
2026
2027        // Create a new world message and send to system actor
2028        let num_procs = 6;
2029        let shape = Shape::Definite(vec![2, 3]);
2030        system_actor_handle
2031            .send(SystemMessage::UpsertWorld {
2032                world_id: world_id.clone(),
2033                shape,
2034                num_procs_per_host: 3,
2035                env: Environment::Local,
2036                labels: HashMap::new(),
2037            })
2038            .unwrap();
2039
2040        let mut all_procs: Vec<ProcId> = Vec::new();
2041        for host_actor in host_actors.iter_mut() {
2042            let m = host_actor.local_proc_message_receiver.recv().await.unwrap();
2043            match m {
2044                ProcMessage::SpawnProc {
2045                    env,
2046                    world_id,
2047                    mut proc_ids,
2048                    world_size,
2049                } => {
2050                    assert_eq!(world_id, WorldId(world_name.clone()));
2051                    assert_eq!(env, Environment::Local);
2052                    assert_eq!(world_size, num_procs);
2053                    all_procs.append(&mut proc_ids);
2054                }
2055                _ => std::panic!("Unexpected message type!"),
2056            }
2057        }
2058        // Check if all proc ids from 0 to num_procs - 1 are in the list
2059        assert_eq!(all_procs.len(), num_procs);
2060        all_procs.sort();
2061        for (i, proc) in all_procs.iter().enumerate() {
2062            assert_eq!(*proc, ProcId::Ranked(WorldId(world_name.clone()), i));
2063        }
2064    }
2065
2066    #[tokio::test]
2067    async fn test_host_join_after_world() {
2068        // Spins up a new world with 2 hosts, with 3 procs each.
2069        let params = SystemActorParams::new(Duration::from_secs(10), Duration::from_secs(10));
2070        let (system_actor_handle, _system_proc) = SystemActor::bootstrap(params).await.unwrap();
2071
2072        // Create a new world message and send to system actor
2073        let world_name = "test".to_string();
2074        let world_id = WorldId(world_name.clone());
2075        let num_procs = 6;
2076        let shape = Shape::Definite(vec![2, 3]);
2077        system_actor_handle
2078            .send(SystemMessage::UpsertWorld {
2079                world_id: world_id.clone(),
2080                shape,
2081                num_procs_per_host: 3,
2082                env: Environment::Local,
2083                labels: HashMap::new(),
2084            })
2085            .unwrap();
2086
2087        // Use a local proc actor to join the system.
2088        let mut host_actors: Vec<MockHostActor> = Vec::new();
2089
2090        host_actors.push(spawn_mock_host_actor(world_id.clone(), 0).await);
2091        host_actors.push(spawn_mock_host_actor(world_id.clone(), 1).await);
2092
2093        for host_actor in host_actors.iter_mut() {
2094            // Join the world.
2095            system_actor_handle
2096                .send(SystemMessage::Join {
2097                    proc_id: host_actor.local_proc_id.clone(),
2098                    world_id: world_id.clone(),
2099                    proc_message_port: host_actor.local_proc_message_port.bind(),
2100                    proc_addr: host_actor.local_proc_addr.clone(),
2101                    labels: HashMap::new(),
2102                    lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
2103                })
2104                .unwrap();
2105
2106            // We should get a joined message.
2107            // and a spawn proc message.
2108            assert_matches!(
2109                host_actor.local_proc_message_receiver.recv().await.unwrap(),
2110                ProcMessage::Joined()
2111            );
2112        }
2113
2114        let mut all_procs: Vec<ProcId> = Vec::new();
2115        for host_actor in host_actors.iter_mut() {
2116            let m = host_actor.local_proc_message_receiver.recv().await.unwrap();
2117            match m {
2118                ProcMessage::SpawnProc {
2119                    env,
2120                    world_id,
2121                    mut proc_ids,
2122                    world_size,
2123                } => {
2124                    assert_eq!(world_id, WorldId(world_name.clone()));
2125                    assert_eq!(env, Environment::Local);
2126                    assert_eq!(world_size, num_procs);
2127                    all_procs.append(&mut proc_ids);
2128                }
2129                _ => std::panic!("Unexpected message type!"),
2130            }
2131        }
2132        // Check if all proc ids from 0 to num_procs - 1 are in the list
2133        assert_eq!(all_procs.len(), num_procs);
2134        all_procs.sort();
2135        for (i, proc) in all_procs.iter().enumerate() {
2136            assert_eq!(*proc, ProcId::Ranked(WorldId(world_name.clone()), i));
2137        }
2138    }
2139
2140    #[test]
2141    fn test_snapshot_filter() {
2142        let test_world = World::new(
2143            WorldId("test_world".to_string()),
2144            Shape::Definite(vec![1]),
2145            WorldState {
2146                host_map: HashMap::new(),
2147                procs: HashMap::new(),
2148                status: WorldStatus::Live,
2149            },
2150            1,
2151            Environment::Local,
2152            HashMap::from([("foo".to_string(), "bar".to_string())]),
2153        )
2154        .unwrap();
2155        // match all
2156        let filter = SystemSnapshotFilter::all();
2157        assert!(filter.world_matches(&test_world));
2158        assert!(SystemSnapshotFilter::labels_match(
2159            &HashMap::new(),
2160            &HashMap::from([("foo".to_string(), "bar".to_string())])
2161        ));
2162        // specific match
2163        let mut filter = SystemSnapshotFilter::all();
2164        filter.worlds = vec![WorldId("test_world".to_string())];
2165        assert!(filter.world_matches(&test_world));
2166        filter.worlds = vec![WorldId("unknow_world".to_string())];
2167        assert!(!filter.world_matches(&test_world));
2168        assert!(SystemSnapshotFilter::labels_match(
2169            &HashMap::from([("foo".to_string(), "baz".to_string())]),
2170            &HashMap::from([("foo".to_string(), "baz".to_string())]),
2171        ));
2172        assert!(!SystemSnapshotFilter::labels_match(
2173            &HashMap::from([("foo".to_string(), "bar".to_string())]),
2174            &HashMap::from([("foo".to_string(), "baz".to_string())]),
2175        ));
2176    }
2177
2178    #[tokio::test]
2179    async fn test_undeliverable_message_return() {
2180        // System can't send a message to a remote actor because the
2181        // proc connection is lost.
2182        use hyperactor::mailbox::MailboxClient;
2183        use hyperactor::test_utils::pingpong::PingPongActor;
2184        use hyperactor::test_utils::pingpong::PingPongMessage;
2185
2186        use crate::System;
2187        use crate::proc_actor::ProcActor;
2188        use crate::supervision::ProcSupervisor;
2189
2190        // Use temporary config for this test
2191        let config = hyperactor::config::global::lock();
2192        let _guard = config.override_key(
2193            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2194            Duration::from_secs(1),
2195        );
2196
2197        // Serve a system. Undeliverable messages encountered by the
2198        // mailbox server are returned to the system actor.
2199        let server_handle = System::serve(
2200            ChannelAddr::any(ChannelTransport::Tcp),
2201            Duration::from_secs(2), // supervision update timeout
2202            Duration::from_secs(2), // duration to evict an unhealthy world
2203        )
2204        .await
2205        .unwrap();
2206        let system_actor_handle = server_handle.system_actor_handle();
2207        let mut system = System::new(server_handle.local_addr().clone());
2208        let client = system.attach().await.unwrap();
2209
2210        // At this point there are no worlds.
2211        let snapshot = system_actor_handle
2212            .snapshot(&client, SystemSnapshotFilter::all())
2213            .await
2214            .unwrap();
2215        assert_eq!(snapshot.worlds.len(), 0);
2216
2217        // Create one.
2218        let world_id = id!(world);
2219        system_actor_handle
2220            .send(SystemMessage::UpsertWorld {
2221                world_id: world_id.clone(),
2222                shape: Shape::Definite(vec![1]),
2223                num_procs_per_host: 1,
2224                env: Environment::Local,
2225                labels: HashMap::new(),
2226            })
2227            .unwrap();
2228
2229        // Now we should know a world.
2230        let snapshot = system_actor_handle
2231            .snapshot(&client, SystemSnapshotFilter::all())
2232            .await
2233            .unwrap();
2234        assert_eq!(snapshot.worlds.len(), 1);
2235        // Check it's the world we think it is.
2236        assert!(snapshot.worlds.contains_key(&world_id));
2237        // It starts out unhealthy (todo: understand why).
2238        assert!(matches!(
2239            snapshot.worlds.get(&world_id).unwrap().status,
2240            WorldStatus::Unhealthy(_)
2241        ));
2242
2243        // Build a supervisor.
2244        let supervisor = system.attach().await.unwrap();
2245        let (sup_tx, _sup_rx) = supervisor.open_port::<ProcSupervisionMessage>();
2246        sup_tx.bind_to(ProcSupervisionMessage::port());
2247        let sup_ref = ActorRef::<ProcSupervisor>::attest(supervisor.self_id().clone());
2248
2249        // Construct a system sender.
2250        let system_sender = BoxedMailboxSender::new(MailboxClient::new(
2251            channel::dial(server_handle.local_addr().clone()).unwrap(),
2252        ));
2253        // Construct a proc forwarder in terms of the system sender.
2254        let proc_forwarder =
2255            BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
2256
2257        // Bootstrap proc 'world[0]', join the system.
2258        let proc_0 = Proc::new(world_id.proc_id(0), proc_forwarder.clone());
2259        let _proc_actor_0 = ProcActor::bootstrap_for_proc(
2260            proc_0.clone(),
2261            world_id.clone(),
2262            ChannelAddr::any(ChannelTransport::Tcp),
2263            server_handle.local_addr().clone(),
2264            sup_ref.clone(),
2265            Duration::from_millis(300), // supervision update interval
2266            HashMap::new(),
2267            ProcLifecycleMode::ManagedBySystem,
2268        )
2269        .await
2270        .unwrap();
2271        let proc_0_client = proc_0.attach("client").unwrap();
2272        let (proc_0_undeliverable_tx, _proc_0_undeliverable_rx) = proc_0_client.open_port();
2273
2274        // Bootstrap a second proc 'world[1]', join the system.
2275        let proc_1 = Proc::new(world_id.proc_id(1), proc_forwarder.clone());
2276        let proc_actor_1 = ProcActor::bootstrap_for_proc(
2277            proc_1.clone(),
2278            world_id.clone(),
2279            ChannelAddr::any(ChannelTransport::Tcp),
2280            server_handle.local_addr().clone(),
2281            sup_ref.clone(),
2282            Duration::from_millis(300), // supervision update interval
2283            HashMap::new(),
2284            ProcLifecycleMode::ManagedBySystem,
2285        )
2286        .await
2287        .unwrap();
2288        let proc_1_client = proc_1.attach("client").unwrap();
2289        let (proc_1_undeliverable_tx, mut _proc_1_undeliverable_rx) = proc_1_client.open_port();
2290
2291        // Spawn two actors 'ping' and 'pong' where 'ping' runs on
2292        // 'world[0]' and 'pong' on 'world[1]' (that is, not on the
2293        // same proc).
2294        let ping_params = PingPongActorParams::new(Some(proc_0_undeliverable_tx.bind()), None);
2295        let ping_handle = proc_0
2296            .spawn::<PingPongActor>("ping", ping_params)
2297            .await
2298            .unwrap();
2299        let pong_params = PingPongActorParams::new(Some(proc_1_undeliverable_tx.bind()), None);
2300        let pong_handle = proc_1
2301            .spawn::<PingPongActor>("pong", pong_params)
2302            .await
2303            .unwrap();
2304
2305        // Now kill pong's mailbox server making message delivery
2306        // between procs impossible.
2307        proc_actor_1.mailbox.stop("from testing");
2308        proc_actor_1.mailbox.await.unwrap().unwrap();
2309
2310        // That in itself shouldn't be a problem. Check the world
2311        // health now.
2312        let snapshot = system_actor_handle
2313            .snapshot(&client, SystemSnapshotFilter::all())
2314            .await
2315            .unwrap();
2316        assert_eq!(snapshot.worlds.len(), 1);
2317        assert!(snapshot.worlds.contains_key(&world_id));
2318        assert_eq!(
2319            snapshot.worlds.get(&world_id).unwrap().status,
2320            WorldStatus::Live
2321        );
2322
2323        // Have 'ping' send 'pong' a message.
2324        let ttl = 1_u64;
2325        let (game_over, on_game_over) = proc_0_client.open_once_port::<bool>();
2326        ping_handle
2327            .send(PingPongMessage(ttl, pong_handle.bind(), game_over.bind()))
2328            .unwrap();
2329
2330        // We expect message delivery failure prevents the game from
2331        // ending within the timeout.
2332        assert!(
2333            RealClock
2334                .timeout(tokio::time::Duration::from_secs(4), on_game_over.recv())
2335                .await
2336                .is_err()
2337        );
2338
2339        // By supervision, we expect the world should have
2340        // transitioned to unhealthy.
2341        let snapshot = system_actor_handle
2342            .snapshot(&client, SystemSnapshotFilter::all())
2343            .await
2344            .unwrap();
2345        assert_eq!(snapshot.worlds.len(), 1);
2346        assert!(matches!(
2347            snapshot.worlds.get(&world_id).unwrap().status,
2348            WorldStatus::Unhealthy(_)
2349        ));
2350    }
2351
2352    #[tokio::test]
2353    async fn test_stop_fast() -> Result<()> {
2354        let server_handle = System::serve(
2355            ChannelAddr::any(ChannelTransport::Tcp),
2356            Duration::from_secs(2), // supervision update timeout
2357            Duration::from_secs(2), // duration to evict an unhealthy world
2358        )
2359        .await?;
2360        let system_actor_handle = server_handle.system_actor_handle();
2361        let mut system = System::new(server_handle.local_addr().clone());
2362        let client = system.attach().await?;
2363
2364        // Create a new world message and send to system actor
2365        let (client_tx, client_rx) = client.open_once_port::<()>();
2366        system_actor_handle.send(SystemMessage::Stop {
2367            worlds: None,
2368            proc_timeout: Duration::from_secs(5),
2369            reply_port: client_tx.bind(),
2370        })?;
2371        client_rx.recv().await?;
2372
2373        // Check that it has stopped.
2374        let mut sys_status_rx = system_actor_handle.status();
2375        {
2376            let received = sys_status_rx.borrow_and_update();
2377            assert_eq!(*received, ActorStatus::Stopped);
2378        }
2379
2380        Ok(())
2381    }
2382
2383    #[tokio::test]
2384    async fn test_update_sim_address() {
2385        simnet::start();
2386
2387        let src_id = id!(proc[0].actor);
2388        let src_addr = ChannelAddr::Sim(SimAddr::new("unix!@src".parse().unwrap()).unwrap());
2389        let dst_addr = ChannelAddr::Sim(SimAddr::new("unix!@dst".parse().unwrap()).unwrap());
2390        let (_, mut rx) = channel::serve::<MessageEnvelope>(src_addr.clone()).unwrap();
2391
2392        let router = ReportingRouter::new();
2393
2394        router
2395            .router
2396            .bind(src_id.proc_id().clone().into(), src_addr);
2397        router.router.bind(id!(proc[1]).into(), dst_addr);
2398
2399        router.post_update_address(&MessageEnvelope::new(
2400            src_id,
2401            PortId(id!(proc[1].actor), 9999u64),
2402            Serialized::serialize(&1u64).unwrap(),
2403            Attrs::new(),
2404        ));
2405
2406        let envelope = rx.recv().await.unwrap();
2407        let admin_msg = envelope
2408            .data()
2409            .deserialized::<MailboxAdminMessage>()
2410            .unwrap();
2411        let MailboxAdminMessage::UpdateAddress {
2412            addr: ChannelAddr::Sim(addr),
2413            ..
2414        } = admin_msg
2415        else {
2416            panic!("Expected sim address");
2417        };
2418
2419        assert_eq!(addr.src().clone().unwrap().to_string(), "unix:@src");
2420        assert_eq!(addr.addr().to_string(), "unix:@dst");
2421    }
2422}