hyperactor_mesh/
bootstrap.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::sync::Arc;
10use std::time::Duration;
11
12use hyperactor::ActorRef;
13use hyperactor::Named;
14use hyperactor::ProcId;
15use hyperactor::channel;
16use hyperactor::channel::ChannelAddr;
17use hyperactor::channel::ChannelTransport;
18use hyperactor::channel::Rx;
19use hyperactor::channel::Tx;
20use hyperactor::clock::Clock;
21use hyperactor::clock::RealClock;
22use hyperactor::mailbox::MailboxServer;
23use hyperactor::proc::Proc;
24use serde::Deserialize;
25use serde::Serialize;
26use tokio::sync::Mutex;
27
28use crate::proc_mesh::mesh_agent::MeshAgent;
29
30pub const BOOTSTRAP_ADDR_ENV: &str = "HYPERACTOR_MESH_BOOTSTRAP_ADDR";
31pub const BOOTSTRAP_INDEX_ENV: &str = "HYPERACTOR_MESH_INDEX";
32pub const CLIENT_TRACE_ID_ENV: &str = "MONARCH_CLIENT_TRACE_ID";
33/// A channel used by each process to receive its own stdout and stderr
34/// Because stdout and stderr can only be obtained by the parent process,
35/// they need to be streamed back to the process.
36pub(crate) const BOOTSTRAP_LOG_CHANNEL: &str = "BOOTSTRAP_LOG_CHANNEL";
37
38/// Messages sent from the process to the allocator. This is an envelope
39/// containing the index of the process (i.e., its "address" assigned by
40/// the allocator), along with the control message in question.
41#[derive(Debug, Clone, Serialize, Deserialize, Named)]
42pub(crate) struct Process2Allocator(pub usize, pub Process2AllocatorMessage);
43
44/// Control messages sent from processes to the allocator.
45#[derive(Debug, Clone, Serialize, Deserialize, Named)]
46pub(crate) enum Process2AllocatorMessage {
47    /// Initialize a process2allocator session. The process is
48    /// listening on the provided channel address, to which
49    /// [`Allocator2Process`] messages are sent.
50    Hello(ChannelAddr),
51
52    /// A proc with the provided ID was started. Its mailbox is
53    /// served at the provided channel address. Procs are started
54    /// after instruction by the allocator through the corresponding
55    /// [`Allocator2Process`] message.
56    StartedProc(ProcId, ActorRef<MeshAgent>, ChannelAddr),
57
58    Heartbeat,
59}
60
61/// Messages sent from the allocator to a process.
62#[derive(Debug, Clone, Serialize, Deserialize, Named)]
63pub(crate) enum Allocator2Process {
64    /// Request to start a new proc with the provided ID, listening
65    /// to an address on the indicated channel transport.
66    StartProc(ProcId, ChannelTransport),
67
68    /// A request for the process to shut down its procs and exit the
69    /// process with the provided code.
70    StopAndExit(i32),
71
72    /// A request for the process to immediately exit with the provided
73    /// exit code
74    Exit(i32),
75}
76
77async fn exit_if_missed_heartbeat(bootstrap_index: usize, bootstrap_addr: ChannelAddr) {
78    let tx = match channel::dial(bootstrap_addr.clone()) {
79        Ok(tx) => tx,
80
81        Err(err) => {
82            tracing::error!(
83                "Failed to establish heartbeat connection to allocator, exiting! (addr: {:?}): {}",
84                bootstrap_addr,
85                err
86            );
87            std::process::exit(1);
88        }
89    };
90    tracing::info!(
91        "Heartbeat connection established to allocator (idx: {bootstrap_index}, addr: {bootstrap_addr:?})",
92    );
93    loop {
94        RealClock.sleep(Duration::from_secs(5)).await;
95
96        let result = tx
97            .send(Process2Allocator(
98                bootstrap_index,
99                Process2AllocatorMessage::Heartbeat,
100            ))
101            .await;
102
103        if let Err(err) = result {
104            tracing::error!(
105                "Heartbeat failed to allocator, exiting! (addr: {:?}): {}",
106                bootstrap_addr,
107                err
108            );
109            std::process::exit(1);
110        }
111    }
112}
113
114/// Entry point to processes managed by hyperactor_mesh. This advertises the process
115/// to a bootstrap server, and receives instructions to manage the lifecycle(s) of
116/// procs within this process.
117///
118/// If bootstrap returns any error, it is defunct from the point of view of hyperactor_mesh,
119/// and the process should likely exit:
120///
121/// ```ignore
122/// let err = hyperactor_mesh::bootstrap().await;
123/// tracing::error("could not bootstrap mesh process: {}", err);
124/// std::process::exit(1);
125/// ```
126///
127/// Use [`bootstrap_or_die`] to implement this behavior directly.
128pub async fn bootstrap() -> anyhow::Error {
129    pub async fn go() -> Result<(), anyhow::Error> {
130        let procs = Arc::new(Mutex::new(Vec::<Proc>::new()));
131        let procs_for_cleanup = procs.clone();
132        let _cleanup_guard = hyperactor::register_signal_cleanup_scoped(Box::pin(async move {
133            for proc_to_stop in procs_for_cleanup.lock().await.iter_mut() {
134                if let Err(err) = proc_to_stop
135                    .destroy_and_wait(Duration::from_millis(10), None)
136                    .await
137                {
138                    tracing::error!(
139                        "error while stopping proc {}: {}",
140                        proc_to_stop.proc_id(),
141                        err
142                    );
143                }
144            }
145        }));
146
147        let bootstrap_addr: ChannelAddr = std::env::var(BOOTSTRAP_ADDR_ENV)
148            .map_err(|err| anyhow::anyhow!("read `{}`: {}", BOOTSTRAP_ADDR_ENV, err))?
149            .parse()?;
150        let bootstrap_index: usize = std::env::var(BOOTSTRAP_INDEX_ENV)
151            .map_err(|err| anyhow::anyhow!("read `{}`: {}", BOOTSTRAP_INDEX_ENV, err))?
152            .parse()?;
153        let listen_addr = ChannelAddr::any(bootstrap_addr.transport());
154        let (serve_addr, mut rx) = channel::serve(listen_addr).await?;
155        let tx = channel::dial(bootstrap_addr.clone())?;
156
157        tx.send(Process2Allocator(
158            bootstrap_index,
159            Process2AllocatorMessage::Hello(serve_addr),
160        ))
161        .await?;
162
163        tokio::spawn(exit_if_missed_heartbeat(bootstrap_index, bootstrap_addr));
164
165        loop {
166            let _ = hyperactor::tracing::info_span!("wait_for_next_message_from_mesh_agent");
167            match rx.recv().await? {
168                Allocator2Process::StartProc(proc_id, listen_transport) => {
169                    let (proc, mesh_agent) = MeshAgent::bootstrap(proc_id.clone()).await?;
170                    let (proc_addr, proc_rx) =
171                        channel::serve(ChannelAddr::any(listen_transport)).await?;
172                    let handle = proc.clone().serve(proc_rx);
173                    drop(handle); // linter appeasement; it is safe to drop this future
174                    tx.send(Process2Allocator(
175                        bootstrap_index,
176                        Process2AllocatorMessage::StartedProc(
177                            proc_id.clone(),
178                            mesh_agent.bind(),
179                            proc_addr,
180                        ),
181                    ))
182                    .await?;
183                    procs.lock().await.push(proc);
184                }
185                Allocator2Process::StopAndExit(code) => {
186                    tracing::info!("stopping procs with code {code}");
187                    {
188                        for proc_to_stop in procs.lock().await.iter_mut() {
189                            if let Err(err) = proc_to_stop
190                                .destroy_and_wait(Duration::from_millis(10), None)
191                                .await
192                            {
193                                tracing::error!(
194                                    "error while stopping proc {}: {}",
195                                    proc_to_stop.proc_id(),
196                                    err
197                                );
198                            }
199                        }
200                    }
201                    tracing::info!("exiting with {code}");
202                    std::process::exit(code);
203                }
204                Allocator2Process::Exit(code) => {
205                    tracing::info!("exiting with {code}");
206                    std::process::exit(code);
207                }
208            }
209        }
210    }
211
212    go().await.unwrap_err()
213}
214
215/// A variant of [`bootstrap`] that logs the error and exits the process
216/// if bootstrapping fails.
217pub async fn bootstrap_or_die() -> ! {
218    let err = bootstrap().await;
219    tracing::error!("failed to bootstrap mesh process: {}", err);
220    std::process::exit(1)
221}