hyperactor_multiprocess/
system_actor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! System actor manages a system.
10
11use std::collections::BTreeSet;
12use std::collections::HashMap;
13use std::collections::HashSet;
14use std::collections::hash_map::Entry;
15use std::fmt::Display;
16use std::fmt::Formatter;
17use std::hash::Hash;
18use std::sync::Arc;
19use std::sync::LazyLock;
20use std::time::SystemTime;
21
22use async_trait::async_trait;
23use dashmap::DashMap;
24use enum_as_inner::EnumAsInner;
25use hyperactor::Actor;
26use hyperactor::ActorHandle;
27use hyperactor::ActorId;
28use hyperactor::ActorRef;
29use hyperactor::Context;
30use hyperactor::HandleClient;
31use hyperactor::Instance;
32use hyperactor::Named;
33use hyperactor::OncePortRef;
34use hyperactor::PortHandle;
35use hyperactor::PortRef;
36use hyperactor::ProcId;
37use hyperactor::RefClient;
38use hyperactor::WorldId;
39use hyperactor::actor::Handler;
40use hyperactor::channel::ChannelAddr;
41use hyperactor::channel::sim::SimAddr;
42use hyperactor::clock::Clock;
43use hyperactor::clock::ClockKind;
44use hyperactor::id;
45use hyperactor::mailbox::BoxedMailboxSender;
46use hyperactor::mailbox::DialMailboxRouter;
47use hyperactor::mailbox::MailboxSender;
48use hyperactor::mailbox::MailboxSenderError;
49use hyperactor::mailbox::MessageEnvelope;
50use hyperactor::mailbox::PortSender;
51use hyperactor::mailbox::Undeliverable;
52use hyperactor::mailbox::mailbox_admin_message::MailboxAdminMessage;
53use hyperactor::mailbox::monitored_return_handle;
54use hyperactor::proc::Proc;
55use hyperactor::reference::Index;
56use serde::Deserialize;
57use serde::Serialize;
58use tokio::time::Duration;
59use tokio::time::Instant;
60
61use super::proc_actor::ProcMessage;
62use crate::proc_actor::Environment;
63use crate::proc_actor::ProcActor;
64use crate::proc_actor::ProcStopResult;
65use crate::supervision::ProcStatus;
66use crate::supervision::ProcSupervisionMessage;
67use crate::supervision::ProcSupervisionState;
68use crate::supervision::WorldSupervisionMessage;
69use crate::supervision::WorldSupervisionState;
70
71/// A snapshot of a single proc.
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73pub struct WorldSnapshotProcInfo {
74    /// The labels of the proc.
75    pub labels: HashMap<String, String>,
76}
77
78impl From<&ProcInfo> for WorldSnapshotProcInfo {
79    fn from(proc_info: &ProcInfo) -> Self {
80        Self {
81            labels: proc_info.labels.clone(),
82        }
83    }
84}
85
86/// A snapshot view of a world in the system.
87#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
88pub struct WorldSnapshot {
89    /// The host procs used to spawn procs in this world. Some caveats:
90    ///   1. The host procs are actually not in this world but in a different
91    ///      "shadow" world. The shadow world's ID can be told from the host
92    ///      ProcId.
93    ///   2. Not all host procs are captured here. This field only captures the
94    ///      hosts that joined before the world were created.
95    pub host_procs: HashSet<ProcId>,
96
97    /// The procs in this world.
98    pub procs: HashMap<ProcId, WorldSnapshotProcInfo>,
99
100    /// The status of the world.
101    pub status: WorldStatus,
102
103    /// Labels attached to this world. They can be used later to query
104    /// world(s) using system snapshot api.
105    pub labels: HashMap<String, String>,
106}
107
108impl WorldSnapshot {
109    fn from_world_filtered(world: &World, filter: &SystemSnapshotFilter) -> Self {
110        WorldSnapshot {
111            host_procs: world.state.host_map.keys().map(|h| &h.0).cloned().collect(),
112            procs: world
113                .state
114                .procs
115                .iter()
116                .map_while(|(k, v)| filter.proc_matches(v).then_some((k.clone(), v.into())))
117                .collect(),
118            status: world.state.status.clone(),
119            labels: world.labels.clone(),
120        }
121    }
122}
123
124/// A snapshot view of the system.
125#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named)]
126pub struct SystemSnapshot {
127    /// Snapshots of all the worlds in this system.
128    pub worlds: HashMap<WorldId, WorldSnapshot>,
129    /// Execution ID of the system.
130    pub execution_id: String,
131}
132
133/// A filter used to filter the snapshot view of the system.
134#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Named, Default)]
135pub struct SystemSnapshotFilter {
136    /// The world ids to filter. Empty list matches all.
137    pub worlds: Vec<WorldId>,
138    /// World labels to filter. Empty matches all.
139    pub world_labels: HashMap<String, String>,
140    /// Proc labels to filter. Empty matches all.
141    pub proc_labels: HashMap<String, String>,
142}
143
144impl SystemSnapshotFilter {
145    /// Create an empty filter that matches everything.
146    pub fn all() -> Self {
147        Self {
148            worlds: Vec::new(),
149            world_labels: HashMap::new(),
150            proc_labels: HashMap::new(),
151        }
152    }
153
154    /// Whether the filter matches the given world.
155    fn world_matches(&self, world: &World) -> bool {
156        if !self.worlds.is_empty() && !self.worlds.contains(&world.world_id) {
157            return false;
158        }
159        Self::labels_match(&self.world_labels, &world.labels)
160    }
161
162    fn proc_matches(&self, proc_info: &ProcInfo) -> bool {
163        Self::labels_match(&self.proc_labels, &proc_info.labels)
164    }
165
166    /// Whether the filter matches the given proc labels.
167    fn labels_match(
168        filter_labels: &HashMap<String, String>,
169        labels: &HashMap<String, String>,
170    ) -> bool {
171        filter_labels.is_empty()
172            || filter_labels
173                .iter()
174                .all(|(k, v)| labels.contains_key(k) && labels.get(k).unwrap() == v)
175    }
176}
177
178/// Update the states of worlds, specifically checking if they are unhealthy.
179/// Evict the world if it is unhealthy for too long.
180#[derive(Debug, Clone, PartialEq)]
181struct MaintainWorldHealth;
182
183/// The proc's lifecyle management mode.
184#[derive(Named, Debug, Clone, Serialize, Deserialize, PartialEq)]
185pub enum ProcLifecycleMode {
186    /// Proc is detached, its lifecycle isn't managed by the system.
187    Detached,
188    /// Proc's lifecycle is managed by the system, supervision is enabled for the proc.
189    ManagedBySystem,
190    /// The proc manages the lifecyle of the system, supervision is enabled for the proc.
191    /// System goes down when the proc stops.
192    ManagingSystem,
193}
194
195impl ProcLifecycleMode {
196    /// Whether the lifecycle mode indicates whether proc is managed by/managing system or not.
197    pub fn is_managed(&self) -> bool {
198        matches!(
199            self,
200            ProcLifecycleMode::ManagedBySystem | ProcLifecycleMode::ManagingSystem
201        )
202    }
203}
204
205/// System messages.
206#[derive(
207    hyperactor::Handler,
208    HandleClient,
209    RefClient,
210    Named,
211    Debug,
212    Clone,
213    Serialize,
214    Deserialize,
215    PartialEq
216)]
217pub enum SystemMessage {
218    /// Join the system at the given proc id.
219    Join {
220        /// The world that is being joined.
221        world_id: WorldId,
222        /// The proc id that is joining.
223        proc_id: ProcId,
224        /// Reference to the proc actor managing the proc.
225        proc_message_port: PortRef<ProcMessage>,
226        /// The channel address used to communicate with the proc.
227        proc_addr: ChannelAddr,
228        /// Arbitrary name/value pairs that can be used to identify the proc.
229        labels: HashMap<String, String>,
230        /// The lifecyle mode of the proc.
231        lifecycle_mode: ProcLifecycleMode,
232    },
233
234    /// Create a new world or update an existing world.
235    UpsertWorld {
236        /// The world id.
237        world_id: WorldId,
238        /// The shape of the world.
239        shape: Shape,
240        /// The number of procs per host.
241        num_procs_per_host: usize,
242        /// How to spawn procs in the world.
243        env: Environment,
244        /// Arbitrary name/value pairs that can be used to identify the world.
245        labels: HashMap<String, String>,
246    },
247
248    /// Return a snapshot view of this system. Used for debugging.
249    #[log_level(debug)]
250    Snapshot {
251        /// The filter used to filter the snapshot view.
252        filter: SystemSnapshotFilter,
253        /// Used to return the snapshot view to the caller.
254        #[reply]
255        ret: OncePortRef<SystemSnapshot>,
256    },
257
258    /// Start the shutdown process of everything in this system. It tries to
259    /// shutdown all the procs first, and then the system actor itself.
260    ///
261    /// Note this shutdown sequence is best effort, yet not guaranteed. It is
262    /// possible the system actor/proc might already stop, while the remote
263    /// procs are still in the middle of shutting down.
264    Stop {
265        /// List of worlds to stop. If provided, only the procs belonging to
266        /// the list of worlds are stopped, otherwise all worlds are stopped
267        /// including the system proc itself.
268        worlds: Option<Vec<WorldId>>,
269        /// The timeout used by ProcActor to stop the proc.
270        proc_timeout: Duration,
271        /// Used to return success to the caller.
272        reply_port: OncePortRef<()>,
273    },
274}
275
276/// Errors that can occur inside a system actor.
277#[derive(thiserror::Error, Debug)]
278pub enum SystemActorError {
279    /// A proc is trying to join before a world is created
280    #[error("procs cannot join uncreated world {0}")]
281    UnknownWorldId(WorldId),
282
283    /// Spawn procs failed
284    #[error("failed to spawn procs")]
285    SpawnProcsFailed(#[from] MailboxSenderError),
286
287    /// Host ID does not start with valid prefix.
288    #[error("invalid host {0}: does not start with prefix '{SHADOW_PREFIX}'")]
289    InvalidHostPrefix(HostId),
290
291    /// A host is trying to join the world which already has a joined host with the same ID.
292    #[error("host ID {0} already exists in world")]
293    DuplicatedHostId(HostId),
294
295    /// Trying to get the actor ref for a host that doesn't exist in a world.
296    #[error("host {0} does not exist in world")]
297    HostNotExist(HostId),
298}
299
300/// TODO: add missing doc
301#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
302pub enum Shape {
303    /// A definite N-dimensional shape of the world, the semantics of the shape can be defined
304    /// by the user (TODO: implement this), e.g. in a shape like [3, 2, 2], user will be able
305    /// to express things like dim 0: ai zone, dim 1: rack, dim 2: host.
306    Definite(Vec<usize>),
307    /// Shape is unknown.
308    Unknown,
309}
310
311/// TODO: Toss this world implementation away once we have
312/// a more clearly defined allocation API.
313/// Currently, each world in a system has two worlds beneath:
314/// the actual world and the shadow world. The shadow world
315/// is a world that is used to maintain hosts which in turn
316/// spawn procs for the world.
317/// This is needed in order to support the current scheduler implementation
318/// which does not support per-proc scheduling.
319///
320/// That means, each host is a proc in the shadow world. Each host proc spawns
321/// a number of procs for the actual world.
322#[derive(Debug, Clone, Serialize, Deserialize)]
323pub struct World {
324    /// The world id.
325    world_id: WorldId,
326    /// TODO: add misssing doc
327    scheduler_params: SchedulerParams,
328    /// Artbitrary labels attached to the world.
329    labels: HashMap<String, String>,
330    /// The state of the world.
331    state: WorldState,
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
335struct Host {
336    num_procs_assigned: usize,
337    proc_message_port: PortRef<ProcMessage>,
338    host_rank: usize,
339}
340
341impl Host {
342    fn new(proc_message_port: PortRef<ProcMessage>, host_rank: usize) -> Self {
343        Self {
344            num_procs_assigned: 0,
345            proc_message_port,
346            host_rank,
347        }
348    }
349
350    fn get_assigned_procs(
351        &mut self,
352        world_id: &WorldId,
353        scheduler_params: &mut SchedulerParams,
354    ) -> Vec<ProcId> {
355        // Get Host from hosts given host_id else return empty vec
356        let mut proc_ids = Vec::new();
357
358        // The number of hosts that will be assigned a total of scheduler_params.num_procs_per_host
359        // procs. If scheduler_params.num_procs() is 31 and scheduler_params.num_procs_per_host is 8,
360        // then num_saturated_hosts == 3 even though total number of hosts will be 4.
361        let num_saturated_hosts =
362            scheduler_params.num_procs() / scheduler_params.num_procs_per_host;
363        // If num_saturated_hosts is less than total hosts, then the final host_rank will be equal
364        // to num_saturated_hosts, and should not be assigned the full scheduler_params.num_procs_per_host.
365        // Instead, we should only assign the remaining procs. So if num_procs is 31, num_procs_per_host is 8,
366        // then host_rank 3 should only be assigned 7 procs.
367        let num_scheduled = if self.host_rank == num_saturated_hosts {
368            scheduler_params.num_procs() % scheduler_params.num_procs_per_host
369        } else {
370            scheduler_params.num_procs_per_host
371        };
372
373        scheduler_params.num_procs_scheduled += num_scheduled;
374
375        for _ in 0..num_scheduled {
376            // Compute each proc id (which will become the RANK env var on each worker)
377            // based on host_rank, which is (optionally) assigned to each host at bootstrap
378            // time according to a sorted hostname file.
379            //
380            // More precisely, when a host process starts up, it gets its host rank from some
381            // global source of truth common to all host nodes. This source of truth could be
382            // a file or an env var. In order to be consistent with the SPMD world, assuming
383            // num_procs_per_host == N, we would want worker ranks 0 through N-1 on host 0;
384            // ranks N through 2N-1 on host 1; etc. So, for host H, we assign proc ids in the
385            // interval [H*N, (H+1)*N).
386            let rank =
387                self.host_rank * scheduler_params.num_procs_per_host + self.num_procs_assigned;
388            let proc_id = ProcId::Ranked(world_id.clone(), rank);
389            proc_ids.push(proc_id);
390            self.num_procs_assigned += 1;
391        }
392
393        proc_ids
394    }
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
398struct SchedulerParams {
399    shape: Shape,
400    num_procs_scheduled: usize,
401    num_procs_per_host: usize,
402    next_rank: Index,
403    env: Environment,
404}
405
406impl SchedulerParams {
407    fn num_procs(&self) -> usize {
408        match &self.shape {
409            Shape::Definite(v) => v.iter().product(),
410            Shape::Unknown => unimplemented!(),
411        }
412    }
413}
414
415/// A world id that is used to identify a host.
416pub type HostWorldId = WorldId;
417static SHADOW_PREFIX: &str = "host";
418
419/// A host id that is used to identify a host.
420#[derive(
421    Debug,
422    Serialize,
423    Deserialize,
424    Clone,
425    PartialEq,
426    Eq,
427    PartialOrd,
428    Hash,
429    Ord
430)]
431pub struct HostId(ProcId);
432impl HostId {
433    /// Creates a new HostId from a proc_id.
434    pub fn new(proc_id: ProcId) -> Result<Self, anyhow::Error> {
435        if !proc_id
436            .world_name()
437            .expect("proc must be ranked for world_name check")
438            .starts_with(SHADOW_PREFIX)
439        {
440            anyhow::bail!(
441                "proc_id {} is not a valid HostId because it does not start with {}",
442                proc_id,
443                SHADOW_PREFIX
444            )
445        }
446        Ok(Self(proc_id))
447    }
448}
449
450impl TryFrom<ProcId> for HostId {
451    type Error = anyhow::Error;
452
453    fn try_from(proc_id: ProcId) -> Result<Self, anyhow::Error> {
454        if !proc_id
455            .world_name()
456            .expect("proc must be ranked for world_name check")
457            .starts_with(SHADOW_PREFIX)
458        {
459            anyhow::bail!(
460                "proc_id {} is not a valid HostId because it does not start with {}",
461                proc_id,
462                SHADOW_PREFIX
463            )
464        }
465        Ok(Self(proc_id))
466    }
467}
468
469impl std::fmt::Display for HostId {
470    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
471        self.0.fmt(f)
472    }
473}
474
475type HostMap = HashMap<HostId, Host>;
476
477#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
478struct ProcInfo {
479    port_ref: PortRef<ProcMessage>,
480    labels: HashMap<String, String>,
481}
482
483#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
484struct WorldState {
485    host_map: HostMap,
486    procs: HashMap<ProcId, ProcInfo>,
487    status: WorldStatus,
488}
489
490/// A world status represents the different phases of a world.
491#[derive(Debug, Clone, Serialize, Deserialize, EnumAsInner, PartialEq)]
492pub enum WorldStatus {
493    /// Waiting for the world to be created. Accumulate joined hosts or procs while we're waiting.
494    AwaitingCreation,
495
496    /// World is created and enough procs based on the scheduler parameter.
497    /// All procs in the world are without failures.
498    Live,
499
500    /// World is created but it does not have enough procs or some procs are failing.
501    /// [`SystemTime`] contains the time when the world became unhealthy.
502    // Use SystemTime instead of Instant to avoid the issue of serialization.
503    Unhealthy(SystemTime),
504}
505
506impl Display for WorldStatus {
507    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
508        match self {
509            WorldStatus::AwaitingCreation => write!(f, "Awaiting Creation"),
510            WorldStatus::Live => write!(f, "Live"),
511            WorldStatus::Unhealthy(_) => write!(f, "Unhealthy"),
512        }
513    }
514}
515
516impl WorldState {
517    /// Gets the mutable ref to host_map.
518    fn get_host_map_mut(&mut self) -> &mut HostMap {
519        &mut self.host_map
520    }
521
522    /// Gets the ref to host_map.
523    fn get_host_map(&self) -> &HostMap {
524        &self.host_map
525    }
526}
527
528impl World {
529    fn new(
530        world_id: WorldId,
531        shape: Shape,
532        state: WorldState,
533        num_procs_per_host: usize,
534        env: Environment,
535        labels: HashMap<String, String>,
536    ) -> Result<Self, anyhow::Error> {
537        if world_id.name().starts_with(SHADOW_PREFIX) {
538            anyhow::bail!(
539                "world name {} cannot start with {}!",
540                world_id,
541                SHADOW_PREFIX
542            )
543        }
544        tracing::info!("creating world {}", world_id,);
545        Ok(Self {
546            world_id,
547            scheduler_params: SchedulerParams {
548                shape,
549                num_procs_per_host,
550                num_procs_scheduled: 0,
551                next_rank: 0,
552                env,
553            },
554            state,
555            labels,
556        })
557    }
558
559    fn get_real_world_id(proc_world_id: &WorldId) -> WorldId {
560        WorldId(
561            proc_world_id
562                .name()
563                .strip_prefix(SHADOW_PREFIX)
564                .unwrap_or(proc_world_id.name())
565                .to_string(),
566        )
567    }
568
569    fn is_host_world(world_id: &WorldId) -> bool {
570        world_id.name().starts_with(SHADOW_PREFIX)
571    }
572
573    fn get_port_ref_from_host(
574        &self,
575        host_id: &HostId,
576    ) -> Result<PortRef<ProcMessage>, SystemActorError> {
577        let host_map = self.state.get_host_map();
578        // Get Host from hosts given proc_id
579        match host_map.get(host_id) {
580            Some(h) => Ok(h.proc_message_port.clone()),
581            None => Err(SystemActorError::HostNotExist(host_id.clone())),
582        }
583    }
584
585    /// Adds procs to the world.
586    fn add_proc(
587        &mut self,
588        proc_id: ProcId,
589        proc_message_port: PortRef<ProcMessage>,
590        labels: HashMap<String, String>,
591    ) -> Result<(), SystemActorError> {
592        self.state.procs.insert(
593            proc_id,
594            ProcInfo {
595                port_ref: proc_message_port,
596                labels,
597            },
598        );
599        if self.state.status.is_unhealthy()
600            && self.state.procs.len() >= self.scheduler_params.num_procs()
601        {
602            self.state.status = WorldStatus::Live;
603            tracing::info!(
604                "world {}: ready to serve with {} procs",
605                self.world_id,
606                self.state.procs.len()
607            );
608        }
609        Ok(())
610    }
611
612    /// 1. Adds a host to the hosts map.
613    /// 2. Create executor procs for the host.
614    /// 3. Run necessary programs
615    async fn on_host_join(
616        &mut self,
617        host_id: HostId,
618        proc_message_port: PortRef<ProcMessage>,
619        router: &DialMailboxRouter,
620    ) -> Result<(), SystemActorError> {
621        let mut host_entry = match self.state.host_map.entry(host_id.clone()) {
622            Entry::Occupied(_) => {
623                return Err(SystemActorError::DuplicatedHostId(host_id));
624            }
625            Entry::Vacant(entry) => entry.insert_entry(Host::new(
626                proc_message_port.clone(),
627                host_id
628                    .0
629                    .rank()
630                    .expect("host proc must be ranked for rank access"),
631            )),
632        };
633
634        if self.state.status == WorldStatus::AwaitingCreation {
635            return Ok(());
636        }
637
638        let proc_ids = host_entry
639            .get_mut()
640            .get_assigned_procs(&self.world_id, &mut self.scheduler_params);
641
642        router.serialize_and_send(
643            &proc_message_port,
644            ProcMessage::SpawnProc {
645                env: self.scheduler_params.env.clone(),
646                world_id: self.world_id.clone(),
647                proc_ids,
648                world_size: self.scheduler_params.num_procs(),
649            },
650            monitored_return_handle(),
651        )?;
652        Ok(())
653    }
654
655    fn get_hosts_to_procs(&mut self) -> Result<HashMap<HostId, Vec<ProcId>>, SystemActorError> {
656        // A map from host ID to scheduled proc IDs on this host.
657        let mut host_proc_map: HashMap<HostId, Vec<ProcId>> = HashMap::new();
658        let host_map = self.state.get_host_map_mut();
659        // Iterate over each entry in self.hosts
660        for (host_id, host) in host_map {
661            // Had to clone hosts in order to call schedule_procs
662            if host.num_procs_assigned == self.scheduler_params.num_procs_per_host {
663                continue;
664            }
665            let host_procs = host.get_assigned_procs(&self.world_id, &mut self.scheduler_params);
666            if host_procs.is_empty() {
667                continue;
668            }
669            host_proc_map.insert(host_id.clone(), host_procs);
670        }
671        Ok(host_proc_map)
672    }
673
674    async fn on_create(&mut self, router: &DialMailboxRouter) -> Result<(), anyhow::Error> {
675        let host_procs_map = self.get_hosts_to_procs()?;
676        for (host_id, procs_ids) in host_procs_map {
677            if procs_ids.is_empty() {
678                continue;
679            }
680
681            // REFACTOR(marius): remove
682            let world_id = procs_ids
683                .first()
684                .unwrap()
685                .clone()
686                .into_ranked()
687                .expect("proc must be ranked for world_id access")
688                .0
689                .clone();
690            // Open port ref
691            tracing::info!("spawning procs for host {:?}", host_id);
692            router.serialize_and_send(
693                // Get host proc!
694                &self.get_port_ref_from_host(&host_id)?,
695                ProcMessage::SpawnProc {
696                    env: self.scheduler_params.env.clone(),
697                    world_id,
698                    // REFACTOR(marius): remove
699                    proc_ids: procs_ids,
700                    world_size: self.scheduler_params.num_procs(),
701                },
702                monitored_return_handle(),
703            )?;
704        }
705        Ok(())
706    }
707}
708
709/// A mailbox router that forwards messages to their destinations and
710/// additionally reports the destination address back to the sender’s
711/// [`ProcActor`], allowing it to cache the address for future use.
712#[derive(Debug, Clone)]
713pub struct ReportingRouter {
714    router: DialMailboxRouter,
715    /// A record of cached addresses from dst_proc_id to HashSet(src_proc_id)
716    /// Right now only the proc_ids are recorded for updating purpose.
717    /// We can also cache the address here in the future.
718    address_cache: Arc<DashMap<ProcId, HashSet<ProcId>>>,
719}
720
721impl MailboxSender for ReportingRouter {
722    fn post_unchecked(
723        &self,
724        envelope: MessageEnvelope,
725        return_handle: PortHandle<Undeliverable<MessageEnvelope>>,
726    ) {
727        let ReportingRouter { router, .. } = self;
728        self.post_update_address(&envelope);
729        router.post_unchecked(envelope, return_handle);
730    }
731}
732
733impl ReportingRouter {
734    fn new() -> Self {
735        Self {
736            router: DialMailboxRouter::new(),
737            address_cache: Arc::new(DashMap::new()),
738        }
739    }
740    fn post_update_address(&self, envelope: &MessageEnvelope) {
741        let system_proc_id = id!(system[0]);
742        // These are edge cases that are at unlikely to come up in a
743        // well ordered system but in the event that they do we skip
744        // sending update address messages:
745        // - The sender ID is "unknown" (it makes no sense to remember
746        //   the address of an unknown sender)
747        // - The sender world is "user", which doesn't have a ProcActor running
748        //   to process the address update message.
749        // - The sender is the system (the system knows all addresses)
750        // - The destination is the system (every proc knows the
751        //   system address)
752        // - The sender and the destination are on the same proc (it
753        //   doesn't make sense to be dialing connections between them).
754        if envelope.sender().proc_id() == &id!(unknown[0])
755            || envelope.sender().proc_id().world_id() == Some(&id!(user))
756            || envelope.sender().proc_id() == &system_proc_id
757            || envelope.dest().actor_id().proc_id() == &system_proc_id
758            || envelope.sender().proc_id() == envelope.dest().actor_id().proc_id()
759        {
760            return;
761        }
762        let (dst_proc_id, dst_proc_addr) = self.dest_proc_id_and_address(envelope);
763        let Some(dst_proc_addr) = dst_proc_addr else {
764            tracing::warn!("unknown address for {}", &dst_proc_id);
765            return;
766        };
767
768        let sender_proc_id = envelope.sender().proc_id();
769        self.upsert_address_cache(sender_proc_id, &dst_proc_id);
770        // Sim addresses have a concept of directionality. When we notify a proc of an address we should
771        // use the proc's address as the source for the sim address.
772        let sender_address = self.router.lookup_addr(envelope.sender());
773        let dst_proc_addr =
774            if let (Some(ChannelAddr::Sim(sender_sim_addr)), ChannelAddr::Sim(dest_sim_addr)) =
775                (sender_address, &dst_proc_addr)
776            {
777                ChannelAddr::Sim(
778                    SimAddr::new_with_src(
779                        // source is the sender
780                        sender_sim_addr.addr().clone(),
781                        // dest remains unchanged
782                        dest_sim_addr.addr().clone(),
783                    )
784                    .unwrap(),
785                )
786            } else {
787                dst_proc_addr
788            };
789        self.serialize_and_send(
790            &self.proc_port_ref(sender_proc_id),
791            MailboxAdminMessage::UpdateAddress {
792                proc_id: dst_proc_id,
793                addr: dst_proc_addr,
794            },
795            monitored_return_handle(),
796        )
797        .expect("unexpected serialization failure")
798    }
799
800    /// broadcasts the address of the proc if there's any stale record that has been sent
801    /// out to senders before.
802    fn broadcast_addr(&self, dst_proc_id: &ProcId, dst_proc_addr: ChannelAddr) {
803        if let Some(r) = self.address_cache.get(dst_proc_id) {
804            for sender_proc_id in r.value() {
805                tracing::info!(
806                    "broadcasting address change to {} for {}: {}",
807                    sender_proc_id,
808                    dst_proc_id,
809                    dst_proc_addr
810                );
811                self.serialize_and_send(
812                    &self.proc_port_ref(sender_proc_id),
813                    MailboxAdminMessage::UpdateAddress {
814                        proc_id: dst_proc_id.clone(),
815                        addr: dst_proc_addr.clone(),
816                    },
817                    monitored_return_handle(),
818                )
819                .expect("unexpected serialization failure")
820            }
821        }
822    }
823
824    fn upsert_address_cache(&self, src_proc_id: &ProcId, dst_proc_id: &ProcId) {
825        self.address_cache
826            .entry(dst_proc_id.clone())
827            .and_modify(|src_proc_ids| {
828                src_proc_ids.insert(src_proc_id.clone());
829            })
830            .or_insert({
831                let mut set = HashSet::new();
832                set.insert(src_proc_id.clone());
833                set
834            });
835    }
836
837    fn dest_proc_id_and_address(
838        &self,
839        envelope: &MessageEnvelope,
840    ) -> (ProcId, Option<ChannelAddr>) {
841        let dest_proc_port_id = envelope.dest();
842        let dest_proc_actor_id = dest_proc_port_id.actor_id();
843        let dest_proc_id = dest_proc_actor_id.proc_id();
844        let dest_proc_addr = self.router.lookup_addr(dest_proc_actor_id);
845        (dest_proc_id.clone(), dest_proc_addr)
846    }
847
848    fn proc_port_ref(&self, proc_id: &ProcId) -> PortRef<MailboxAdminMessage> {
849        let proc_actor_id = ActorId(proc_id.clone(), "proc".to_string(), 0);
850        let proc_actor_ref = ActorRef::<ProcActor>::attest(proc_actor_id);
851        proc_actor_ref.port::<MailboxAdminMessage>()
852    }
853}
854
855/// TODO: add misssing doc
856#[derive(Debug, Clone)]
857pub struct SystemActorParams {
858    mailbox_router: ReportingRouter,
859
860    /// The duration to declare an actor dead if no supervision update received.
861    supervision_update_timeout: Duration,
862
863    /// The duration to evict an unhealthy world, after which a world fails supervision states.
864    world_eviction_timeout: Duration,
865}
866
867impl SystemActorParams {
868    /// Create a new system actor params.
869    pub fn new(supervision_update_timeout: Duration, world_eviction_timeout: Duration) -> Self {
870        Self {
871            mailbox_router: ReportingRouter::new(),
872            supervision_update_timeout,
873            world_eviction_timeout,
874        }
875    }
876}
877
878/// A map of all alive procs with their proc ids as the key, the value is the supervision info of this proc.
879#[derive(Debug, Clone, Serialize, Deserialize)]
880struct SystemSupervisionState {
881    // A map from world id to world supervision state.
882    supervision_map: HashMap<WorldId, WorldSupervisionInfo>,
883    // Supervision expiration duration.
884    supervision_update_timeout: Duration,
885}
886
887// Used to record when procs sent their last heartbeats.
888#[derive(Debug, Clone, Default)]
889struct HeartbeatRecord {
890    // This index is used to efficiently find expired procs.
891    // T208419148: Handle btree_index initialization during system actor recovery
892    btree_index: BTreeSet<(Instant, ProcId)>,
893
894    // Last time when proc was updated.
895    proc_last_update_time: HashMap<ProcId, Instant>,
896}
897
898impl HeartbeatRecord {
899    // Update this proc's heartbeat record with timestamp as "now".
900    fn update(&mut self, proc_id: &ProcId, clock: &impl Clock) {
901        // Remove previous entry in btree_index if exists.
902        if let Some(update_time) = self.proc_last_update_time.get(proc_id) {
903            self.btree_index
904                .remove(&(update_time.clone(), proc_id.clone()));
905        }
906
907        // Insert new entry into btree_index.
908        let now = clock.now();
909        self.proc_last_update_time
910            .insert(proc_id.clone(), now.clone());
911        self.btree_index.insert((now.clone(), proc_id.clone()));
912    }
913
914    // Find all the procs with expired heartbeat, and mark them as expired in
915    // WorldSupervisionState.
916    fn mark_expired_procs(
917        &self,
918        state: &mut WorldSupervisionState,
919        clock: &impl Clock,
920        supervision_update_timeout: Duration,
921    ) {
922        // Update procs' live status.
923        let now = clock.now();
924        self.btree_index
925            .iter()
926            .take_while(|(last_update_time, _)| {
927                now > *last_update_time + supervision_update_timeout
928            })
929            .for_each(|(_, proc_id)| {
930                if let Some(proc_state) = state
931                    .procs
932                    .get_mut(&proc_id.rank().expect("proc must be ranked for rank access"))
933                {
934                    match proc_state.proc_health {
935                        ProcStatus::Alive => proc_state.proc_health = ProcStatus::Expired,
936                        // Do not overwrite the health of a proc already known to be unhealthy.
937                        _ => (),
938                    }
939                }
940            });
941    }
942}
943
944#[derive(Debug, Clone, Serialize, Deserialize)]
945struct WorldSupervisionInfo {
946    state: WorldSupervisionState,
947
948    // The lifecycle mode of the proc.
949    lifecycle_mode: HashMap<ProcId, ProcLifecycleMode>,
950
951    #[serde(skip)]
952    heartbeat_record: HeartbeatRecord,
953}
954
955impl WorldSupervisionInfo {
956    fn new() -> Self {
957        Self {
958            state: WorldSupervisionState {
959                procs: HashMap::new(),
960            },
961            lifecycle_mode: HashMap::new(),
962            heartbeat_record: HeartbeatRecord::default(),
963        }
964    }
965}
966
967impl SystemSupervisionState {
968    fn new(supervision_update_timeout: Duration) -> Self {
969        Self {
970            supervision_map: HashMap::new(),
971            supervision_update_timeout,
972        }
973    }
974
975    /// Create a proc supervision entry.
976    fn create(
977        &mut self,
978        proc_state: ProcSupervisionState,
979        lifecycle_mode: ProcLifecycleMode,
980        clock: &impl Clock,
981    ) {
982        if World::is_host_world(&proc_state.world_id) {
983            return;
984        }
985
986        let world = self
987            .supervision_map
988            .entry(proc_state.world_id.clone())
989            .or_insert_with(WorldSupervisionInfo::new);
990        world
991            .lifecycle_mode
992            .insert(proc_state.proc_id.clone(), lifecycle_mode);
993
994        self.update(proc_state, clock);
995    }
996
997    /// Update a proc supervision entry.
998    fn update(&mut self, proc_state: ProcSupervisionState, clock: &impl Clock) {
999        if World::is_host_world(&proc_state.world_id) {
1000            return;
1001        }
1002
1003        let world = self
1004            .supervision_map
1005            .entry(proc_state.world_id.clone())
1006            .or_insert_with(WorldSupervisionInfo::new);
1007
1008        world.heartbeat_record.update(&proc_state.proc_id, clock);
1009
1010        // Update supervision map.
1011        if let Some(info) = world.state.procs.get_mut(
1012            &proc_state
1013                .proc_id
1014                .rank()
1015                .expect("proc must be ranked for proc state update"),
1016        ) {
1017            match info.proc_health {
1018                ProcStatus::Alive => info.proc_health = proc_state.proc_health,
1019                // Do not overwrite the health of a proc already known to be unhealthy.
1020                _ => (),
1021            }
1022            info.failed_actors.extend(proc_state.failed_actors);
1023        } else {
1024            world.state.procs.insert(
1025                proc_state
1026                    .proc_id
1027                    .rank()
1028                    .expect("proc must be ranked for rank access"),
1029                proc_state,
1030            );
1031        }
1032    }
1033
1034    /// Report the given proc's supervision state. If the proc is not in the map, do nothing.
1035    fn report(&mut self, proc_state: ProcSupervisionState, clock: &impl Clock) {
1036        if World::is_host_world(&proc_state.world_id) {
1037            return;
1038        }
1039
1040        let proc_id = proc_state.proc_id.clone();
1041        match self.supervision_map.entry(proc_state.world_id.clone()) {
1042            Entry::Occupied(mut world_supervision_info) => {
1043                match world_supervision_info
1044                    .get_mut()
1045                    .state
1046                    .procs
1047                    .entry(proc_id.rank().expect("proc must be ranked for rank access"))
1048                {
1049                    Entry::Occupied(_) => {
1050                        self.update(proc_state, clock);
1051                    }
1052                    Entry::Vacant(_) => {
1053                        tracing::error!("supervision not enabled for proc {}", &proc_id);
1054                    }
1055                }
1056            }
1057            Entry::Vacant(_) => {
1058                tracing::error!("supervision not enabled for proc {}", &proc_id);
1059            }
1060        }
1061    }
1062
1063    /// Get procs of a world with expired supervision updates, as well as procs with
1064    /// actor failures.
1065    fn get_world_with_failures(
1066        &mut self,
1067        world_id: &WorldId,
1068        clock: &impl Clock,
1069    ) -> Option<WorldSupervisionState> {
1070        if let Some(world) = self.supervision_map.get_mut(world_id) {
1071            world.heartbeat_record.mark_expired_procs(
1072                &mut world.state,
1073                clock,
1074                self.supervision_update_timeout,
1075            );
1076            // Get procs with failures.
1077            let mut world_state_copy = world.state.clone();
1078            // Only return failed procs if there is any
1079            world_state_copy
1080                .procs
1081                .retain(|_, proc_state| !proc_state.is_healthy());
1082            return Some(world_state_copy);
1083        }
1084        None
1085    }
1086
1087    fn is_world_healthy(&mut self, world_id: &WorldId, clock: &impl Clock) -> bool {
1088        self.get_world_with_failures(world_id, clock)
1089            .is_none_or(|state| WorldSupervisionState::is_healthy(&state))
1090    }
1091}
1092
1093#[derive(Debug, Clone, Serialize, Deserialize)]
1094struct WorldStoppingState {
1095    stopping_procs: HashSet<ProcId>,
1096    stopped_procs: HashSet<ProcId>,
1097}
1098
1099/// A message to stop the system actor.
1100#[derive(Debug, Clone, PartialEq, EnumAsInner)]
1101enum SystemStopMessage {
1102    StopSystemActor,
1103    EvictWorlds(Vec<WorldId>),
1104}
1105
1106/// The system actor manages the whole system. It is responsible for
1107/// managing the systems' worlds, and for managing their constituent
1108/// procs. The system actor also provides a central mailbox that can
1109/// route messages to any live actor in the system.
1110#[derive(Debug, Clone)]
1111#[hyperactor::export(
1112    handlers = [
1113        SystemMessage,
1114        ProcSupervisionMessage,
1115        WorldSupervisionMessage,
1116    ],
1117)]
1118pub struct SystemActor {
1119    params: SystemActorParams,
1120    supervision_state: SystemSupervisionState,
1121    worlds: HashMap<WorldId, World>,
1122    // A map from request id to stop state for inflight stop requests.
1123    worlds_to_stop: HashMap<WorldId, WorldStoppingState>,
1124    shutting_down: bool,
1125}
1126
1127/// The well known ID of the world that hosts the system actor, it is always `system`.
1128pub static SYSTEM_WORLD: LazyLock<WorldId> = LazyLock::new(|| id!(system));
1129
1130/// The well known ID of the system actor, it is always `system[0].root`.
1131static SYSTEM_ACTOR_ID: LazyLock<ActorId> = LazyLock::new(|| id!(system[0].root));
1132
1133/// The ref corresponding to the well known [`ID`].
1134pub static SYSTEM_ACTOR_REF: LazyLock<ActorRef<SystemActor>> =
1135    LazyLock::new(|| ActorRef::attest(id!(system[0].root)));
1136
1137impl SystemActor {
1138    /// Adds a new world that's awaiting creation to the worlds.
1139    fn add_new_world(&mut self, world_id: WorldId) -> Result<(), anyhow::Error> {
1140        let world_state = WorldState {
1141            host_map: HashMap::new(),
1142            procs: HashMap::new(),
1143            status: WorldStatus::AwaitingCreation,
1144        };
1145        let world = World::new(
1146            world_id.clone(),
1147            Shape::Unknown,
1148            world_state,
1149            0,
1150            Environment::Local,
1151            HashMap::new(),
1152        )?;
1153        self.worlds.insert(world_id.clone(), world);
1154        Ok(())
1155    }
1156
1157    fn router(&self) -> &ReportingRouter {
1158        &self.params.mailbox_router
1159    }
1160
1161    /// Bootstrap the system actor. This will create a proc, spawn the actor
1162    /// on that proc, and then return the actor handle and the corresponding
1163    /// proc.
1164    pub async fn bootstrap(
1165        params: SystemActorParams,
1166    ) -> Result<(ActorHandle<SystemActor>, Proc), anyhow::Error> {
1167        Self::bootstrap_with_clock(params, ClockKind::default()).await
1168    }
1169
1170    /// Bootstrap the system actor with a specified clock.This will create a proc, spawn the actor
1171    /// on that proc, and then return the actor handle and the corresponding
1172    /// proc.
1173    pub async fn bootstrap_with_clock(
1174        params: SystemActorParams,
1175        clock: ClockKind,
1176    ) -> Result<(ActorHandle<SystemActor>, Proc), anyhow::Error> {
1177        let system_proc = Proc::new_with_clock(
1178            SYSTEM_ACTOR_ID.proc_id().clone(),
1179            BoxedMailboxSender::new(params.mailbox_router.clone()),
1180            clock,
1181        );
1182        let actor_handle = system_proc
1183            .spawn::<SystemActor>(SYSTEM_ACTOR_ID.name(), params)
1184            .await?;
1185
1186        Ok((actor_handle, system_proc))
1187    }
1188
1189    /// Evict a single world
1190    fn evict_world(&mut self, world_id: &WorldId) {
1191        self.worlds.remove(world_id);
1192        self.supervision_state.supervision_map.remove(world_id);
1193        // Remove all the addresses starting with the world_id as the prefix.
1194        self.params
1195            .mailbox_router
1196            .router
1197            .unbind(&world_id.clone().into());
1198    }
1199}
1200
1201#[async_trait]
1202impl Actor for SystemActor {
1203    type Params = SystemActorParams;
1204
1205    async fn new(params: SystemActorParams) -> Result<Self, anyhow::Error> {
1206        let supervision_update_timeout = params.supervision_update_timeout.clone();
1207        Ok(Self {
1208            params,
1209            supervision_state: SystemSupervisionState::new(supervision_update_timeout),
1210            worlds: HashMap::new(),
1211            worlds_to_stop: HashMap::new(),
1212            shutting_down: false,
1213        })
1214    }
1215
1216    async fn init(&mut self, cx: &Instance<Self>) -> Result<(), anyhow::Error> {
1217        // Start to periodically check the unhealthy worlds.
1218        cx.self_message_with_delay(MaintainWorldHealth {}, Duration::from_secs(0))?;
1219        Ok(())
1220    }
1221
1222    async fn handle_undeliverable_message(
1223        &mut self,
1224        _cx: &Instance<Self>,
1225        Undeliverable(envelope): Undeliverable<MessageEnvelope>,
1226    ) -> Result<(), anyhow::Error> {
1227        let to = envelope.dest().clone();
1228        let from = envelope.sender().clone();
1229        tracing::info!(
1230            "a message from {} to {} was undeliverable and returned to the system actor",
1231            from,
1232            to,
1233        );
1234
1235        // The channel to the receiver's proc is lost or can't be
1236        // established. Update the proc's supervision status
1237        // accordingly.
1238        let proc_id = to.actor_id().proc_id();
1239        let world_id = proc_id
1240            .world_id()
1241            .expect("proc must be ranked for world_id access");
1242        if let Some(world) = &mut self.supervision_state.supervision_map.get_mut(world_id) {
1243            if let Some(proc) = world
1244                .state
1245                .procs
1246                .get_mut(&proc_id.rank().expect("proc must be ranked for rank access"))
1247            {
1248                match proc.proc_health {
1249                    ProcStatus::Alive => proc.proc_health = ProcStatus::ConnectionFailure,
1250                    // Do not overwrite the health of a proc already
1251                    // known to be unhealthy.
1252                    _ => (),
1253                }
1254            } else {
1255                tracing::error!(
1256                    "can't update proc {} status because there isn't one",
1257                    proc_id
1258                );
1259            }
1260        } else {
1261            tracing::error!(
1262                "can't update world {} status because there isn't one",
1263                world_id
1264            );
1265        }
1266        Ok(())
1267    }
1268}
1269
1270///
1271/// +------+  spawns   +----+  joins   +-----+
1272/// | Proc |<----------|Host|--------->|World|
1273/// +------+           +----+          +-----+
1274///    |                                   ^
1275///    |          joins                    |
1276///    +-----------------------------------+
1277/// When bootstrapping the system,
1278///   1. hosts will join the world,
1279///   2. hosts will spawn (worker) procs,
1280///   3. procs will join the world
1281#[async_trait]
1282#[hyperactor::forward(SystemMessage)]
1283impl SystemMessageHandler for SystemActor {
1284    async fn join(
1285        &mut self,
1286        cx: &Context<Self>,
1287        world_id: WorldId,
1288        proc_id: ProcId,
1289        proc_message_port: PortRef<ProcMessage>,
1290        channel_addr: ChannelAddr,
1291        labels: HashMap<String, String>,
1292        lifecycle_mode: ProcLifecycleMode,
1293    ) -> Result<(), anyhow::Error> {
1294        tracing::info!("received join for proc {} in world {}", proc_id, world_id);
1295        // todo: check that proc_id is a user id
1296        self.router()
1297            .router
1298            .bind(proc_id.clone().into(), channel_addr.clone());
1299
1300        self.router().broadcast_addr(&proc_id, channel_addr.clone());
1301
1302        // TODO: handle potential undeliverable message return
1303        self.router().serialize_and_send(
1304            &proc_message_port,
1305            ProcMessage::Joined(),
1306            monitored_return_handle(),
1307        )?;
1308
1309        if lifecycle_mode.is_managed() {
1310            self.supervision_state.create(
1311                ProcSupervisionState {
1312                    world_id: world_id.clone(),
1313                    proc_id: proc_id.clone(),
1314                    proc_addr: channel_addr.clone(),
1315                    proc_health: ProcStatus::Alive,
1316                    failed_actors: Vec::new(),
1317                },
1318                lifecycle_mode.clone(),
1319                cx.clock(),
1320            );
1321        }
1322
1323        // If the proc's life cycle is not managed by system actor, system actor
1324        // doesn't need to track it in its "worlds" field.
1325        if lifecycle_mode != ProcLifecycleMode::ManagedBySystem {
1326            tracing::info!("ignoring join for proc {} in world {}", proc_id, world_id);
1327            return Ok(());
1328        }
1329
1330        let world_id = World::get_real_world_id(&world_id);
1331        if !self.worlds.contains_key(&world_id) {
1332            self.add_new_world(world_id.clone())?;
1333        }
1334        let world = self
1335            .worlds
1336            .get_mut(&world_id)
1337            .ok_or(anyhow::anyhow!("failed to get world from map"))?;
1338
1339        match HostId::try_from(proc_id.clone()) {
1340            Ok(host_id) => {
1341                tracing::info!("{}: adding host {}", world_id, host_id);
1342                return world
1343                    .on_host_join(
1344                        host_id,
1345                        proc_message_port,
1346                        &self.params.mailbox_router.router,
1347                    )
1348                    .await
1349                    .map_err(anyhow::Error::from);
1350            }
1351            // If it is not a host ID, it must be a regular proc ID. e.g.
1352            // worker procs spawned by the host proc actor.
1353            Err(_) => {
1354                tracing::info!("proc {} joined to world {}", &proc_id, &world_id,);
1355                // TODO(T207602936) add reconciliation machine to make sure
1356                // 1. only add procs that are created by the host
1357                // 2. retry upon failed proc creation by host.
1358                if let Err(e) = world.add_proc(proc_id.clone(), proc_message_port, labels) {
1359                    tracing::warn!(
1360                        "failed to add proc {} to world {}: {}",
1361                        &proc_id,
1362                        &world_id,
1363                        e
1364                    );
1365                }
1366            }
1367        };
1368        Ok(())
1369    }
1370
1371    async fn upsert_world(
1372        &mut self,
1373        cx: &Context<Self>,
1374        world_id: WorldId,
1375        shape: Shape,
1376        num_procs_per_host: usize,
1377        env: Environment,
1378        labels: HashMap<String, String>,
1379    ) -> Result<(), anyhow::Error> {
1380        tracing::info!("received upsert_world for world {}!", world_id);
1381        match self.worlds.get_mut(&world_id) {
1382            Some(world) => {
1383                tracing::info!("found existing world {}!", world_id);
1384                match &world.state.status {
1385                    WorldStatus::AwaitingCreation => {
1386                        world.scheduler_params.shape = shape;
1387                        world.scheduler_params.num_procs_per_host = num_procs_per_host;
1388                        world.scheduler_params.env = env;
1389                        world.state = WorldState {
1390                            host_map: world.state.host_map.clone(),
1391                            procs: world.state.procs.clone(),
1392                            status: if world.state.procs.len() < world.scheduler_params.num_procs()
1393                                || !self
1394                                    .supervision_state
1395                                    .is_world_healthy(&world_id, cx.clock())
1396                            {
1397                                WorldStatus::Unhealthy(cx.clock().system_time_now())
1398                            } else {
1399                                WorldStatus::Live
1400                            },
1401                        };
1402                        for (k, v) in labels {
1403                            if world.labels.contains_key(&k) {
1404                                anyhow::bail!("cannot overwrite world label: {}", k);
1405                            }
1406                            world.labels.insert(k.clone(), v.clone());
1407                        }
1408                    }
1409                    _ => {
1410                        anyhow::bail!("cannot modify world {}: already exists", world.world_id)
1411                    }
1412                }
1413
1414                world.on_create(&self.params.mailbox_router.router).await?;
1415                tracing::info!(
1416                    "modified parameters to world {} with shape: {:?} and labels {:?}",
1417                    &world.world_id,
1418                    world.scheduler_params.shape,
1419                    world.labels
1420                );
1421            }
1422            None => {
1423                let world = World::new(
1424                    world_id.clone(),
1425                    shape.clone(),
1426                    WorldState {
1427                        host_map: HashMap::new(),
1428                        procs: HashMap::new(),
1429                        status: WorldStatus::Unhealthy(cx.clock().system_time_now()),
1430                    },
1431                    num_procs_per_host,
1432                    env,
1433                    labels,
1434                )?;
1435                tracing::info!("new world {} added with shape: {:?}", world_id, &shape);
1436                self.worlds.insert(world_id, world);
1437            }
1438        };
1439        Ok(())
1440    }
1441
1442    async fn snapshot(
1443        &mut self,
1444        _cx: &Context<Self>,
1445        filter: SystemSnapshotFilter,
1446    ) -> Result<SystemSnapshot, anyhow::Error> {
1447        let world_snapshots = self
1448            .worlds
1449            .iter()
1450            .filter(|(_, world)| filter.world_matches(world))
1451            .map(|(world_id, world)| {
1452                (
1453                    world_id.clone(),
1454                    WorldSnapshot::from_world_filtered(world, &filter),
1455                )
1456            })
1457            .collect();
1458        Ok(SystemSnapshot {
1459            worlds: world_snapshots,
1460            execution_id: hyperactor_telemetry::env::execution_id(),
1461        })
1462    }
1463
1464    async fn stop(
1465        &mut self,
1466        cx: &Context<Self>,
1467        worlds: Option<Vec<WorldId>>,
1468        proc_timeout: Duration,
1469        reply_port: OncePortRef<()>,
1470    ) -> Result<(), anyhow::Error> {
1471        // TODO: this needn't be async
1472
1473        match &worlds {
1474            Some(world_ids) => {
1475                tracing::info!("stopping worlds: {:?}", world_ids);
1476            }
1477            None => {
1478                tracing::info!("stopping system actor and all worlds");
1479                self.shutting_down = true;
1480            }
1481        }
1482
1483        // If there's no worlds left to stop, shutdown now.
1484        if self.worlds.is_empty() && self.shutting_down {
1485            cx.stop()?;
1486            reply_port.send(cx, ())?;
1487            return Ok(());
1488        }
1489
1490        let mut world_ids = vec![];
1491        match &worlds {
1492            Some(worlds) => {
1493                // Stop only the specified worlds
1494                world_ids.extend(worlds.clone().into_iter().collect::<Vec<_>>());
1495            }
1496            None => {
1497                // Stop all worlds
1498                world_ids.extend(
1499                    self.worlds
1500                        .keys()
1501                        .filter(|x| x.name() != "user")
1502                        .cloned()
1503                        .collect::<Vec<_>>(),
1504                );
1505            }
1506        }
1507
1508        for world_id in &world_ids {
1509            if self.worlds_to_stop.contains_key(world_id) || !self.worlds.contains_key(world_id) {
1510                // The world is being stopped already.
1511                continue;
1512            }
1513            self.worlds_to_stop.insert(
1514                world_id.clone(),
1515                WorldStoppingState {
1516                    stopping_procs: HashSet::new(),
1517                    stopped_procs: HashSet::new(),
1518                },
1519            );
1520        }
1521
1522        let all_procs = self
1523            .worlds
1524            .iter()
1525            .filter(|(world_id, _)| match &worlds {
1526                Some(worlds_ids) => worlds_ids.contains(world_id),
1527                None => true,
1528            })
1529            .flat_map(|(_, world)| {
1530                world
1531                    .state
1532                    .host_map
1533                    .iter()
1534                    .map(|(host_id, host)| (host_id.0.clone(), host.proc_message_port.clone()))
1535                    .chain(
1536                        world
1537                            .state
1538                            .procs
1539                            .iter()
1540                            .map(|(proc_id, info)| (proc_id.clone(), info.port_ref.clone())),
1541                    )
1542                    .collect::<Vec<_>>()
1543            })
1544            .collect::<HashMap<_, _>>();
1545
1546        // Send Stop message to all processes known to the system. This is a best
1547        // effort, because the message might fail to deliver due to network
1548        // partition.
1549        for (proc_id, port) in all_procs.into_iter() {
1550            let stopping_state = self
1551                .worlds_to_stop
1552                .get_mut(&World::get_real_world_id(
1553                    proc_id
1554                        .world_id()
1555                        .expect("proc must be ranked for world_id access"),
1556                ))
1557                .unwrap();
1558            if !stopping_state.stopping_procs.insert(proc_id) {
1559                continue;
1560            }
1561
1562            // This is a hack. Due to T214365263, SystemActor cannot get reply
1563            // from a 2-way message when that message is sent from its handler.
1564            // As a result, we set the reply to a handle port, so that reply
1565            // can be processed as a separate message. See Handler<ProcStopResult>
1566            // for how the received reply is further processed.
1567            let reply_to = cx.port::<ProcStopResult>().bind().into_once();
1568            port.send(
1569                cx,
1570                ProcMessage::Stop {
1571                    timeout: proc_timeout,
1572                    reply_to,
1573                },
1574            )?;
1575        }
1576
1577        let stop_msg = match &worlds {
1578            Some(_) => SystemStopMessage::EvictWorlds(world_ids.clone()),
1579            None => SystemStopMessage::StopSystemActor {},
1580        };
1581
1582        // Schedule a message to stop the system actor itself.
1583        cx.self_message_with_delay(stop_msg, Duration::from_secs(8))?;
1584
1585        reply_port.send(cx, ())?;
1586        Ok(())
1587    }
1588}
1589
1590#[async_trait]
1591impl Handler<MaintainWorldHealth> for SystemActor {
1592    async fn handle(&mut self, cx: &Context<Self>, _: MaintainWorldHealth) -> anyhow::Result<()> {
1593        // TODO: this needn't be async
1594
1595        // Find the world with the oldest unhealthy time so we can schedule the next check.
1596        let mut next_check_delay = self.params.world_eviction_timeout;
1597        tracing::debug!("Checking world state. Got {} worlds", self.worlds.len());
1598
1599        for world in self.worlds.values_mut() {
1600            if world.state.status == WorldStatus::AwaitingCreation {
1601                continue;
1602            }
1603
1604            let Some(state) = self
1605                .supervision_state
1606                .get_world_with_failures(&world.world_id, cx.clock())
1607            else {
1608                tracing::debug!("world {} does not have failures, skipping.", world.world_id);
1609                continue;
1610            };
1611
1612            if state.is_healthy() {
1613                tracing::debug!(
1614                    "world {} with procs {:?} is healthy, skipping.",
1615                    world.world_id,
1616                    state
1617                        .procs
1618                        .values()
1619                        .map(|p| p.proc_id.clone())
1620                        .collect::<Vec<_>>()
1621                );
1622                continue;
1623            }
1624            // Some procs are not healthy, check if any of the proc should manage the system.
1625            for (_, proc_state) in state.procs.iter() {
1626                if proc_state.proc_health == ProcStatus::Alive {
1627                    tracing::debug!("proc {} is still alive.", proc_state.proc_id);
1628                    continue;
1629                }
1630                if self
1631                    .supervision_state
1632                    .supervision_map
1633                    .get(&world.world_id)
1634                    .and_then(|world| world.lifecycle_mode.get(&proc_state.proc_id))
1635                    .map_or(true, |mode| *mode != ProcLifecycleMode::ManagingSystem)
1636                {
1637                    tracing::debug!(
1638                        "proc {} with state {} does not manage system.",
1639                        proc_state.proc_id,
1640                        proc_state.proc_health
1641                    );
1642                    continue;
1643                }
1644
1645                tracing::error!(
1646                    "proc {}  is unhealthy, stop the system as the proc manages the system",
1647                    proc_state.proc_id
1648                );
1649
1650                // The proc has expired heartbeating and it manages the lifecycle of system, schedule system stop
1651                let (tx, _) = cx.open_once_port::<()>();
1652                cx.port().send(SystemMessage::Stop {
1653                    worlds: None,
1654                    proc_timeout: Duration::from_secs(5),
1655                    reply_port: tx.bind(),
1656                })?;
1657            }
1658
1659            if world.state.status == WorldStatus::Live {
1660                world.state.status = WorldStatus::Unhealthy(cx.clock().system_time_now());
1661            }
1662
1663            match &world.state.status {
1664                WorldStatus::Unhealthy(last_unhealthy_time) => {
1665                    let elapsed = last_unhealthy_time
1666                        .elapsed()
1667                        .inspect_err(|err| {
1668                            tracing::error!(
1669                                "failed to get elapsed time for unhealthy world {}: {}",
1670                                world.world_id,
1671                                err
1672                            )
1673                        })
1674                        .unwrap_or_else(|_| Duration::from_secs(0));
1675
1676                    if elapsed < self.params.world_eviction_timeout {
1677                        // We can live a bit longer still.
1678                        next_check_delay = std::cmp::min(
1679                            next_check_delay,
1680                            self.params.world_eviction_timeout - elapsed,
1681                        );
1682                    } else {
1683                        next_check_delay = Duration::from_secs(0);
1684                    }
1685                }
1686                _ => {
1687                    tracing::error!(
1688                        "find a failed world {} with healthy state {}",
1689                        world.world_id,
1690                        world.state.status
1691                    );
1692                    continue;
1693                }
1694            }
1695        }
1696        cx.self_message_with_delay(MaintainWorldHealth {}, next_check_delay)?;
1697
1698        Ok(())
1699    }
1700}
1701
1702#[async_trait]
1703impl Handler<ProcSupervisionMessage> for SystemActor {
1704    async fn handle(
1705        &mut self,
1706        cx: &Context<Self>,
1707        msg: ProcSupervisionMessage,
1708    ) -> anyhow::Result<()> {
1709        match msg {
1710            ProcSupervisionMessage::Update(state, reply_port) => {
1711                self.supervision_state.report(state, cx.clock());
1712                let _ = reply_port.send(cx, ());
1713            }
1714        }
1715        Ok(())
1716    }
1717}
1718
1719#[async_trait]
1720impl Handler<WorldSupervisionMessage> for SystemActor {
1721    async fn handle(
1722        &mut self,
1723        cx: &Context<Self>,
1724        msg: WorldSupervisionMessage,
1725    ) -> anyhow::Result<()> {
1726        match msg {
1727            WorldSupervisionMessage::State(world_id, reply_port) => {
1728                let world_state = self
1729                    .supervision_state
1730                    .get_world_with_failures(&world_id, cx.clock());
1731                // TODO: handle potential undeliverable message return
1732                let _ = reply_port.send(cx, world_state);
1733            }
1734        }
1735        Ok(())
1736    }
1737}
1738
1739// Temporary solution to allow SystemMessage::Stop receive replies from 2-way
1740// messages. Can be remove after T214365263 is implemented.
1741#[async_trait]
1742impl Handler<ProcStopResult> for SystemActor {
1743    async fn handle(&mut self, cx: &Context<Self>, msg: ProcStopResult) -> anyhow::Result<()> {
1744        fn stopping_proc_msg<'a>(sprocs: impl Iterator<Item = &'a ProcId>) -> String {
1745            let sprocs = sprocs.collect::<Vec<_>>();
1746            if sprocs.is_empty() {
1747                return "no procs left".to_string();
1748            }
1749            let msg = sprocs
1750                .iter()
1751                .take(3)
1752                .map(|proc_id| proc_id.to_string())
1753                .collect::<Vec<_>>()
1754                .join(", ");
1755            if sprocs.len() > 3 {
1756                format!("remaining procs: {} and {} more", msg, sprocs.len() - 3)
1757            } else {
1758                format!("remaining procs: {}", msg)
1759            }
1760        }
1761        let mut world_stopped = false;
1762        let world_id = &msg
1763            .proc_id
1764            .clone()
1765            .into_ranked()
1766            .expect("proc must be ranked for world_id access")
1767            .0;
1768        if let Some(stopping_state) = self.worlds_to_stop.get_mut(world_id) {
1769            stopping_state.stopped_procs.insert(msg.proc_id.clone());
1770            tracing::debug!(
1771                "received stop response from {}: {} stopped actors, {} aborted actors: {}",
1772                msg.proc_id,
1773                msg.actors_stopped,
1774                msg.actors_aborted,
1775                stopping_proc_msg(
1776                    stopping_state
1777                        .stopping_procs
1778                        .difference(&stopping_state.stopped_procs)
1779                ),
1780            );
1781            world_stopped =
1782                stopping_state.stopping_procs.len() == stopping_state.stopped_procs.len();
1783        } else {
1784            tracing::warn!(
1785                "received stop response from {} but no inflight stopping request is found, possibly late response",
1786                msg.proc_id
1787            );
1788        }
1789
1790        if world_stopped {
1791            self.evict_world(world_id);
1792            self.worlds_to_stop.remove(world_id);
1793        }
1794
1795        if self.shutting_down && self.worlds.is_empty() {
1796            cx.stop()?;
1797        }
1798
1799        Ok(())
1800    }
1801}
1802
1803#[async_trait]
1804impl Handler<SystemStopMessage> for SystemActor {
1805    async fn handle(
1806        &mut self,
1807        cx: &Context<Self>,
1808        message: SystemStopMessage,
1809    ) -> anyhow::Result<()> {
1810        match message {
1811            SystemStopMessage::EvictWorlds(world_ids) => {
1812                for world_id in &world_ids {
1813                    if self.worlds_to_stop.contains_key(world_id) {
1814                        tracing::warn!(
1815                            "Waiting for world to stop timed out, evicting world anyways: {:?}",
1816                            world_id
1817                        );
1818                        self.evict_world(world_id);
1819                    }
1820                }
1821            }
1822            SystemStopMessage::StopSystemActor => {
1823                if self.worlds_to_stop.is_empty() {
1824                    tracing::warn!(
1825                        "waiting for all worlds to stop timed out, stopping the system actor and evicting the these worlds anyways: {:?}",
1826                        self.worlds_to_stop.keys()
1827                    );
1828                } else {
1829                    tracing::warn!(
1830                        "waiting for all worlds to stop timed out, stopping the system actor"
1831                    );
1832                }
1833
1834                cx.stop()?;
1835            }
1836        }
1837        Ok(())
1838    }
1839}
1840
1841#[cfg(test)]
1842mod tests {
1843    use std::assert_matches::assert_matches;
1844
1845    use anyhow::Result;
1846    use hyperactor::PortId;
1847    use hyperactor::actor::ActorStatus;
1848    use hyperactor::attrs::Attrs;
1849    use hyperactor::channel;
1850    use hyperactor::channel::ChannelTransport;
1851    use hyperactor::channel::Rx;
1852    use hyperactor::channel::TcpMode;
1853    use hyperactor::clock::Clock;
1854    use hyperactor::clock::RealClock;
1855    use hyperactor::data::Serialized;
1856    use hyperactor::mailbox::Mailbox;
1857    use hyperactor::mailbox::MailboxServer;
1858    use hyperactor::mailbox::MessageEnvelope;
1859    use hyperactor::mailbox::PortHandle;
1860    use hyperactor::mailbox::PortReceiver;
1861    use hyperactor::simnet;
1862    use hyperactor::test_utils::pingpong::PingPongActorParams;
1863
1864    use super::*;
1865    use crate::System;
1866
1867    struct MockHostActor {
1868        local_proc_id: ProcId,
1869        local_proc_addr: ChannelAddr,
1870        local_proc_message_port: PortHandle<ProcMessage>,
1871        local_proc_message_receiver: PortReceiver<ProcMessage>,
1872    }
1873
1874    async fn spawn_mock_host_actor(proc_world_id: WorldId, host_id: usize) -> MockHostActor {
1875        // Set up a local actor.
1876        let local_proc_id = ProcId::Ranked(
1877            WorldId(format!("{}{}", SHADOW_PREFIX, proc_world_id.name())),
1878            host_id,
1879        );
1880        let (local_proc_addr, local_proc_rx) =
1881            channel::serve::<MessageEnvelope>(ChannelAddr::any(ChannelTransport::Local)).unwrap();
1882        let local_proc_mbox = Mailbox::new_detached(local_proc_id.actor_id("test".to_string(), 0));
1883        let (local_proc_message_port, local_proc_message_receiver) = local_proc_mbox.open_port();
1884        let _local_proc_serve_handle = local_proc_mbox.clone().serve(local_proc_rx);
1885        MockHostActor {
1886            local_proc_id,
1887            local_proc_addr,
1888            local_proc_message_port,
1889            local_proc_message_receiver,
1890        }
1891    }
1892
1893    #[tokio::test]
1894    async fn test_supervision_state() {
1895        let mut sv = SystemSupervisionState::new(Duration::from_secs(1));
1896        let world_id = id!(world);
1897        let proc_id_0 = world_id.proc_id(0);
1898        let clock = ClockKind::Real(RealClock);
1899        sv.create(
1900            ProcSupervisionState {
1901                world_id: world_id.clone(),
1902                proc_addr: ChannelAddr::any(ChannelTransport::Local),
1903                proc_id: proc_id_0.clone(),
1904                proc_health: ProcStatus::Alive,
1905                failed_actors: Vec::new(),
1906            },
1907            ProcLifecycleMode::ManagedBySystem,
1908            &clock,
1909        );
1910        let actor_id = id!(world[1].actor);
1911        let proc_id_1 = actor_id.proc_id();
1912        sv.create(
1913            ProcSupervisionState {
1914                world_id: world_id.clone(),
1915                proc_addr: ChannelAddr::any(ChannelTransport::Local),
1916                proc_id: proc_id_1.clone(),
1917                proc_health: ProcStatus::Alive,
1918                failed_actors: Vec::new(),
1919            },
1920            ProcLifecycleMode::ManagedBySystem,
1921            &clock,
1922        );
1923        let world_id = id!(world);
1924
1925        let unknown_world_id = id!(unknow_world);
1926        let failures = sv.get_world_with_failures(&unknown_world_id, &clock);
1927        assert!(failures.is_none());
1928
1929        // No supervision expiration yet.
1930        let failures = sv.get_world_with_failures(&world_id, &clock);
1931        assert!(failures.is_some());
1932        assert_eq!(failures.unwrap().procs.len(), 0);
1933
1934        // One proc expired.
1935        RealClock.sleep(Duration::from_secs(2)).await;
1936        sv.report(
1937            ProcSupervisionState {
1938                world_id: world_id.clone(),
1939                proc_addr: ChannelAddr::any(ChannelTransport::Local),
1940                proc_id: proc_id_1.clone(),
1941                proc_health: ProcStatus::Alive,
1942                failed_actors: Vec::new(),
1943            },
1944            &clock,
1945        );
1946        let failures = sv.get_world_with_failures(&world_id, &clock);
1947        let procs = failures.unwrap().procs;
1948        assert_eq!(procs.len(), 1);
1949        assert!(
1950            procs.contains_key(
1951                &proc_id_0
1952                    .rank()
1953                    .expect("proc must be ranked for rank access")
1954            )
1955        );
1956
1957        // Actor failure happened to proc_1
1958        sv.report(
1959            ProcSupervisionState {
1960                world_id: world_id.clone(),
1961                proc_addr: ChannelAddr::any(ChannelTransport::Local),
1962                proc_id: proc_id_1.clone(),
1963                proc_health: ProcStatus::Alive,
1964                failed_actors: [(
1965                    actor_id.clone(),
1966                    ActorStatus::generic_failure("Actor failed"),
1967                )]
1968                .to_vec(),
1969            },
1970            &clock,
1971        );
1972
1973        let failures = sv.get_world_with_failures(&world_id, &clock);
1974        let procs = failures.unwrap().procs;
1975        assert_eq!(procs.len(), 2);
1976        assert!(
1977            procs.contains_key(
1978                &proc_id_0
1979                    .rank()
1980                    .expect("proc must be ranked for rank access")
1981            )
1982        );
1983        assert!(
1984            procs.contains_key(
1985                &proc_id_1
1986                    .rank()
1987                    .expect("proc must be ranked for rank access")
1988            )
1989        );
1990    }
1991
1992    #[tracing_test::traced_test]
1993    #[tokio::test]
1994    async fn test_host_join_before_world() {
1995        // Spins up a new world with 2 hosts, with 3 procs each.
1996        let params = SystemActorParams::new(Duration::from_secs(10), Duration::from_secs(10));
1997        let (system_actor_handle, _system_proc) = SystemActor::bootstrap(params).await.unwrap();
1998
1999        // Use a local proc actor to join the system.
2000        let mut host_actors: Vec<MockHostActor> = Vec::new();
2001
2002        let world_name = "test".to_string();
2003        let world_id = WorldId(world_name.clone());
2004        host_actors.push(spawn_mock_host_actor(world_id.clone(), 0).await);
2005        host_actors.push(spawn_mock_host_actor(world_id.clone(), 1).await);
2006
2007        for host_actor in host_actors.iter_mut() {
2008            // Join the world.
2009            system_actor_handle
2010                .send(SystemMessage::Join {
2011                    proc_id: host_actor.local_proc_id.clone(),
2012                    world_id: world_id.clone(),
2013                    proc_message_port: host_actor.local_proc_message_port.bind(),
2014                    proc_addr: host_actor.local_proc_addr.clone(),
2015                    labels: HashMap::new(),
2016                    lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
2017                })
2018                .unwrap();
2019
2020            // We should get a joined message.
2021            // and a spawn proc message.
2022            assert_matches!(
2023                host_actor.local_proc_message_receiver.recv().await.unwrap(),
2024                ProcMessage::Joined()
2025            );
2026        }
2027
2028        // Create a new world message and send to system actor
2029        let num_procs = 6;
2030        let shape = Shape::Definite(vec![2, 3]);
2031        system_actor_handle
2032            .send(SystemMessage::UpsertWorld {
2033                world_id: world_id.clone(),
2034                shape,
2035                num_procs_per_host: 3,
2036                env: Environment::Local,
2037                labels: HashMap::new(),
2038            })
2039            .unwrap();
2040
2041        let mut all_procs: Vec<ProcId> = Vec::new();
2042        for host_actor in host_actors.iter_mut() {
2043            let m = host_actor.local_proc_message_receiver.recv().await.unwrap();
2044            match m {
2045                ProcMessage::SpawnProc {
2046                    env,
2047                    world_id,
2048                    mut proc_ids,
2049                    world_size,
2050                } => {
2051                    assert_eq!(world_id, WorldId(world_name.clone()));
2052                    assert_eq!(env, Environment::Local);
2053                    assert_eq!(world_size, num_procs);
2054                    all_procs.append(&mut proc_ids);
2055                }
2056                _ => std::panic!("Unexpected message type!"),
2057            }
2058        }
2059        // Check if all proc ids from 0 to num_procs - 1 are in the list
2060        assert_eq!(all_procs.len(), num_procs);
2061        all_procs.sort();
2062        for (i, proc) in all_procs.iter().enumerate() {
2063            assert_eq!(*proc, ProcId::Ranked(WorldId(world_name.clone()), i));
2064        }
2065    }
2066
2067    #[tokio::test]
2068    async fn test_host_join_after_world() {
2069        // Spins up a new world with 2 hosts, with 3 procs each.
2070        let params = SystemActorParams::new(Duration::from_secs(10), Duration::from_secs(10));
2071        let (system_actor_handle, _system_proc) = SystemActor::bootstrap(params).await.unwrap();
2072
2073        // Create a new world message and send to system actor
2074        let world_name = "test".to_string();
2075        let world_id = WorldId(world_name.clone());
2076        let num_procs = 6;
2077        let shape = Shape::Definite(vec![2, 3]);
2078        system_actor_handle
2079            .send(SystemMessage::UpsertWorld {
2080                world_id: world_id.clone(),
2081                shape,
2082                num_procs_per_host: 3,
2083                env: Environment::Local,
2084                labels: HashMap::new(),
2085            })
2086            .unwrap();
2087
2088        // Use a local proc actor to join the system.
2089        let mut host_actors: Vec<MockHostActor> = Vec::new();
2090
2091        host_actors.push(spawn_mock_host_actor(world_id.clone(), 0).await);
2092        host_actors.push(spawn_mock_host_actor(world_id.clone(), 1).await);
2093
2094        for host_actor in host_actors.iter_mut() {
2095            // Join the world.
2096            system_actor_handle
2097                .send(SystemMessage::Join {
2098                    proc_id: host_actor.local_proc_id.clone(),
2099                    world_id: world_id.clone(),
2100                    proc_message_port: host_actor.local_proc_message_port.bind(),
2101                    proc_addr: host_actor.local_proc_addr.clone(),
2102                    labels: HashMap::new(),
2103                    lifecycle_mode: ProcLifecycleMode::ManagedBySystem,
2104                })
2105                .unwrap();
2106
2107            // We should get a joined message.
2108            // and a spawn proc message.
2109            assert_matches!(
2110                host_actor.local_proc_message_receiver.recv().await.unwrap(),
2111                ProcMessage::Joined()
2112            );
2113        }
2114
2115        let mut all_procs: Vec<ProcId> = Vec::new();
2116        for host_actor in host_actors.iter_mut() {
2117            let m = host_actor.local_proc_message_receiver.recv().await.unwrap();
2118            match m {
2119                ProcMessage::SpawnProc {
2120                    env,
2121                    world_id,
2122                    mut proc_ids,
2123                    world_size,
2124                } => {
2125                    assert_eq!(world_id, WorldId(world_name.clone()));
2126                    assert_eq!(env, Environment::Local);
2127                    assert_eq!(world_size, num_procs);
2128                    all_procs.append(&mut proc_ids);
2129                }
2130                _ => std::panic!("Unexpected message type!"),
2131            }
2132        }
2133        // Check if all proc ids from 0 to num_procs - 1 are in the list
2134        assert_eq!(all_procs.len(), num_procs);
2135        all_procs.sort();
2136        for (i, proc) in all_procs.iter().enumerate() {
2137            assert_eq!(*proc, ProcId::Ranked(WorldId(world_name.clone()), i));
2138        }
2139    }
2140
2141    #[test]
2142    fn test_snapshot_filter() {
2143        let test_world = World::new(
2144            WorldId("test_world".to_string()),
2145            Shape::Definite(vec![1]),
2146            WorldState {
2147                host_map: HashMap::new(),
2148                procs: HashMap::new(),
2149                status: WorldStatus::Live,
2150            },
2151            1,
2152            Environment::Local,
2153            HashMap::from([("foo".to_string(), "bar".to_string())]),
2154        )
2155        .unwrap();
2156        // match all
2157        let filter = SystemSnapshotFilter::all();
2158        assert!(filter.world_matches(&test_world));
2159        assert!(SystemSnapshotFilter::labels_match(
2160            &HashMap::new(),
2161            &HashMap::from([("foo".to_string(), "bar".to_string())])
2162        ));
2163        // specific match
2164        let mut filter = SystemSnapshotFilter::all();
2165        filter.worlds = vec![WorldId("test_world".to_string())];
2166        assert!(filter.world_matches(&test_world));
2167        filter.worlds = vec![WorldId("unknow_world".to_string())];
2168        assert!(!filter.world_matches(&test_world));
2169        assert!(SystemSnapshotFilter::labels_match(
2170            &HashMap::from([("foo".to_string(), "baz".to_string())]),
2171            &HashMap::from([("foo".to_string(), "baz".to_string())]),
2172        ));
2173        assert!(!SystemSnapshotFilter::labels_match(
2174            &HashMap::from([("foo".to_string(), "bar".to_string())]),
2175            &HashMap::from([("foo".to_string(), "baz".to_string())]),
2176        ));
2177    }
2178
2179    #[tokio::test]
2180    async fn test_undeliverable_message_return() {
2181        // System can't send a message to a remote actor because the
2182        // proc connection is lost.
2183        use hyperactor::mailbox::MailboxClient;
2184        use hyperactor::test_utils::pingpong::PingPongActor;
2185        use hyperactor::test_utils::pingpong::PingPongMessage;
2186
2187        use crate::System;
2188        use crate::proc_actor::ProcActor;
2189        use crate::supervision::ProcSupervisor;
2190
2191        // Use temporary config for this test
2192        let config = hyperactor::config::global::lock();
2193        let _guard = config.override_key(
2194            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2195            Duration::from_secs(1),
2196        );
2197
2198        // Serve a system. Undeliverable messages encountered by the
2199        // mailbox server are returned to the system actor.
2200        let server_handle = System::serve(
2201            ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
2202            Duration::from_secs(2), // supervision update timeout
2203            Duration::from_secs(2), // duration to evict an unhealthy world
2204        )
2205        .await
2206        .unwrap();
2207        let system_actor_handle = server_handle.system_actor_handle();
2208        let mut system = System::new(server_handle.local_addr().clone());
2209        let client = system.attach().await.unwrap();
2210
2211        // At this point there are no worlds.
2212        let snapshot = system_actor_handle
2213            .snapshot(&client, SystemSnapshotFilter::all())
2214            .await
2215            .unwrap();
2216        assert_eq!(snapshot.worlds.len(), 0);
2217
2218        // Create one.
2219        let world_id = id!(world);
2220        system_actor_handle
2221            .send(SystemMessage::UpsertWorld {
2222                world_id: world_id.clone(),
2223                shape: Shape::Definite(vec![1]),
2224                num_procs_per_host: 1,
2225                env: Environment::Local,
2226                labels: HashMap::new(),
2227            })
2228            .unwrap();
2229
2230        // Now we should know a world.
2231        let snapshot = system_actor_handle
2232            .snapshot(&client, SystemSnapshotFilter::all())
2233            .await
2234            .unwrap();
2235        assert_eq!(snapshot.worlds.len(), 1);
2236        // Check it's the world we think it is.
2237        assert!(snapshot.worlds.contains_key(&world_id));
2238        // It starts out unhealthy (todo: understand why).
2239        assert!(matches!(
2240            snapshot.worlds.get(&world_id).unwrap().status,
2241            WorldStatus::Unhealthy(_)
2242        ));
2243
2244        // Build a supervisor.
2245        let supervisor = system.attach().await.unwrap();
2246        let (_sup_tx, _sup_rx) = supervisor.bind_actor_port::<ProcSupervisionMessage>();
2247        let sup_ref = ActorRef::<ProcSupervisor>::attest(supervisor.self_id().clone());
2248
2249        // Construct a system sender.
2250        let system_sender = BoxedMailboxSender::new(MailboxClient::new(
2251            channel::dial(server_handle.local_addr().clone()).unwrap(),
2252        ));
2253        // Construct a proc forwarder in terms of the system sender.
2254        let proc_forwarder =
2255            BoxedMailboxSender::new(DialMailboxRouter::new_with_default(system_sender));
2256
2257        // Bootstrap proc 'world[0]', join the system.
2258        let proc_0 = Proc::new(world_id.proc_id(0), proc_forwarder.clone());
2259        let _proc_actor_0 = ProcActor::bootstrap_for_proc(
2260            proc_0.clone(),
2261            world_id.clone(),
2262            ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
2263            server_handle.local_addr().clone(),
2264            sup_ref.clone(),
2265            Duration::from_millis(300), // supervision update interval
2266            HashMap::new(),
2267            ProcLifecycleMode::ManagedBySystem,
2268        )
2269        .await
2270        .unwrap();
2271        let proc_0_client = proc_0.attach("client").unwrap();
2272        let (proc_0_undeliverable_tx, _proc_0_undeliverable_rx) = proc_0_client.open_port();
2273
2274        // Bootstrap a second proc 'world[1]', join the system.
2275        let proc_1 = Proc::new(world_id.proc_id(1), proc_forwarder.clone());
2276        let proc_actor_1 = ProcActor::bootstrap_for_proc(
2277            proc_1.clone(),
2278            world_id.clone(),
2279            ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
2280            server_handle.local_addr().clone(),
2281            sup_ref.clone(),
2282            Duration::from_millis(300), // supervision update interval
2283            HashMap::new(),
2284            ProcLifecycleMode::ManagedBySystem,
2285        )
2286        .await
2287        .unwrap();
2288        let proc_1_client = proc_1.attach("client").unwrap();
2289        let (proc_1_undeliverable_tx, mut _proc_1_undeliverable_rx) = proc_1_client.open_port();
2290
2291        // Spawn two actors 'ping' and 'pong' where 'ping' runs on
2292        // 'world[0]' and 'pong' on 'world[1]' (that is, not on the
2293        // same proc).
2294        let ping_params = PingPongActorParams::new(Some(proc_0_undeliverable_tx.bind()), None);
2295        let ping_handle = proc_0
2296            .spawn::<PingPongActor>("ping", ping_params)
2297            .await
2298            .unwrap();
2299        let pong_params = PingPongActorParams::new(Some(proc_1_undeliverable_tx.bind()), None);
2300        let pong_handle = proc_1
2301            .spawn::<PingPongActor>("pong", pong_params)
2302            .await
2303            .unwrap();
2304
2305        // Now kill pong's mailbox server making message delivery
2306        // between procs impossible.
2307        proc_actor_1.mailbox.stop("from testing");
2308        proc_actor_1.mailbox.await.unwrap().unwrap();
2309
2310        // That in itself shouldn't be a problem. Check the world
2311        // health now.
2312        let snapshot = system_actor_handle
2313            .snapshot(&client, SystemSnapshotFilter::all())
2314            .await
2315            .unwrap();
2316        assert_eq!(snapshot.worlds.len(), 1);
2317        assert!(snapshot.worlds.contains_key(&world_id));
2318        assert_eq!(
2319            snapshot.worlds.get(&world_id).unwrap().status,
2320            WorldStatus::Live
2321        );
2322
2323        // Have 'ping' send 'pong' a message.
2324        let ttl = 1_u64;
2325        let (game_over, on_game_over) = proc_0_client.open_once_port::<bool>();
2326        ping_handle
2327            .send(PingPongMessage(ttl, pong_handle.bind(), game_over.bind()))
2328            .unwrap();
2329
2330        // We expect message delivery failure prevents the game from
2331        // ending within the timeout.
2332        assert!(
2333            RealClock
2334                .timeout(tokio::time::Duration::from_secs(4), on_game_over.recv())
2335                .await
2336                .is_err()
2337        );
2338
2339        // By supervision, we expect the world should have
2340        // transitioned to unhealthy.
2341        let snapshot = system_actor_handle
2342            .snapshot(&client, SystemSnapshotFilter::all())
2343            .await
2344            .unwrap();
2345        assert_eq!(snapshot.worlds.len(), 1);
2346        assert!(matches!(
2347            snapshot.worlds.get(&world_id).unwrap().status,
2348            WorldStatus::Unhealthy(_)
2349        ));
2350    }
2351
2352    #[tokio::test]
2353    async fn test_stop_fast() -> Result<()> {
2354        let server_handle = System::serve(
2355            ChannelAddr::any(ChannelTransport::Tcp(TcpMode::Hostname)),
2356            Duration::from_secs(2), // supervision update timeout
2357            Duration::from_secs(2), // duration to evict an unhealthy world
2358        )
2359        .await?;
2360        let system_actor_handle = server_handle.system_actor_handle();
2361        let mut system = System::new(server_handle.local_addr().clone());
2362        let client = system.attach().await?;
2363
2364        // Create a new world message and send to system actor
2365        let (client_tx, client_rx) = client.open_once_port::<()>();
2366        system_actor_handle.send(SystemMessage::Stop {
2367            worlds: None,
2368            proc_timeout: Duration::from_secs(5),
2369            reply_port: client_tx.bind(),
2370        })?;
2371        client_rx.recv().await?;
2372
2373        // Check that it has stopped.
2374        let mut sys_status_rx = system_actor_handle.status();
2375        {
2376            let received = sys_status_rx.borrow_and_update();
2377            assert_eq!(*received, ActorStatus::Stopped);
2378        }
2379
2380        Ok(())
2381    }
2382
2383    #[tokio::test]
2384    async fn test_update_sim_address() {
2385        simnet::start();
2386
2387        let src_id = id!(proc[0].actor);
2388        let src_addr = ChannelAddr::Sim(SimAddr::new("unix!@src".parse().unwrap()).unwrap());
2389        let dst_addr = ChannelAddr::Sim(SimAddr::new("unix!@dst".parse().unwrap()).unwrap());
2390        let (_, mut rx) = channel::serve::<MessageEnvelope>(src_addr.clone()).unwrap();
2391
2392        let router = ReportingRouter::new();
2393
2394        router
2395            .router
2396            .bind(src_id.proc_id().clone().into(), src_addr);
2397        router.router.bind(id!(proc[1]).into(), dst_addr);
2398
2399        router.post_update_address(&MessageEnvelope::new(
2400            src_id,
2401            PortId(id!(proc[1].actor), 9999u64),
2402            Serialized::serialize(&1u64).unwrap(),
2403            Attrs::new(),
2404        ));
2405
2406        let envelope = rx.recv().await.unwrap();
2407        let admin_msg = envelope
2408            .data()
2409            .deserialized::<MailboxAdminMessage>()
2410            .unwrap();
2411        let MailboxAdminMessage::UpdateAddress {
2412            addr: ChannelAddr::Sim(addr),
2413            ..
2414        } = admin_msg
2415        else {
2416            panic!("Expected sim address");
2417        };
2418
2419        assert_eq!(addr.src().clone().unwrap().to_string(), "unix:@src");
2420        assert_eq!(addr.addr().to_string(), "unix:@dst");
2421    }
2422}