monarch_tensor_worker/
bootstrap.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::io::BufRead;
11use std::io::Write;
12use std::os::fd::RawFd;
13use std::time::Duration;
14
15use anyhow::Result;
16use anyhow::anyhow;
17use hyperactor::ActorHandle;
18use hyperactor::ProcId;
19use hyperactor::WorldId;
20use hyperactor::actor::ActorStatus;
21use hyperactor::channel::ChannelAddr;
22use hyperactor_multiprocess::proc_actor::ProcActor;
23use hyperactor_multiprocess::system_actor::ProcLifecycleMode;
24use monarch_hyperactor::runtime::get_tokio_runtime;
25use pyo3::prelude::*;
26use pyo3::types::PyType;
27use serde::Deserialize;
28use serde::Serialize;
29
30use crate::pipe::OutOfProcessSetupParams;
31use crate::pipe::Pipe;
32use crate::pipe::StreamPipe;
33use crate::py_pipe::PyPipe;
34use crate::py_pipe::run_py_pipe;
35
36#[derive(clap::Parser)]
37pub enum BinaryArgs {
38    /// Starts running a worker server.
39    WorkerServer { rd: RawFd, wr: RawFd },
40    /// Starts running a worker.
41    Worker(WorkerBootstrapArgs),
42    /// Starts running a pipe.
43    Pipe,
44}
45
46/// Bootstrap arguments used for bootstrapping workers for both Python and Rust entry points.
47// TODO: We might want to convert these arguments to environment variables depending on
48// how we launch the processes in the end.
49#[derive(Debug, clap::Args)]
50pub struct WorkerBootstrapArgs {
51    /// The world id of the launched worker.
52    #[arg(long)]
53    pub world_id: WorldId,
54
55    /// The proc id of the launched worker.
56    #[arg(long)]
57    pub proc_id: ProcId,
58
59    /// The system address for the worker to connect to.
60    #[arg(long)]
61    pub bootstrap_addr: ChannelAddr,
62
63    /// The supervision update interval for worker proc actor.
64    #[arg(long, default_value_t = 5)]
65    pub supervision_update_interval_in_sec: u64,
66
67    /// Proc metadata which will be available through system.
68    /// Keys are not allowed to contain '='.
69    #[clap(long, value_parser=parse_key_val)]
70    extra_proc_labels: Option<Vec<(String, String)>>,
71}
72
73/// Bootstrap the worker proc and join the system at `bootstrap_addr`.
74/// The actual worker actor is spawned by the corresponding controller.
75/// The worker Python dependencies need to be packaged
76/// separately and loaded during the runtime.
77pub async fn bootstrap_worker_proc(
78    args: WorkerBootstrapArgs,
79) -> Result<ActorHandle<ProcActor>, anyhow::Error> {
80    let labels: HashMap<String, String> = match args.extra_proc_labels {
81        Some(extra_lables) => extra_lables.into_iter().collect(),
82        _ => HashMap::new(),
83    };
84
85    tracing::info!(
86        "bootstrap worker proc {} in world {} with labels: {:?}",
87        args.proc_id,
88        args.world_id,
89        labels
90    );
91    let bootstrap = ProcActor::bootstrap(
92        args.proc_id,
93        args.world_id,
94        ChannelAddr::any(args.bootstrap_addr.transport()),
95        args.bootstrap_addr.clone(),
96        Duration::from_secs(args.supervision_update_interval_in_sec),
97        labels,
98        ProcLifecycleMode::ManagedBySystem,
99    )
100    .await?;
101
102    Ok(bootstrap.proc_actor)
103}
104
105pub fn bootstrap_pipe() -> Result<(), anyhow::Error> {
106    // Use a temp pipe just to ship the init params.
107    // Value of 4 is arbitrary as our side does not need to do buffering.
108    let mut pipe = StreamPipe::new(std::io::stdin(), std::io::stdout(), 4);
109    let init: OutOfProcessSetupParams = pipe.recv()?;
110    // Create a PyPipe that allows unsafe object conversion. This allows the pipe to
111    // receive tensors, which we know is safe because StreamPipe receives the serialized
112    // tensors from out-of-process, and they therefore can't be owned by anything except
113    // the pipe's python code.
114    run_py_pipe(
115        PyPipe::new(Box::new(pipe), init.ranks, init.sizes, true),
116        init.function,
117        init.args,
118        init.kwargs,
119    )?;
120
121    Ok(())
122}
123
124fn parse_key_val(s: &str) -> anyhow::Result<(String, String)> {
125    match s.split_once('=') {
126        None => Err(anyhow::anyhow!("invalid KEY=value: no `=` found in `{s}`")),
127        Some((a, b)) => Ok((a.to_owned(), b.to_owned())),
128    }
129}
130
131#[pyclass(
132    frozen,
133    module = "monarch._rust_bindings.monarch_tensor_worker.bootstrap"
134)]
135#[derive(Debug, Serialize, Deserialize)]
136pub enum WorkerServerRequest {
137    Run {
138        world_id: String,
139        proc_id: String,
140        bootstrap_addr: String,
141        labels: Vec<(String, String)>,
142    },
143    Exit(),
144}
145
146#[pymethods]
147impl WorkerServerRequest {
148    fn to_json(&self) -> PyResult<String> {
149        Ok(serde_json::to_string(self).map_err(|e| anyhow!(e))?)
150    }
151
152    fn __str__(&self) -> String {
153        format!("{:?}", self)
154    }
155}
156
157#[pyclass(
158    frozen,
159    module = "monarch._rust_bindings.monarch_tensor_worker.bootstrap"
160)]
161#[derive(Debug, Serialize, Deserialize)]
162pub enum WorkerServerResponse {
163    Finished { error: Option<String> },
164}
165
166#[pymethods]
167impl WorkerServerResponse {
168    #[classmethod]
169    fn from_json(_: &Bound<'_, PyType>, json: &str) -> PyResult<Self> {
170        Ok(serde_json::from_str(json).map_err(|e| anyhow!(e))?)
171    }
172
173    fn __str__(&self) -> String {
174        format!("{:?}", self)
175    }
176}
177
178pub fn worker_server(inp: impl BufRead, mut outp: impl Write) -> Result<()> {
179    tracing::info!("running worker server on {}", std::process::id());
180
181    for line in inp.lines() {
182        let line = line?;
183        let request: WorkerServerRequest = serde_json::from_str(&line)?;
184        tracing::info!("got worker request: {:?}", request);
185        let response = match serde_json::from_str(&line)? {
186            WorkerServerRequest::Run {
187                world_id,
188                proc_id,
189                bootstrap_addr,
190                labels,
191            } => {
192                let args = WorkerBootstrapArgs {
193                    world_id: world_id.parse()?,
194                    proc_id: proc_id.parse()?,
195                    bootstrap_addr: bootstrap_addr.parse()?,
196                    supervision_update_interval_in_sec: 5,
197                    extra_proc_labels: Some(labels),
198                };
199                let res = get_tokio_runtime()
200                    .block_on(async move { anyhow::Ok(bootstrap_worker_proc(args).await?.await) });
201                WorkerServerResponse::Finished {
202                    error: match res {
203                        Err(err) => Some(format!("{}", err)),
204                        Ok(ActorStatus::Stopped) => None,
205                        Ok(status) => Some(format!("unexpected actor status: {}", status)),
206                    },
207                }
208            }
209            WorkerServerRequest::Exit() => break,
210        };
211        tracing::info!("sending worker response: {:?}", response);
212        writeln!(outp, "{}", &serde_json::to_string(&response)?)?;
213    }
214
215    tracing::info!("finished running worker server");
216
217    // TODO(agallagher): Forcing an exit here saves 700ms on shutdown for some
218    // reasons -- does this avoid some slow Python shutdown code?
219    //Ok(())
220    std::process::exit(0);
221}
222
223pub fn register_python_bindings(worker_mod: &Bound<'_, PyModule>) -> PyResult<()> {
224    worker_mod.add_class::<WorkerServerRequest>()?;
225    worker_mod.add_class::<WorkerServerResponse>()?;
226
227    Ok(())
228}
229
230#[cfg(test)]
231mod tests {
232    use hyperactor::channel::ChannelTransport;
233    use hyperactor::id;
234    use hyperactor_multiprocess::System;
235    use timed_test::async_timed_test;
236
237    use super::*;
238
239    #[async_timed_test(timeout_secs = 60)]
240    async fn test_worker_bootstrap() {
241        let server_handle = System::serve(
242            ChannelAddr::any(ChannelTransport::Local),
243            Duration::from_secs(10),
244            Duration::from_secs(10),
245        )
246        .await
247        .unwrap();
248
249        let world_id = id!(worker);
250        let proc_id = world_id.proc_id(0);
251        let proc_handle = bootstrap_worker_proc(WorkerBootstrapArgs {
252            world_id,
253            proc_id,
254            bootstrap_addr: server_handle.local_addr().clone(),
255            supervision_update_interval_in_sec: 5,
256            extra_proc_labels: None,
257        })
258        .await
259        .unwrap();
260
261        proc_handle.drain_and_stop().unwrap();
262    }
263
264    #[test]
265    fn test_parse_key_val_valid_input() {
266        let s = "key=value";
267        assert_eq!(
268            parse_key_val(s).unwrap(),
269            ("key".to_string(), "value".to_string())
270        );
271    }
272
273    #[test]
274    fn test_parse_key_val_extra_equals() {
275        let s = "key=value=3";
276        assert_eq!(
277            parse_key_val(s).unwrap(),
278            ("key".to_string(), "value=3".to_string())
279        );
280    }
281
282    #[test]
283    fn test_parse_key_val_invalid() {
284        let s = "invalid";
285        assert!(parse_key_val(s).is_err());
286    }
287}