1use std::collections::HashMap;
10use std::io::BufRead;
11use std::io::Write;
12use std::os::fd::RawFd;
13use std::time::Duration;
14
15use anyhow::Result;
16use anyhow::anyhow;
17use hyperactor::ActorHandle;
18use hyperactor::ProcId;
19use hyperactor::WorldId;
20use hyperactor::actor::ActorStatus;
21use hyperactor::channel::ChannelAddr;
22use hyperactor_multiprocess::proc_actor::ProcActor;
23use hyperactor_multiprocess::system_actor::ProcLifecycleMode;
24use monarch_hyperactor::runtime::get_tokio_runtime;
25use pyo3::prelude::*;
26use pyo3::types::PyType;
27use serde::Deserialize;
28use serde::Serialize;
29
30use crate::pipe::OutOfProcessSetupParams;
31use crate::pipe::Pipe;
32use crate::pipe::StreamPipe;
33use crate::py_pipe::PyPipe;
34use crate::py_pipe::run_py_pipe;
35
36#[derive(clap::Parser)]
37pub enum BinaryArgs {
38 WorkerServer { rd: RawFd, wr: RawFd },
40 Worker(WorkerBootstrapArgs),
42 Pipe,
44}
45
46#[derive(Debug, clap::Args)]
50pub struct WorkerBootstrapArgs {
51 #[arg(long)]
53 pub world_id: WorldId,
54
55 #[arg(long)]
57 pub proc_id: ProcId,
58
59 #[arg(long)]
61 pub bootstrap_addr: ChannelAddr,
62
63 #[arg(long, default_value_t = 5)]
65 pub supervision_update_interval_in_sec: u64,
66
67 #[clap(long, value_parser=parse_key_val)]
70 extra_proc_labels: Option<Vec<(String, String)>>,
71}
72
73pub async fn bootstrap_worker_proc(
78 args: WorkerBootstrapArgs,
79) -> Result<ActorHandle<ProcActor>, anyhow::Error> {
80 let labels: HashMap<String, String> = match args.extra_proc_labels {
81 Some(extra_lables) => extra_lables.into_iter().collect(),
82 _ => HashMap::new(),
83 };
84
85 tracing::info!(
86 "bootstrap worker proc {} in world {} with labels: {:?}",
87 args.proc_id,
88 args.world_id,
89 labels
90 );
91 let bootstrap = ProcActor::bootstrap(
92 args.proc_id,
93 args.world_id,
94 ChannelAddr::any(args.bootstrap_addr.transport()),
95 args.bootstrap_addr.clone(),
96 Duration::from_secs(args.supervision_update_interval_in_sec),
97 labels,
98 ProcLifecycleMode::ManagedBySystem,
99 )
100 .await?;
101
102 Ok(bootstrap.proc_actor)
103}
104
105pub fn bootstrap_pipe() -> Result<(), anyhow::Error> {
106 let mut pipe = StreamPipe::new(std::io::stdin(), std::io::stdout(), 4);
109 let init: OutOfProcessSetupParams = pipe.recv()?;
110 run_py_pipe(
115 PyPipe::new(Box::new(pipe), init.ranks, init.sizes, true),
116 init.function,
117 init.args,
118 init.kwargs,
119 )?;
120
121 Ok(())
122}
123
124fn parse_key_val(s: &str) -> anyhow::Result<(String, String)> {
125 match s.split_once('=') {
126 None => Err(anyhow::anyhow!("invalid KEY=value: no `=` found in `{s}`")),
127 Some((a, b)) => Ok((a.to_owned(), b.to_owned())),
128 }
129}
130
131#[pyclass(
132 frozen,
133 module = "monarch._rust_bindings.monarch_tensor_worker.bootstrap"
134)]
135#[derive(Debug, Serialize, Deserialize)]
136pub enum WorkerServerRequest {
137 Run {
138 world_id: String,
139 proc_id: String,
140 bootstrap_addr: String,
141 labels: Vec<(String, String)>,
142 },
143 Exit(),
144}
145
146#[pymethods]
147impl WorkerServerRequest {
148 fn to_json(&self) -> PyResult<String> {
149 Ok(serde_json::to_string(self).map_err(|e| anyhow!(e))?)
150 }
151
152 fn __str__(&self) -> String {
153 format!("{:?}", self)
154 }
155}
156
157#[pyclass(
158 frozen,
159 module = "monarch._rust_bindings.monarch_tensor_worker.bootstrap"
160)]
161#[derive(Debug, Serialize, Deserialize)]
162pub enum WorkerServerResponse {
163 Finished { error: Option<String> },
164}
165
166#[pymethods]
167impl WorkerServerResponse {
168 #[classmethod]
169 fn from_json(_: &Bound<'_, PyType>, json: &str) -> PyResult<Self> {
170 Ok(serde_json::from_str(json).map_err(|e| anyhow!(e))?)
171 }
172
173 fn __str__(&self) -> String {
174 format!("{:?}", self)
175 }
176}
177
178pub fn worker_server(inp: impl BufRead, mut outp: impl Write) -> Result<()> {
179 tracing::info!("running worker server on {}", std::process::id());
180
181 for line in inp.lines() {
182 let line = line?;
183 let request: WorkerServerRequest = serde_json::from_str(&line)?;
184 tracing::info!("got worker request: {:?}", request);
185 let response = match serde_json::from_str(&line)? {
186 WorkerServerRequest::Run {
187 world_id,
188 proc_id,
189 bootstrap_addr,
190 labels,
191 } => {
192 let args = WorkerBootstrapArgs {
193 world_id: world_id.parse()?,
194 proc_id: proc_id.parse()?,
195 bootstrap_addr: bootstrap_addr.parse()?,
196 supervision_update_interval_in_sec: 5,
197 extra_proc_labels: Some(labels),
198 };
199 let res = get_tokio_runtime()
200 .block_on(async move { anyhow::Ok(bootstrap_worker_proc(args).await?.await) });
201 WorkerServerResponse::Finished {
202 error: match res {
203 Err(err) => Some(format!("{}", err)),
204 Ok(ActorStatus::Stopped) => None,
205 Ok(status) => Some(format!("unexpected actor status: {}", status)),
206 },
207 }
208 }
209 WorkerServerRequest::Exit() => break,
210 };
211 tracing::info!("sending worker response: {:?}", response);
212 writeln!(outp, "{}", &serde_json::to_string(&response)?)?;
213 }
214
215 tracing::info!("finished running worker server");
216
217 std::process::exit(0);
221}
222
223pub fn register_python_bindings(worker_mod: &Bound<'_, PyModule>) -> PyResult<()> {
224 worker_mod.add_class::<WorkerServerRequest>()?;
225 worker_mod.add_class::<WorkerServerResponse>()?;
226
227 Ok(())
228}
229
230#[cfg(test)]
231mod tests {
232 use hyperactor::channel::ChannelTransport;
233 use hyperactor::id;
234 use hyperactor_multiprocess::System;
235 use timed_test::async_timed_test;
236
237 use super::*;
238
239 #[async_timed_test(timeout_secs = 60)]
240 async fn test_worker_bootstrap() {
241 let server_handle = System::serve(
242 ChannelAddr::any(ChannelTransport::Local),
243 Duration::from_secs(10),
244 Duration::from_secs(10),
245 )
246 .await
247 .unwrap();
248
249 let world_id = id!(worker);
250 let proc_id = world_id.proc_id(0);
251 let proc_handle = bootstrap_worker_proc(WorkerBootstrapArgs {
252 world_id,
253 proc_id,
254 bootstrap_addr: server_handle.local_addr().clone(),
255 supervision_update_interval_in_sec: 5,
256 extra_proc_labels: None,
257 })
258 .await
259 .unwrap();
260
261 proc_handle.drain_and_stop().unwrap();
262 }
263
264 #[test]
265 fn test_parse_key_val_valid_input() {
266 let s = "key=value";
267 assert_eq!(
268 parse_key_val(s).unwrap(),
269 ("key".to_string(), "value".to_string())
270 );
271 }
272
273 #[test]
274 fn test_parse_key_val_extra_equals() {
275 let s = "key=value=3";
276 assert_eq!(
277 parse_key_val(s).unwrap(),
278 ("key".to_string(), "value=3".to_string())
279 );
280 }
281
282 #[test]
283 fn test_parse_key_val_invalid() {
284 let s = "invalid";
285 assert!(parse_key_val(s).is_err());
286 }
287}