monarch_tensor_worker/
pipe.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::future::Future;
11use std::io::Read;
12use std::io::Write;
13use std::process::Stdio;
14use std::thread;
15
16use anyhow::Context;
17use anyhow::Result;
18use anyhow::anyhow;
19use anyhow::bail;
20use async_trait::async_trait;
21use hyperactor::Actor;
22use hyperactor::HandleClient;
23use hyperactor::Handler;
24use hyperactor::forward;
25use hyperactor::mailbox::OncePortHandle;
26use monarch_messages::controller::WorkerError;
27use monarch_types::PyTree;
28use nix::sys::wait::WaitStatus;
29use nix::unistd::Pid;
30use serde::Deserialize;
31use serde::Serialize;
32use serde::de::DeserializeOwned;
33use tokio::io::AsyncRead;
34use tokio::io::AsyncReadExt;
35use tokio::io::AsyncWrite;
36use tokio::io::AsyncWriteExt;
37use tokio::process::Child;
38use tokio::process::Command;
39use tokio::sync::mpsc;
40use tokio::task;
41use torch_sys::RValue;
42
43use crate::ResolvableFunction;
44
45/// Simple communication channel to send/recv objects over an async stream.
46pub trait AsyncPipe<T> {
47    fn send(&mut self, val: T) -> impl Future<Output = Result<()>>;
48    fn recv(&mut self) -> impl Future<Output = Result<T>>;
49}
50
51/// Simple communication channel to send/recv objects over a synchronous stream.
52/// NOTE: This synchronous specialization is mainly useful when wrapped w/ the
53/// `PyPipe` struct, which is also synchronous (via Python).
54pub trait Pipe<T> {
55    fn send(&mut self, val: T) -> Result<()>;
56    fn recv(&mut self) -> Result<T>;
57}
58
59#[derive(Serialize, Deserialize)]
60pub struct OutOfProcessSetupParams {
61    pub sizes: HashMap<String, usize>,
62    pub ranks: HashMap<String, usize>,
63    pub function: ResolvableFunction,
64    pub args: Vec<PyTree<RValue>>,
65    pub kwargs: HashMap<String, PyTree<RValue>>,
66}
67
68impl<T: Send + Sync + 'static> AsyncPipe<T>
69    for (mpsc::UnboundedSender<T>, mpsc::UnboundedReceiver<T>)
70{
71    async fn send(&mut self, val: T) -> Result<()> {
72        Ok(self.0.send(val)?)
73    }
74
75    async fn recv(&mut self) -> Result<T> {
76        Ok(self
77            .1
78            .recv()
79            .await
80            .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::UnexpectedEof, ""))?)
81    }
82}
83
84impl<T: Send + Sync + 'static> Pipe<T> for (mpsc::UnboundedSender<T>, mpsc::UnboundedReceiver<T>) {
85    fn send(&mut self, val: T) -> Result<()> {
86        Ok(self.0.send(val)?)
87    }
88
89    fn recv(&mut self) -> Result<T> {
90        Ok(self
91            .1
92            .blocking_recv()
93            .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::UnexpectedEof, ""))?)
94    }
95}
96
97/// Return a pair of parent/child pipes connected via unbounded tokio mpsc
98/// queues.
99pub fn create_local_pipe<T>() -> (
100    (mpsc::UnboundedSender<T>, mpsc::UnboundedReceiver<T>),
101    (mpsc::UnboundedSender<T>, mpsc::UnboundedReceiver<T>),
102) {
103    let (t1, r1) = mpsc::unbounded_channel();
104    let (t2, r2) = mpsc::unbounded_channel();
105    ((t1, r2), (t2, r1))
106}
107
108pub trait AsyncWriteDebug: std::fmt::Debug + AsyncWrite + Sync + Send + Unpin {}
109impl<T: std::fmt::Debug + AsyncWrite + Unpin + Sync + Send> AsyncWriteDebug for T {}
110
111#[derive(Debug)]
112pub struct AsyncStreamPipe {
113    writer: Box<dyn AsyncWriteDebug>,
114    channel_reader: mpsc::Receiver<Vec<u8>>,
115}
116
117impl AsyncStreamPipe {
118    /// Create a new `AsyncStreamPipe` from a reader/writer pair.
119    /// The pipe will run a background task to read-ahead up to max_messages
120    /// messages to make them immediately available to read.
121    /// When reader is closed, the background task will exit and further
122    /// reads will return an error.
123    pub fn new(
124        mut reader: impl AsyncRead + Unpin + Send + 'static,
125        writer: impl AsyncWriteDebug + 'static,
126        max_messages: usize,
127    ) -> Self {
128        let (channel_writer, channel_reader) = mpsc::channel::<Vec<u8>>(max_messages);
129
130        task::spawn(async move {
131            loop {
132                let mut buf = vec![0; 8];
133                match reader.read_exact(&mut buf).await {
134                    Ok(_) => (),
135                    Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
136                    // Other errors should not be expected. We should perhaps log and break
137                    // instead of panicking.
138                    Err(e) => panic!("preamble read failed: {}", e),
139                }
140                let len = u64::from_be_bytes(buf.try_into().unwrap());
141                buf = vec![0; len as usize];
142                match reader.read_exact(&mut buf).await {
143                    Ok(_) => (),
144                    Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
145                    // Other errors should not be expected. We should perhaps log and break
146                    // instead of panicking.
147                    Err(e) => panic!("read failed: {}", e),
148                }
149                if channel_writer.send(buf).await.is_err() {
150                    // receiver closed AsyncStreamPipe dropped, so we can break out of the loop
151                    break;
152                }
153            }
154        });
155
156        AsyncStreamPipe {
157            writer: Box::new(writer),
158            channel_reader,
159        }
160    }
161}
162
163impl<T: Serialize + DeserializeOwned> AsyncPipe<T> for AsyncStreamPipe {
164    async fn send(&mut self, val: T) -> Result<()> {
165        let bytes = bincode::serialize(&val)?;
166        let len = bytes.len();
167        self.writer.write_all(&len.to_be_bytes()).await?;
168        self.writer.write_all(&bytes).await?;
169        Ok(())
170    }
171
172    async fn recv(&mut self) -> Result<T> {
173        let buf = self.channel_reader.recv().await.expect("recv failed");
174        Ok(bincode::deserialize(&buf)?)
175    }
176}
177
178pub trait WriteDebug: std::fmt::Debug + Write + Sync + Send {}
179impl<T: std::fmt::Debug + Write + Sync + Send> WriteDebug for T {}
180
181pub struct StreamPipe {
182    writer: Box<dyn WriteDebug>,
183    channel_reader: std::sync::Arc<std::sync::Mutex<::std::sync::mpsc::Receiver<Vec<u8>>>>,
184}
185
186impl StreamPipe {
187    /// Create a new `AsyncStreamPipe` from a reader/writer pair.
188    /// The pipe will run a background thread to read-ahead up to max_messages
189    /// messages to make them immediately available to read.
190    /// When reader is closed, the background thread will exit and further
191    /// reads will return an error.
192    pub fn new(
193        mut reader: impl Read + Send + 'static,
194        writer: impl WriteDebug + 'static,
195        max_messages: usize,
196    ) -> Self {
197        let (channel_writer, channel_reader) =
198            ::std::sync::mpsc::sync_channel::<Vec<u8>>(max_messages);
199
200        thread::spawn(move || {
201            loop {
202                let mut buf = vec![0; 8];
203                match reader.read_exact(&mut buf) {
204                    Ok(_) => (),
205                    Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
206                    // Other errors should not be expected. We should perhaps log and break
207                    // instead of panicking.
208                    Err(e) => panic!("preamble read failed: {}", e),
209                }
210                let len = u64::from_be_bytes(buf.try_into().unwrap());
211                buf = vec![0; len as usize];
212                match reader.read_exact(&mut buf) {
213                    Ok(_) => (),
214                    Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
215                    Err(e) => panic!("preamble read failed: {}", e),
216                }
217                if channel_writer.send(buf).is_err() {
218                    // receiver closed StreamPipe dropped, so we can break out of the loop
219                    break;
220                }
221            }
222        });
223
224        StreamPipe {
225            writer: Box::new(writer),
226            channel_reader: std::sync::Arc::new(std::sync::Mutex::new(channel_reader)),
227        }
228    }
229}
230
231impl<T: Serialize + DeserializeOwned> Pipe<T> for StreamPipe {
232    fn send(&mut self, val: T) -> Result<()> {
233        let bytes = bincode::serialize(&val)?;
234        let len = bytes.len();
235        self.writer.write_all(&len.to_be_bytes())?;
236        self.writer.write_all(&bytes)?;
237        self.writer.flush()?;
238        Ok(())
239    }
240
241    fn recv(&mut self) -> Result<T> {
242        let buf = self
243            .channel_reader
244            .lock()
245            .unwrap()
246            .recv()
247            .expect("recv failed");
248        Ok(bincode::deserialize(&buf)?)
249    }
250}
251
252#[allow(dead_code)]
253#[derive(Handler, HandleClient, Debug)]
254pub enum PipeMessage {
255    SendValue(Result<PyTree<RValue>, WorkerError>),
256
257    RecvValue(#[reply] OncePortHandle<PyTree<RValue>>),
258}
259
260#[derive(Debug)]
261pub struct PipeActor {
262    // NOTE: Use `Option` wrappers to allow moving in `Drop` impl below.
263    pipe: Option<AsyncStreamPipe>,
264    handle: Child,
265}
266
267/// Initialization parameters for `PipeActor`.
268#[derive(Debug, Clone)]
269pub struct PipeParams {
270    pub function: ResolvableFunction,
271    pub max_messages: i64,
272    pub ranks: HashMap<String, usize>,
273    pub sizes: HashMap<String, usize>,
274    pub args: Vec<PyTree<RValue>>,
275    pub kwargs: HashMap<String, PyTree<RValue>>,
276}
277
278#[async_trait]
279impl Actor for PipeActor {
280    type Params = PipeParams;
281
282    async fn new(params: Self::Params) -> Result<Self> {
283        let mut command = Command::new(
284            std::env::var("MONARCH_TENSOR_WORKER_EXE")
285                .map_err(|e| anyhow!("could not get var MONARCH_TENSOR_WORKER_EXE: {}", e))?,
286        );
287        if let Ok(main) = std::env::var("MONARCH_TENSOR_WORKER_MAIN") {
288            if std::env::var("FB_XAR_INVOKED_NAME").is_ok() {
289                command.env("PAR_MAIN_OVERRIDE", main);
290            } else {
291                command.arg("-m").arg(main);
292            }
293        }
294
295        // Spawn server process.
296        let mut handle = command
297            .arg("pipe")
298            .stdout(Stdio::piped())
299            .stdin(Stdio::piped())
300            .kill_on_drop(true)
301            .spawn()?;
302
303        // Send init args.
304        let mut pipe = AsyncStreamPipe::new(
305            handle.stdout.take().unwrap(),
306            handle.stdin.take().unwrap(),
307            params.max_messages as usize,
308        );
309        let params = OutOfProcessSetupParams {
310            ranks: params.ranks,
311            sizes: params.sizes,
312            function: params.function,
313            args: params.args,
314            kwargs: params.kwargs,
315        };
316        tokio::select! {
317            res = handle.wait() => bail!("pipe server exited: {:?}", res),
318            res = pipe.send(params) => res?,
319        }
320
321        Ok(Self {
322            pipe: Some(pipe),
323            handle,
324        })
325    }
326}
327
328impl PipeActor {
329    /// Forcibly kill and cleanup the pipe server. Avoids `await` to be usable
330    /// in `Drop`.
331    fn kill_pipe_server(&mut self) -> Result<()> {
332        self.handle.start_kill()?;
333
334        // NOT(agallagher): Since this is called from `drop()`, we can't
335        // use the async `wait()` method (is there a way to convert to
336        // `std::process::Child`?).
337        let pid = Pid::from_raw(self.handle.id().context("cannot get pid")? as i32);
338        match nix::sys::wait::waitpid(pid, None)? {
339            WaitStatus::Exited(_, 0) => (),
340            status => bail!("exited abnormally: {:?}", status),
341        }
342        Ok(())
343    }
344}
345
346// TODO(agallager): It'd be nice if the `Actor` API had a `shutdown` mechanism
347// which could allow for preserving error propagation in cases like this.
348impl Drop for PipeActor {
349    fn drop(&mut self) {
350        // Close the pipe first, which should make the server end get an EPIPE
351        // and die.
352        self.pipe.take();
353
354        // Kill/cleanup the server.
355        if let Err(err) = self.kill_pipe_server() {
356            tracing::warn!("error cleaning up pipe server: {}", err);
357        }
358    }
359}
360
361#[async_trait]
362#[forward(PipeMessage)]
363impl PipeMessageHandler for PipeActor {
364    async fn send_value(
365        &mut self,
366        _cx: &hyperactor::Context<Self>,
367        val: Result<PyTree<RValue>, WorkerError>,
368    ) -> Result<()> {
369        // TODO(agallagher): Propagate failures and use a timeout and handle worker errors?
370        let val = val.map_err(|err| anyhow::anyhow!(err.backtrace).context("worker error"))?;
371        tokio::select! {
372            res = self.handle.wait() => bail!("pipe server exited: {:?}", res),
373            res = self.pipe.as_mut().unwrap().send(val) => res?,
374        };
375        Ok(())
376    }
377
378    async fn recv_value(&mut self, _cx: &hyperactor::Context<Self>) -> Result<PyTree<RValue>> {
379        // TODO(agallagher): Propagate failures and use a timeout?
380        tokio::select! {
381            res = self.handle.wait() => bail!("pipe server exited: {:?}", res),
382            res = self.pipe.as_mut().unwrap().recv() => res
383        }
384    }
385}