hyperactor_mesh/
bootstrap.rs1use std::sync::Arc;
10use std::time::Duration;
11
12use hyperactor::ActorRef;
13use hyperactor::Named;
14use hyperactor::ProcId;
15use hyperactor::channel;
16use hyperactor::channel::ChannelAddr;
17use hyperactor::channel::ChannelTransport;
18use hyperactor::channel::Rx;
19use hyperactor::channel::Tx;
20use hyperactor::clock::Clock;
21use hyperactor::clock::RealClock;
22use hyperactor::mailbox::MailboxServer;
23use hyperactor::proc::Proc;
24use serde::Deserialize;
25use serde::Serialize;
26use tokio::sync::Mutex;
27
28use crate::proc_mesh::mesh_agent::MeshAgent;
29
30pub const BOOTSTRAP_ADDR_ENV: &str = "HYPERACTOR_MESH_BOOTSTRAP_ADDR";
31pub const BOOTSTRAP_INDEX_ENV: &str = "HYPERACTOR_MESH_INDEX";
32pub const CLIENT_TRACE_ID_ENV: &str = "MONARCH_CLIENT_TRACE_ID";
33pub(crate) const BOOTSTRAP_LOG_CHANNEL: &str = "BOOTSTRAP_LOG_CHANNEL";
37
38#[derive(Debug, Clone, Serialize, Deserialize, Named)]
42pub(crate) struct Process2Allocator(pub usize, pub Process2AllocatorMessage);
43
44#[derive(Debug, Clone, Serialize, Deserialize, Named)]
46pub(crate) enum Process2AllocatorMessage {
47 Hello(ChannelAddr),
51
52 StartedProc(ProcId, ActorRef<MeshAgent>, ChannelAddr),
57
58 Heartbeat,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, Named)]
63pub(crate) enum Allocator2Process {
64 StartProc(ProcId, ChannelTransport),
67
68 StopAndExit(i32),
71
72 Exit(i32),
75}
76
77async fn exit_if_missed_heartbeat(bootstrap_index: usize, bootstrap_addr: ChannelAddr) {
78 let tx = match channel::dial(bootstrap_addr.clone()) {
79 Ok(tx) => tx,
80
81 Err(err) => {
82 tracing::error!(
83 "Failed to establish heartbeat connection to allocator, exiting! (addr: {:?}): {}",
84 bootstrap_addr,
85 err
86 );
87 std::process::exit(1);
88 }
89 };
90 tracing::info!(
91 "Heartbeat connection established to allocator (idx: {bootstrap_index}, addr: {bootstrap_addr:?})",
92 );
93 loop {
94 RealClock.sleep(Duration::from_secs(5)).await;
95
96 let result = tx
97 .send(Process2Allocator(
98 bootstrap_index,
99 Process2AllocatorMessage::Heartbeat,
100 ))
101 .await;
102
103 if let Err(err) = result {
104 tracing::error!(
105 "Heartbeat failed to allocator, exiting! (addr: {:?}): {}",
106 bootstrap_addr,
107 err
108 );
109 std::process::exit(1);
110 }
111 }
112}
113
114pub async fn bootstrap() -> anyhow::Error {
129 pub async fn go() -> Result<(), anyhow::Error> {
130 let procs = Arc::new(Mutex::new(Vec::<Proc>::new()));
131 let procs_for_cleanup = procs.clone();
132 let _cleanup_guard = hyperactor::register_signal_cleanup_scoped(Box::pin(async move {
133 for proc_to_stop in procs_for_cleanup.lock().await.iter_mut() {
134 if let Err(err) = proc_to_stop
135 .destroy_and_wait(Duration::from_millis(10), None)
136 .await
137 {
138 tracing::error!(
139 "error while stopping proc {}: {}",
140 proc_to_stop.proc_id(),
141 err
142 );
143 }
144 }
145 }));
146
147 let bootstrap_addr: ChannelAddr = std::env::var(BOOTSTRAP_ADDR_ENV)
148 .map_err(|err| anyhow::anyhow!("read `{}`: {}", BOOTSTRAP_ADDR_ENV, err))?
149 .parse()?;
150 let bootstrap_index: usize = std::env::var(BOOTSTRAP_INDEX_ENV)
151 .map_err(|err| anyhow::anyhow!("read `{}`: {}", BOOTSTRAP_INDEX_ENV, err))?
152 .parse()?;
153 let listen_addr = ChannelAddr::any(bootstrap_addr.transport());
154 let (serve_addr, mut rx) = channel::serve(listen_addr).await?;
155 let tx = channel::dial(bootstrap_addr.clone())?;
156
157 tx.send(Process2Allocator(
158 bootstrap_index,
159 Process2AllocatorMessage::Hello(serve_addr),
160 ))
161 .await?;
162
163 tokio::spawn(exit_if_missed_heartbeat(bootstrap_index, bootstrap_addr));
164
165 loop {
166 let _ = hyperactor::tracing::info_span!("wait_for_next_message_from_mesh_agent");
167 match rx.recv().await? {
168 Allocator2Process::StartProc(proc_id, listen_transport) => {
169 let (proc, mesh_agent) = MeshAgent::bootstrap(proc_id.clone()).await?;
170 let (proc_addr, proc_rx) =
171 channel::serve(ChannelAddr::any(listen_transport)).await?;
172 let handle = proc.clone().serve(proc_rx);
173 drop(handle); tx.send(Process2Allocator(
175 bootstrap_index,
176 Process2AllocatorMessage::StartedProc(
177 proc_id.clone(),
178 mesh_agent.bind(),
179 proc_addr,
180 ),
181 ))
182 .await?;
183 procs.lock().await.push(proc);
184 }
185 Allocator2Process::StopAndExit(code) => {
186 tracing::info!("stopping procs with code {code}");
187 {
188 for proc_to_stop in procs.lock().await.iter_mut() {
189 if let Err(err) = proc_to_stop
190 .destroy_and_wait(Duration::from_millis(10), None)
191 .await
192 {
193 tracing::error!(
194 "error while stopping proc {}: {}",
195 proc_to_stop.proc_id(),
196 err
197 );
198 }
199 }
200 }
201 tracing::info!("exiting with {code}");
202 std::process::exit(code);
203 }
204 Allocator2Process::Exit(code) => {
205 tracing::info!("exiting with {code}");
206 std::process::exit(code);
207 }
208 }
209 }
210 }
211
212 go().await.unwrap_err()
213}
214
215pub async fn bootstrap_or_die() -> ! {
218 let err = bootstrap().await;
219 tracing::error!("failed to bootstrap mesh process: {}", err);
220 std::process::exit(1)
221}