controller/
bootstrap.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::collections::HashSet;
11use std::time::Duration;
12
13use anyhow::Result;
14use anyhow::anyhow;
15use clap::Args;
16use clap::Subcommand;
17use const_format::concatcp;
18use hyperactor::Mailbox;
19use hyperactor::actor::ActorHandle;
20use hyperactor::channel::ChannelAddr;
21use hyperactor::clock::Clock;
22use hyperactor::clock::RealClock;
23use hyperactor::mailbox::open_port;
24use hyperactor::reference::ActorId;
25use hyperactor::reference::ActorRef;
26use hyperactor::reference::GangId;
27use hyperactor::reference::Index;
28use hyperactor::reference::WorldId;
29use hyperactor_mesh::comm::CommActor;
30use hyperactor_multiprocess::System;
31use hyperactor_multiprocess::proc_actor::Environment;
32use hyperactor_multiprocess::proc_actor::ProcActor;
33use hyperactor_multiprocess::proc_actor::ProcMessageClient;
34use hyperactor_multiprocess::system_actor;
35use hyperactor_multiprocess::system_actor::ProcLifecycleMode;
36use hyperactor_multiprocess::system_actor::Shape;
37use hyperactor_multiprocess::system_actor::SystemMessageClient;
38use monarch_messages::worker::WorkerParams;
39use pyo3::prelude::*;
40use pyo3::types::PyType;
41use serde::Deserialize;
42use serde::Serialize;
43use tokio::task::JoinHandle;
44
45use crate::ControllerActor;
46use crate::ControllerParams;
47
48/// Domain name for all monarch reserved labels.
49pub static MONARCH_LABEL_PREFIX: &str = "monarch.meta.com/";
50/// Prefix for all monarch reserved labels for procs.
51static WORKER_LABEL_PREFIX: &str = concatcp!("proc.", MONARCH_LABEL_PREFIX);
52/// Labels suffix indicating the role of a proc.
53static LABEL_NAME_ROLE: &str = concatcp!(WORKER_LABEL_PREFIX, "role");
54/// Label value indicating proc role is controller.
55static LABEL_VALUE_ROLE_CONTROLLER: &str = "controller";
56/// Label value indicating proc role is host.
57static LABEL_VALUE_ROLE_HOST: &str = "host";
58/// Label indicating the worker world for a given controller to allow
59/// for backreferencing.
60static LABEL_NAME_WORKER_WORLD: &str = concatcp!(WORKER_LABEL_PREFIX, "workerWorld");
61/// The global name used for comm actors.
62static COMM_ACTOR_NAME: &str = "comm";
63
64/// Prefix for all monarch reserved labels for worlds.
65pub static WORLD_LABEL_PREFIX: &str = concatcp!("world.", MONARCH_LABEL_PREFIX);
66/// Label indicating if a given world is a worker world. A value of "1" indicates
67/// a worker world. This allows us to query all worker worlds in the system.
68static LABEL_NAME_WORKER: &str = concatcp!(WORLD_LABEL_PREFIX, "worker");
69/// Label indicating the controller actor id for a given worker world. This allows
70/// to query all worker worlds and communcicate with their controllers.
71static LABEL_NAME_CONTROLLER_ACTOR_ID: &str = concatcp!(WORLD_LABEL_PREFIX, "controllerActorId");
72
73#[derive(Clone, Debug, Serialize, Deserialize, Args)]
74#[pyclass(module = "monarch._rust_bindings.controller.bootstrap")]
75pub struct ControllerCommand {
76    /// The worker world to create
77    #[arg(long)]
78    pub worker_world: String,
79
80    /// The system address to bootstrap with.
81    #[arg(long)]
82    pub system_addr: String,
83
84    /// The controller actor id to give to.
85    #[arg(long, default_value_t = String::from("controller[0].root"))]
86    pub controller_actor_id: String,
87
88    // Global world size for this job
89    #[arg(long)]
90    pub world_size: usize,
91
92    /// The number of processes per host.
93    #[arg(long, default_value_t = 8)]
94    pub num_procs_per_host: usize,
95
96    /// The worker name.
97    #[arg(long, default_value_t = String::from("worker"))]
98    pub worker_name: String,
99
100    /// The worker program to execute for each process. It is not needed if worker procs
101    /// are directly launched without management from host actors.
102    #[arg(long)]
103    pub program: Option<String>,
104
105    /// The supervision check interval in seconds. It indicates how often the controller
106    /// will poll system actor to check the status of all procs/actors in a world. This
107    /// decides how fast the client could observe a failure in the system.
108    #[arg(long, default_value_t = 2)]
109    pub supervision_query_interval_in_sec: u64,
110
111    /// The supervision update interval in seconds, it indiciates how often the controller
112    /// proc should report its supervision status to the system.
113    #[arg(long, default_value_t = 2)]
114    pub supervision_update_interval_in_sec: u64,
115
116    /// The worker progress check interval in seconds, it indicates how often the controller
117    /// will check that progress is being made.
118    #[arg(long, default_value_t = 10)]
119    pub worker_progress_check_interval_in_sec: u64,
120
121    /// The operation timeout duration interval in seconds, it indicates how long we will allow
122    /// progress to stall for before letting the client know that worker(s) may be stuck.
123    #[arg(long, default_value_t = 120)]
124    pub operation_timeout_in_sec: u64,
125
126    /// The number of operations invoked before we proactively check worker progress. If a large number
127    /// of operations are invoked all at once, it is expected that it will take a while for all operations
128    /// to complete so we want to inject progress requests at a higher frequency to check if we are making progress
129    #[arg(long, default_value_t = 100)]
130    pub operations_per_worker_progress_request: u64,
131
132    /// If the controller should propagate a failure to the client if the workers become stuck.
133    #[arg(long, default_value_t = false)]
134    pub fail_on_worker_timeout: bool,
135
136    /// If to launch the workers for CPU-only devices.
137    #[arg(long, default_value_t = false)]
138    pub is_cpu_worker: bool,
139
140    /// Proc metadata which will be available through system.
141    #[arg(long, value_parser=parse_key_val)]
142    pub extra_proc_labels: Option<Vec<(String, String)>>,
143}
144
145#[pymethods]
146impl ControllerCommand {
147    #[new]
148    #[pyo3(signature = (*, worker_world, system_addr, controller_actor_id, world_size, num_procs_per_host, worker_name, program, supervision_query_interval_in_sec, supervision_update_interval_in_sec, worker_progress_check_interval_in_sec, operation_timeout_in_sec, operations_per_worker_progress_request, fail_on_worker_timeout, is_cpu_worker, extra_proc_labels))]
149    fn new(
150        worker_world: String,
151        system_addr: String,
152        controller_actor_id: String,
153        world_size: usize,
154        num_procs_per_host: usize,
155        worker_name: String,
156        program: Option<String>,
157        supervision_query_interval_in_sec: u64,
158        supervision_update_interval_in_sec: u64,
159        worker_progress_check_interval_in_sec: u64,
160        operation_timeout_in_sec: u64,
161        operations_per_worker_progress_request: u64,
162        fail_on_worker_timeout: bool,
163        is_cpu_worker: bool,
164        extra_proc_labels: Option<Vec<(String, String)>>,
165    ) -> Self {
166        Self {
167            worker_world,
168            system_addr,
169            controller_actor_id,
170            world_size,
171            num_procs_per_host,
172            worker_name,
173            program,
174            supervision_query_interval_in_sec,
175            supervision_update_interval_in_sec,
176            worker_progress_check_interval_in_sec,
177            operation_timeout_in_sec,
178            operations_per_worker_progress_request,
179            fail_on_worker_timeout,
180            is_cpu_worker,
181            extra_proc_labels,
182        }
183    }
184}
185
186/// The different types of hyperactor to launch based on the subcommands.
187/// The ones for System / Host should probably be moved to the hyperactor
188/// multiprocess crate.
189#[derive(Clone, Debug, Serialize, Deserialize, Subcommand)]
190#[pyclass(module = "monarch._rust_bindings.controller.bootstrap")]
191pub enum RunCommand {
192    System {
193        /// The system address to bootstrap with.
194        #[arg(long)]
195        system_addr: String,
196
197        /// The supervision update timeout in seconds. A proc is considered dead if system
198        /// doesn't get any supervision update from it within this timeout.
199        #[arg(long, default_value_t = 20)]
200        supervision_update_timeout_in_sec: u64,
201
202        /// Evict a world if it has been unhealthy for this many seconds.
203        #[arg(long, default_value_t = 10)]
204        world_eviction_timeout_in_sec: u64,
205    },
206
207    Host {
208        /// The system address to bootstrap with.
209        #[arg(long)]
210        system_addr: String,
211
212        /// The host world to create.
213        #[arg(long)]
214        host_world: String,
215
216        /// The host rank; i.e., the index of the host in the world.
217        #[arg(long)]
218        host_rank: Index,
219
220        /// The supervision update interval in seconds, it indiciates how often a proc should
221        /// report its supervision status to the system.
222        #[arg(long, default_value_t = 2)]
223        supervision_update_interval_in_sec: u64,
224    },
225
226    Controller(ControllerCommand),
227}
228
229#[pyclass(frozen, module = "monarch._rust_bindings.controller.bootstrap")]
230#[derive(Debug, Serialize, Deserialize)]
231pub enum ControllerServerRequest {
232    Run(RunCommand),
233    Exit(),
234}
235
236#[pymethods]
237impl ControllerServerRequest {
238    fn to_json(&self) -> PyResult<String> {
239        Ok(serde_json::to_string(self).map_err(|e| anyhow!(e))?)
240    }
241
242    fn __str__(&self) -> String {
243        format!("{:?}", self)
244    }
245}
246
247#[pyclass(frozen, module = "monarch._rust_bindings.controller.bootstrap")]
248#[derive(Debug, Serialize, Deserialize)]
249pub enum ControllerServerResponse {
250    Finished { error: Option<String> },
251}
252
253#[pymethods]
254impl ControllerServerResponse {
255    #[classmethod]
256    fn from_json(_: &Bound<'_, PyType>, json: &str) -> PyResult<Self> {
257        Ok(serde_json::from_str(json).map_err(|e| anyhow!(e))?)
258    }
259
260    fn __str__(&self) -> String {
261        format!("{:?}", self)
262    }
263}
264
265/// A helper function to launch the system, host, or controller actors.
266/// Returns the handle to be waited on.
267pub fn run(command: RunCommand) -> Result<JoinHandle<Result<(), anyhow::Error>>> {
268    Ok(match command {
269        RunCommand::System {
270            system_addr,
271            supervision_update_timeout_in_sec,
272            world_eviction_timeout_in_sec,
273        } => tokio::spawn(spawn_system(
274            system_addr.parse()?,
275            Duration::from_secs(supervision_update_timeout_in_sec),
276            Duration::from_secs(world_eviction_timeout_in_sec),
277        )),
278        RunCommand::Host {
279            system_addr,
280            host_world,
281            host_rank,
282            supervision_update_interval_in_sec,
283        } => tokio::spawn(spawn_host(
284            system_addr.parse()?,
285            host_world.parse()?,
286            host_rank,
287            Duration::from_secs(supervision_update_interval_in_sec),
288        )),
289        RunCommand::Controller(ControllerCommand {
290            worker_world,
291            system_addr,
292            controller_actor_id,
293            world_size,
294            num_procs_per_host,
295            worker_name,
296            program,
297            supervision_query_interval_in_sec,
298            supervision_update_interval_in_sec,
299            worker_progress_check_interval_in_sec,
300            operation_timeout_in_sec,
301            operations_per_worker_progress_request,
302            is_cpu_worker,
303            extra_proc_labels,
304            fail_on_worker_timeout,
305        }) => tokio::spawn(spawn_controller(
306            system_addr.parse()?,
307            controller_actor_id.parse()?,
308            world_size,
309            num_procs_per_host,
310            worker_world.parse()?,
311            worker_name,
312            program,
313            Duration::from_secs(supervision_query_interval_in_sec),
314            Duration::from_secs(supervision_update_interval_in_sec),
315            Duration::from_secs(worker_progress_check_interval_in_sec),
316            Duration::from_secs(operation_timeout_in_sec),
317            operations_per_worker_progress_request,
318            is_cpu_worker,
319            extra_proc_labels,
320            fail_on_worker_timeout,
321        )),
322    })
323}
324
325/// Spawn the system actor
326async fn spawn_system(
327    system_addr: ChannelAddr,
328    supervision_update_timeout: Duration,
329    world_eviction_timeout: Duration,
330) -> anyhow::Result<()> {
331    tracing::info!("spawning system");
332
333    let handle = System::serve(
334        system_addr.clone(),
335        supervision_update_timeout,
336        world_eviction_timeout,
337    )
338    .await?;
339    tracing::info!("system serve: {}", handle.local_addr());
340
341    // This will not end until the system actor is stopped.
342    handle.system_actor_handle().clone().await;
343
344    tracing::info!("system actor exited");
345
346    Ok(())
347}
348
349/// Spawn the host actor
350#[tracing::instrument(skip_all)]
351async fn spawn_host(
352    system_addr: ChannelAddr,
353    host_world_id: WorldId,
354    host_rank: Index,
355    supervision_update_interval: Duration,
356) -> anyhow::Result<()> {
357    tracing::info!("spawning host actor");
358
359    let proc_id = host_world_id.proc_id(host_rank);
360    let host_addr = ChannelAddr::any(system_addr.transport());
361
362    let bootstrap = ProcActor::bootstrap(
363        proc_id.clone(),
364        host_world_id.clone(),
365        host_addr,
366        system_addr,
367        supervision_update_interval,
368        HashMap::from([(
369            LABEL_NAME_ROLE.to_string(),
370            LABEL_VALUE_ROLE_HOST.to_string(),
371        )]),
372        ProcLifecycleMode::ManagedBySystem,
373    )
374    .await?;
375    tracing::info!(
376        "{}: joined; host actor: {}",
377        proc_id,
378        bootstrap.proc_actor.actor_id()
379    );
380
381    // This will not end until the proc actor is stopped.
382    bootstrap.proc_actor.await;
383
384    Ok(())
385}
386
387/// Spawn the controller actor. The order of bootstrap is:
388/// 1. Create the new worker world.
389/// 2. Check if the worker world is alive
390/// 3. Spawn the controller proc and actor.
391/// 4. Spawn all the worker actors and wait for them to be ready.
392/// 5. Create the new controller world. The client is able to send traffic
393///    only after both the controller and worker worlds are alive.
394#[tracing::instrument(skip_all)]
395async fn spawn_controller(
396    system_addr: ChannelAddr,
397    controller_actor_id: ActorId,
398    num_procs: usize,
399    num_procs_per_host: usize,
400    worker_world_id: WorldId,
401    worker_name: String,
402    program: Option<String>,
403    supervision_query_interval: Duration,
404    supervision_update_interval: Duration,
405    worker_progress_check_interval: Duration,
406    operation_timeout: Duration,
407    operations_per_worker_progress_request: u64,
408    is_cpu_worker: bool,
409    extra_proc_labels: Option<Vec<(String, String)>>,
410    fail_on_worker_timeout: bool,
411) -> anyhow::Result<()> {
412    tracing::info!("spawning controller");
413
414    let mut system = hyperactor_multiprocess::System::new(system_addr.clone());
415    let client = system.attach().await.unwrap();
416
417    self::create_world(
418        client.clone(),
419        controller_actor_id.clone(),
420        num_procs,
421        num_procs_per_host,
422        worker_world_id.clone(),
423        program,
424    )
425    .await?;
426    let handle = self::bootstrap_controller(
427        system_addr,
428        None, // listen_addr
429        controller_actor_id.clone(),
430        num_procs,
431        worker_world_id.clone(),
432        worker_name.clone(),
433        supervision_query_interval,
434        supervision_update_interval,
435        worker_progress_check_interval,
436        operation_timeout,
437        operations_per_worker_progress_request,
438        extra_proc_labels,
439        fail_on_worker_timeout,
440    )
441    .await?;
442
443    self::spawn_worker_actors(
444        client.clone(),
445        controller_actor_id.clone(),
446        num_procs,
447        worker_world_id,
448        worker_name,
449        is_cpu_worker,
450    )
451    .await?;
452
453    // Controller will join its own world.
454    // This will announce itself as live so the client can observe it.
455    system_actor::SYSTEM_ACTOR_REF
456        .upsert_world(
457            &client,
458            WorldId(controller_actor_id.world_name().to_string()),
459            Shape::Definite(vec![1]),
460            1,
461            Environment::Local,
462            HashMap::new(),
463        )
464        .await?;
465    tracing::info!(
466        "created new controller world {}",
467        controller_actor_id.world_name()
468    );
469
470    // This will not end until the system actor is stopped.
471    handle.await;
472
473    tracing::info!("controller actor exited");
474
475    Ok(())
476}
477
478/// Bootstraps the controller actor.
479/// Listen address is optional. If not provided, it will be assigned with a random available
480/// address that has the same transport as the system address.
481pub async fn bootstrap_controller(
482    system_addr: ChannelAddr,
483    listen_addr: Option<ChannelAddr>,
484    controller_actor_id: ActorId,
485    num_procs: usize,
486    worker_world_id: WorldId,
487    worker_name: String,
488    supervision_query_interval: Duration,
489    supervision_update_interval: Duration,
490    worker_progress_check_interval: Duration,
491    operation_timeout: Duration,
492    operations_per_worker_progress_request: u64,
493    extra_controller_labels: Option<Vec<(String, String)>>,
494    fail_on_worker_timeout: bool,
495) -> anyhow::Result<ActorHandle<ProcActor>> {
496    let listen_addr = listen_addr.unwrap_or(ChannelAddr::any(system_addr.transport()));
497    let mut controller_labels = HashMap::from([
498        (
499            LABEL_NAME_ROLE.to_string(),
500            LABEL_VALUE_ROLE_CONTROLLER.to_string(),
501        ),
502        (
503            LABEL_NAME_WORKER_WORLD.to_string(),
504            worker_world_id.to_string(),
505        ),
506    ]);
507    tracing::info!("controller labels: {:?}", extra_controller_labels);
508    if let Some(extra_controller_labels) = extra_controller_labels {
509        controller_labels.extend(extra_controller_labels);
510    }
511    let (handle, actor_ref) = ControllerActor::bootstrap(
512        controller_actor_id.clone(),
513        listen_addr,
514        system_addr,
515        ControllerParams {
516            world_size: num_procs,
517            comm_actor_ref: ActorRef::<CommActor>::attest(
518                controller_actor_id.proc_id().actor_id(COMM_ACTOR_NAME, 0),
519            ),
520            worker_gang_ref: GangId(worker_world_id.clone(), worker_name.clone()).into(),
521            supervision_query_interval,
522            worker_progress_check_interval,
523            operation_timeout,
524            operations_per_worker_progress_request,
525            fail_on_worker_timeout,
526        },
527        supervision_update_interval,
528        controller_labels,
529    )
530    .await?;
531    tracing::info!("controller starts with id: {}", actor_ref.actor_id());
532
533    Ok(handle)
534}
535
536async fn create_world(
537    client: Mailbox,
538    controller_actor_id: ActorId,
539    num_procs: usize,
540    num_procs_per_host: usize,
541    worker_world_id: WorldId,
542    program: Option<String>,
543) -> anyhow::Result<()> {
544    system_actor::SYSTEM_ACTOR_REF
545        .upsert_world(
546            &client,
547            worker_world_id.clone(),
548            Shape::Definite(vec![num_procs]),
549            num_procs_per_host,
550            match program {
551                Some(program) => Environment::Exec { program },
552                None => Environment::Local,
553            },
554            HashMap::from([
555                (LABEL_NAME_WORKER.to_string(), "1".to_string()),
556                (
557                    LABEL_NAME_CONTROLLER_ACTOR_ID.to_string(),
558                    controller_actor_id.to_string(),
559                ),
560            ]),
561        )
562        .await?;
563    tracing::info!("created new worker world {}", worker_world_id);
564
565    // Wait for all the worker procs to join the worker world.
566    let timeout = hyperactor::config::global::get(hyperactor::config::MESSAGE_DELIVERY_TIMEOUT);
567    tracing::info!("waiting for worker world {} to be alive", worker_world_id);
568    loop {
569        let snapshot = RealClock
570            .timeout(timeout, async {
571                system_actor::SYSTEM_ACTOR_REF
572                    .snapshot(
573                        &client,
574                        system_actor::SystemSnapshotFilter {
575                            worlds: vec![worker_world_id.clone()],
576                            world_labels: HashMap::new(),
577                            proc_labels: HashMap::new(),
578                        },
579                    )
580                    .await
581            })
582            .await?;
583        let snapshot = snapshot?;
584        if let Some(world) = snapshot.worlds.get(&worker_world_id) {
585            if world.status.is_live() {
586                break;
587            }
588        }
589        RealClock.sleep(Duration::from_millis(10)).await;
590    }
591    tracing::info!(
592        "worker world {} is alive; spawning {} worker actors",
593        worker_world_id,
594        num_procs
595    );
596    Ok(())
597}
598
599async fn spawn_worker_actors(
600    client: Mailbox,
601    controller_actor_id: ActorId,
602    num_procs: usize,
603    worker_world_id: WorldId,
604    worker_name: String,
605    is_cpu_worker: bool,
606) -> anyhow::Result<()> {
607    // Bootstrap worker actors and wait for them to be ready.
608    let (spawned_port, mut spawned_receiver) = open_port(&client);
609    for rank in 0..num_procs {
610        let param = WorkerParams {
611            world_size: num_procs,
612            // Rank assignment is consistent with proc indices.
613            rank,
614            // TODO: We never use device index during Monarch bootstrap.
615            // Instead, CUDA_VISIBLE_DEVICES is used for workers to access CUDA devices.
616            device_index: if is_cpu_worker { None } else { Some(0) },
617            controller_actor: ActorRef::attest(controller_actor_id.clone()),
618        };
619        let worker_proc =
620            ActorRef::<ProcActor>::attest(worker_world_id.proc_id(rank).actor_id("proc", 0));
621
622        worker_proc
623            .spawn(
624                &client,
625                // Use explicit actor type to avoid the WorkActor dependency.
626                "monarch_tensor_worker::WorkerActor".to_owned(),
627                worker_name.clone(),
628                bincode::serialize(&param)?,
629                spawned_port.bind(),
630            )
631            .await?;
632    }
633    let mut spawned = HashSet::new();
634    while spawned.len() < num_procs {
635        spawned.insert(spawned_receiver.recv().await?);
636    }
637    tracing::info!("spawned {} worker actors", num_procs);
638
639    Ok(())
640}
641
642pub fn parse_key_val(s: &str) -> anyhow::Result<(String, String)> {
643    match s.split_once('=') {
644        None => Err(anyhow::anyhow!("invalid KEY=value: no `=` found in `{s}`")),
645        Some((a, b)) => Ok((a.to_owned(), b.to_owned())),
646    }
647}
648
649pub fn register_python_bindings(controller_mod: &Bound<'_, PyModule>) -> PyResult<()> {
650    controller_mod.add_class::<ControllerServerRequest>()?;
651    controller_mod.add_class::<ControllerServerResponse>()?;
652    controller_mod.add_class::<RunCommand>()?;
653    controller_mod.add_class::<ControllerCommand>()?;
654    Ok(())
655}