1use std::collections::HashMap;
10use std::collections::HashSet;
11use std::time::Duration;
12
13use anyhow::Result;
14use anyhow::anyhow;
15use clap::Args;
16use clap::Subcommand;
17use const_format::concatcp;
18use hyperactor::Mailbox;
19use hyperactor::actor::ActorHandle;
20use hyperactor::channel::ChannelAddr;
21use hyperactor::clock::Clock;
22use hyperactor::clock::RealClock;
23use hyperactor::mailbox::open_port;
24use hyperactor::reference::ActorId;
25use hyperactor::reference::ActorRef;
26use hyperactor::reference::GangId;
27use hyperactor::reference::Index;
28use hyperactor::reference::WorldId;
29use hyperactor_mesh::comm::CommActor;
30use hyperactor_multiprocess::System;
31use hyperactor_multiprocess::proc_actor::Environment;
32use hyperactor_multiprocess::proc_actor::ProcActor;
33use hyperactor_multiprocess::proc_actor::ProcMessageClient;
34use hyperactor_multiprocess::system_actor;
35use hyperactor_multiprocess::system_actor::ProcLifecycleMode;
36use hyperactor_multiprocess::system_actor::Shape;
37use hyperactor_multiprocess::system_actor::SystemMessageClient;
38use monarch_messages::worker::WorkerParams;
39use pyo3::prelude::*;
40use pyo3::types::PyType;
41use serde::Deserialize;
42use serde::Serialize;
43use tokio::task::JoinHandle;
44
45use crate::ControllerActor;
46use crate::ControllerParams;
47
48pub static MONARCH_LABEL_PREFIX: &str = "monarch.meta.com/";
50static WORKER_LABEL_PREFIX: &str = concatcp!("proc.", MONARCH_LABEL_PREFIX);
52static LABEL_NAME_ROLE: &str = concatcp!(WORKER_LABEL_PREFIX, "role");
54static LABEL_VALUE_ROLE_CONTROLLER: &str = "controller";
56static LABEL_VALUE_ROLE_HOST: &str = "host";
58static LABEL_NAME_WORKER_WORLD: &str = concatcp!(WORKER_LABEL_PREFIX, "workerWorld");
61static COMM_ACTOR_NAME: &str = "comm";
63
64pub static WORLD_LABEL_PREFIX: &str = concatcp!("world.", MONARCH_LABEL_PREFIX);
66static LABEL_NAME_WORKER: &str = concatcp!(WORLD_LABEL_PREFIX, "worker");
69static LABEL_NAME_CONTROLLER_ACTOR_ID: &str = concatcp!(WORLD_LABEL_PREFIX, "controllerActorId");
72
73#[derive(Clone, Debug, Serialize, Deserialize, Args)]
74#[pyclass(module = "monarch._rust_bindings.controller.bootstrap")]
75pub struct ControllerCommand {
76 #[arg(long)]
78 pub worker_world: String,
79
80 #[arg(long)]
82 pub system_addr: String,
83
84 #[arg(long, default_value_t = String::from("controller[0].root"))]
86 pub controller_actor_id: String,
87
88 #[arg(long)]
90 pub world_size: usize,
91
92 #[arg(long, default_value_t = 8)]
94 pub num_procs_per_host: usize,
95
96 #[arg(long, default_value_t = String::from("worker"))]
98 pub worker_name: String,
99
100 #[arg(long)]
103 pub program: Option<String>,
104
105 #[arg(long, default_value_t = 2)]
109 pub supervision_query_interval_in_sec: u64,
110
111 #[arg(long, default_value_t = 2)]
114 pub supervision_update_interval_in_sec: u64,
115
116 #[arg(long, default_value_t = 10)]
119 pub worker_progress_check_interval_in_sec: u64,
120
121 #[arg(long, default_value_t = 120)]
124 pub operation_timeout_in_sec: u64,
125
126 #[arg(long, default_value_t = 100)]
130 pub operations_per_worker_progress_request: u64,
131
132 #[arg(long, default_value_t = false)]
134 pub fail_on_worker_timeout: bool,
135
136 #[arg(long, default_value_t = false)]
138 pub is_cpu_worker: bool,
139
140 #[arg(long, value_parser=parse_key_val)]
142 pub extra_proc_labels: Option<Vec<(String, String)>>,
143}
144
145#[pymethods]
146impl ControllerCommand {
147 #[new]
148 #[pyo3(signature = (*, worker_world, system_addr, controller_actor_id, world_size, num_procs_per_host, worker_name, program, supervision_query_interval_in_sec, supervision_update_interval_in_sec, worker_progress_check_interval_in_sec, operation_timeout_in_sec, operations_per_worker_progress_request, fail_on_worker_timeout, is_cpu_worker, extra_proc_labels))]
149 fn new(
150 worker_world: String,
151 system_addr: String,
152 controller_actor_id: String,
153 world_size: usize,
154 num_procs_per_host: usize,
155 worker_name: String,
156 program: Option<String>,
157 supervision_query_interval_in_sec: u64,
158 supervision_update_interval_in_sec: u64,
159 worker_progress_check_interval_in_sec: u64,
160 operation_timeout_in_sec: u64,
161 operations_per_worker_progress_request: u64,
162 fail_on_worker_timeout: bool,
163 is_cpu_worker: bool,
164 extra_proc_labels: Option<Vec<(String, String)>>,
165 ) -> Self {
166 Self {
167 worker_world,
168 system_addr,
169 controller_actor_id,
170 world_size,
171 num_procs_per_host,
172 worker_name,
173 program,
174 supervision_query_interval_in_sec,
175 supervision_update_interval_in_sec,
176 worker_progress_check_interval_in_sec,
177 operation_timeout_in_sec,
178 operations_per_worker_progress_request,
179 fail_on_worker_timeout,
180 is_cpu_worker,
181 extra_proc_labels,
182 }
183 }
184}
185
186#[derive(Clone, Debug, Serialize, Deserialize, Subcommand)]
190#[pyclass(module = "monarch._rust_bindings.controller.bootstrap")]
191pub enum RunCommand {
192 System {
193 #[arg(long)]
195 system_addr: String,
196
197 #[arg(long, default_value_t = 20)]
200 supervision_update_timeout_in_sec: u64,
201
202 #[arg(long, default_value_t = 10)]
204 world_eviction_timeout_in_sec: u64,
205 },
206
207 Host {
208 #[arg(long)]
210 system_addr: String,
211
212 #[arg(long)]
214 host_world: String,
215
216 #[arg(long)]
218 host_rank: Index,
219
220 #[arg(long, default_value_t = 2)]
223 supervision_update_interval_in_sec: u64,
224 },
225
226 Controller(ControllerCommand),
227}
228
229#[pyclass(frozen, module = "monarch._rust_bindings.controller.bootstrap")]
230#[derive(Debug, Serialize, Deserialize)]
231pub enum ControllerServerRequest {
232 Run(RunCommand),
233 Exit(),
234}
235
236#[pymethods]
237impl ControllerServerRequest {
238 fn to_json(&self) -> PyResult<String> {
239 Ok(serde_json::to_string(self).map_err(|e| anyhow!(e))?)
240 }
241
242 fn __str__(&self) -> String {
243 format!("{:?}", self)
244 }
245}
246
247#[pyclass(frozen, module = "monarch._rust_bindings.controller.bootstrap")]
248#[derive(Debug, Serialize, Deserialize)]
249pub enum ControllerServerResponse {
250 Finished { error: Option<String> },
251}
252
253#[pymethods]
254impl ControllerServerResponse {
255 #[classmethod]
256 fn from_json(_: &Bound<'_, PyType>, json: &str) -> PyResult<Self> {
257 Ok(serde_json::from_str(json).map_err(|e| anyhow!(e))?)
258 }
259
260 fn __str__(&self) -> String {
261 format!("{:?}", self)
262 }
263}
264
265pub fn run(command: RunCommand) -> Result<JoinHandle<Result<(), anyhow::Error>>> {
268 Ok(match command {
269 RunCommand::System {
270 system_addr,
271 supervision_update_timeout_in_sec,
272 world_eviction_timeout_in_sec,
273 } => tokio::spawn(spawn_system(
274 system_addr.parse()?,
275 Duration::from_secs(supervision_update_timeout_in_sec),
276 Duration::from_secs(world_eviction_timeout_in_sec),
277 )),
278 RunCommand::Host {
279 system_addr,
280 host_world,
281 host_rank,
282 supervision_update_interval_in_sec,
283 } => tokio::spawn(spawn_host(
284 system_addr.parse()?,
285 host_world.parse()?,
286 host_rank,
287 Duration::from_secs(supervision_update_interval_in_sec),
288 )),
289 RunCommand::Controller(ControllerCommand {
290 worker_world,
291 system_addr,
292 controller_actor_id,
293 world_size,
294 num_procs_per_host,
295 worker_name,
296 program,
297 supervision_query_interval_in_sec,
298 supervision_update_interval_in_sec,
299 worker_progress_check_interval_in_sec,
300 operation_timeout_in_sec,
301 operations_per_worker_progress_request,
302 is_cpu_worker,
303 extra_proc_labels,
304 fail_on_worker_timeout,
305 }) => tokio::spawn(spawn_controller(
306 system_addr.parse()?,
307 controller_actor_id.parse()?,
308 world_size,
309 num_procs_per_host,
310 worker_world.parse()?,
311 worker_name,
312 program,
313 Duration::from_secs(supervision_query_interval_in_sec),
314 Duration::from_secs(supervision_update_interval_in_sec),
315 Duration::from_secs(worker_progress_check_interval_in_sec),
316 Duration::from_secs(operation_timeout_in_sec),
317 operations_per_worker_progress_request,
318 is_cpu_worker,
319 extra_proc_labels,
320 fail_on_worker_timeout,
321 )),
322 })
323}
324
325async fn spawn_system(
327 system_addr: ChannelAddr,
328 supervision_update_timeout: Duration,
329 world_eviction_timeout: Duration,
330) -> anyhow::Result<()> {
331 tracing::info!("spawning system");
332
333 let handle = System::serve(
334 system_addr.clone(),
335 supervision_update_timeout,
336 world_eviction_timeout,
337 )
338 .await?;
339 tracing::info!("system serve: {}", handle.local_addr());
340
341 handle.system_actor_handle().clone().await;
343
344 tracing::info!("system actor exited");
345
346 Ok(())
347}
348
349#[tracing::instrument(skip_all)]
351async fn spawn_host(
352 system_addr: ChannelAddr,
353 host_world_id: WorldId,
354 host_rank: Index,
355 supervision_update_interval: Duration,
356) -> anyhow::Result<()> {
357 tracing::info!("spawning host actor");
358
359 let proc_id = host_world_id.proc_id(host_rank);
360 let host_addr = ChannelAddr::any(system_addr.transport());
361
362 let bootstrap = ProcActor::bootstrap(
363 proc_id.clone(),
364 host_world_id.clone(),
365 host_addr,
366 system_addr,
367 supervision_update_interval,
368 HashMap::from([(
369 LABEL_NAME_ROLE.to_string(),
370 LABEL_VALUE_ROLE_HOST.to_string(),
371 )]),
372 ProcLifecycleMode::ManagedBySystem,
373 )
374 .await?;
375 tracing::info!(
376 "{}: joined; host actor: {}",
377 proc_id,
378 bootstrap.proc_actor.actor_id()
379 );
380
381 bootstrap.proc_actor.await;
383
384 Ok(())
385}
386
387#[tracing::instrument(skip_all)]
395async fn spawn_controller(
396 system_addr: ChannelAddr,
397 controller_actor_id: ActorId,
398 num_procs: usize,
399 num_procs_per_host: usize,
400 worker_world_id: WorldId,
401 worker_name: String,
402 program: Option<String>,
403 supervision_query_interval: Duration,
404 supervision_update_interval: Duration,
405 worker_progress_check_interval: Duration,
406 operation_timeout: Duration,
407 operations_per_worker_progress_request: u64,
408 is_cpu_worker: bool,
409 extra_proc_labels: Option<Vec<(String, String)>>,
410 fail_on_worker_timeout: bool,
411) -> anyhow::Result<()> {
412 tracing::info!("spawning controller");
413
414 let mut system = hyperactor_multiprocess::System::new(system_addr.clone());
415 let client = system.attach().await.unwrap();
416
417 self::create_world(
418 client.clone(),
419 controller_actor_id.clone(),
420 num_procs,
421 num_procs_per_host,
422 worker_world_id.clone(),
423 program,
424 )
425 .await?;
426 let handle = self::bootstrap_controller(
427 system_addr,
428 None, controller_actor_id.clone(),
430 num_procs,
431 worker_world_id.clone(),
432 worker_name.clone(),
433 supervision_query_interval,
434 supervision_update_interval,
435 worker_progress_check_interval,
436 operation_timeout,
437 operations_per_worker_progress_request,
438 extra_proc_labels,
439 fail_on_worker_timeout,
440 )
441 .await?;
442
443 self::spawn_worker_actors(
444 client.clone(),
445 controller_actor_id.clone(),
446 num_procs,
447 worker_world_id,
448 worker_name,
449 is_cpu_worker,
450 )
451 .await?;
452
453 system_actor::SYSTEM_ACTOR_REF
456 .upsert_world(
457 &client,
458 WorldId(controller_actor_id.world_name().to_string()),
459 Shape::Definite(vec![1]),
460 1,
461 Environment::Local,
462 HashMap::new(),
463 )
464 .await?;
465 tracing::info!(
466 "created new controller world {}",
467 controller_actor_id.world_name()
468 );
469
470 handle.await;
472
473 tracing::info!("controller actor exited");
474
475 Ok(())
476}
477
478pub async fn bootstrap_controller(
482 system_addr: ChannelAddr,
483 listen_addr: Option<ChannelAddr>,
484 controller_actor_id: ActorId,
485 num_procs: usize,
486 worker_world_id: WorldId,
487 worker_name: String,
488 supervision_query_interval: Duration,
489 supervision_update_interval: Duration,
490 worker_progress_check_interval: Duration,
491 operation_timeout: Duration,
492 operations_per_worker_progress_request: u64,
493 extra_controller_labels: Option<Vec<(String, String)>>,
494 fail_on_worker_timeout: bool,
495) -> anyhow::Result<ActorHandle<ProcActor>> {
496 let listen_addr = listen_addr.unwrap_or(ChannelAddr::any(system_addr.transport()));
497 let mut controller_labels = HashMap::from([
498 (
499 LABEL_NAME_ROLE.to_string(),
500 LABEL_VALUE_ROLE_CONTROLLER.to_string(),
501 ),
502 (
503 LABEL_NAME_WORKER_WORLD.to_string(),
504 worker_world_id.to_string(),
505 ),
506 ]);
507 tracing::info!("controller labels: {:?}", extra_controller_labels);
508 if let Some(extra_controller_labels) = extra_controller_labels {
509 controller_labels.extend(extra_controller_labels);
510 }
511 let (handle, actor_ref) = ControllerActor::bootstrap(
512 controller_actor_id.clone(),
513 listen_addr,
514 system_addr,
515 ControllerParams {
516 world_size: num_procs,
517 comm_actor_ref: ActorRef::<CommActor>::attest(
518 controller_actor_id.proc_id().actor_id(COMM_ACTOR_NAME, 0),
519 ),
520 worker_gang_ref: GangId(worker_world_id.clone(), worker_name.clone()).into(),
521 supervision_query_interval,
522 worker_progress_check_interval,
523 operation_timeout,
524 operations_per_worker_progress_request,
525 fail_on_worker_timeout,
526 },
527 supervision_update_interval,
528 controller_labels,
529 )
530 .await?;
531 tracing::info!("controller starts with id: {}", actor_ref.actor_id());
532
533 Ok(handle)
534}
535
536async fn create_world(
537 client: Mailbox,
538 controller_actor_id: ActorId,
539 num_procs: usize,
540 num_procs_per_host: usize,
541 worker_world_id: WorldId,
542 program: Option<String>,
543) -> anyhow::Result<()> {
544 system_actor::SYSTEM_ACTOR_REF
545 .upsert_world(
546 &client,
547 worker_world_id.clone(),
548 Shape::Definite(vec![num_procs]),
549 num_procs_per_host,
550 match program {
551 Some(program) => Environment::Exec { program },
552 None => Environment::Local,
553 },
554 HashMap::from([
555 (LABEL_NAME_WORKER.to_string(), "1".to_string()),
556 (
557 LABEL_NAME_CONTROLLER_ACTOR_ID.to_string(),
558 controller_actor_id.to_string(),
559 ),
560 ]),
561 )
562 .await?;
563 tracing::info!("created new worker world {}", worker_world_id);
564
565 let timeout = hyperactor::config::global::get(hyperactor::config::MESSAGE_DELIVERY_TIMEOUT);
567 tracing::info!("waiting for worker world {} to be alive", worker_world_id);
568 loop {
569 let snapshot = RealClock
570 .timeout(timeout, async {
571 system_actor::SYSTEM_ACTOR_REF
572 .snapshot(
573 &client,
574 system_actor::SystemSnapshotFilter {
575 worlds: vec![worker_world_id.clone()],
576 world_labels: HashMap::new(),
577 proc_labels: HashMap::new(),
578 },
579 )
580 .await
581 })
582 .await?;
583 let snapshot = snapshot?;
584 if let Some(world) = snapshot.worlds.get(&worker_world_id) {
585 if world.status.is_live() {
586 break;
587 }
588 }
589 RealClock.sleep(Duration::from_millis(10)).await;
590 }
591 tracing::info!(
592 "worker world {} is alive; spawning {} worker actors",
593 worker_world_id,
594 num_procs
595 );
596 Ok(())
597}
598
599async fn spawn_worker_actors(
600 client: Mailbox,
601 controller_actor_id: ActorId,
602 num_procs: usize,
603 worker_world_id: WorldId,
604 worker_name: String,
605 is_cpu_worker: bool,
606) -> anyhow::Result<()> {
607 let (spawned_port, mut spawned_receiver) = open_port(&client);
609 for rank in 0..num_procs {
610 let param = WorkerParams {
611 world_size: num_procs,
612 rank,
614 device_index: if is_cpu_worker { None } else { Some(0) },
617 controller_actor: ActorRef::attest(controller_actor_id.clone()),
618 };
619 let worker_proc =
620 ActorRef::<ProcActor>::attest(worker_world_id.proc_id(rank).actor_id("proc", 0));
621
622 worker_proc
623 .spawn(
624 &client,
625 "monarch_tensor_worker::WorkerActor".to_owned(),
627 worker_name.clone(),
628 bincode::serialize(¶m)?,
629 spawned_port.bind(),
630 )
631 .await?;
632 }
633 let mut spawned = HashSet::new();
634 while spawned.len() < num_procs {
635 spawned.insert(spawned_receiver.recv().await?);
636 }
637 tracing::info!("spawned {} worker actors", num_procs);
638
639 Ok(())
640}
641
642pub fn parse_key_val(s: &str) -> anyhow::Result<(String, String)> {
643 match s.split_once('=') {
644 None => Err(anyhow::anyhow!("invalid KEY=value: no `=` found in `{s}`")),
645 Some((a, b)) => Ok((a.to_owned(), b.to_owned())),
646 }
647}
648
649pub fn register_python_bindings(controller_mod: &Bound<'_, PyModule>) -> PyResult<()> {
650 controller_mod.add_class::<ControllerServerRequest>()?;
651 controller_mod.add_class::<ControllerServerResponse>()?;
652 controller_mod.add_class::<RunCommand>()?;
653 controller_mod.add_class::<ControllerCommand>()?;
654 Ok(())
655}