1use std::collections::HashMap;
10use std::future::Future;
11use std::io::Read;
12use std::io::Write;
13use std::process::Stdio;
14use std::thread;
15
16use anyhow::Context;
17use anyhow::Result;
18use anyhow::anyhow;
19use anyhow::bail;
20use async_trait::async_trait;
21use hyperactor::Actor;
22use hyperactor::HandleClient;
23use hyperactor::Handler;
24use hyperactor::forward;
25use hyperactor::mailbox::OncePortHandle;
26use monarch_messages::controller::WorkerError;
27use monarch_types::PyTree;
28use nix::sys::wait::WaitStatus;
29use nix::unistd::Pid;
30use serde::Deserialize;
31use serde::Serialize;
32use serde::de::DeserializeOwned;
33use tokio::io::AsyncRead;
34use tokio::io::AsyncReadExt;
35use tokio::io::AsyncWrite;
36use tokio::io::AsyncWriteExt;
37use tokio::process::Child;
38use tokio::process::Command;
39use tokio::sync::mpsc;
40use tokio::task;
41use torch_sys::RValue;
42
43use crate::ResolvableFunction;
44
45pub trait AsyncPipe<T> {
47 fn send(&mut self, val: T) -> impl Future<Output = Result<()>>;
48 fn recv(&mut self) -> impl Future<Output = Result<T>>;
49}
50
51pub trait Pipe<T> {
55 fn send(&mut self, val: T) -> Result<()>;
56 fn recv(&mut self) -> Result<T>;
57}
58
59#[derive(Serialize, Deserialize)]
60pub struct OutOfProcessSetupParams {
61 pub sizes: HashMap<String, usize>,
62 pub ranks: HashMap<String, usize>,
63 pub function: ResolvableFunction,
64 pub args: Vec<PyTree<RValue>>,
65 pub kwargs: HashMap<String, PyTree<RValue>>,
66}
67
68impl<T: Send + Sync + 'static> AsyncPipe<T>
69 for (mpsc::UnboundedSender<T>, mpsc::UnboundedReceiver<T>)
70{
71 async fn send(&mut self, val: T) -> Result<()> {
72 Ok(self.0.send(val)?)
73 }
74
75 async fn recv(&mut self) -> Result<T> {
76 Ok(self
77 .1
78 .recv()
79 .await
80 .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::UnexpectedEof, ""))?)
81 }
82}
83
84impl<T: Send + Sync + 'static> Pipe<T> for (mpsc::UnboundedSender<T>, mpsc::UnboundedReceiver<T>) {
85 fn send(&mut self, val: T) -> Result<()> {
86 Ok(self.0.send(val)?)
87 }
88
89 fn recv(&mut self) -> Result<T> {
90 Ok(self
91 .1
92 .blocking_recv()
93 .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::UnexpectedEof, ""))?)
94 }
95}
96
97pub fn create_local_pipe<T>() -> (
100 (mpsc::UnboundedSender<T>, mpsc::UnboundedReceiver<T>),
101 (mpsc::UnboundedSender<T>, mpsc::UnboundedReceiver<T>),
102) {
103 let (t1, r1) = mpsc::unbounded_channel();
104 let (t2, r2) = mpsc::unbounded_channel();
105 ((t1, r2), (t2, r1))
106}
107
108pub trait AsyncWriteDebug: std::fmt::Debug + AsyncWrite + Sync + Send + Unpin {}
109impl<T: std::fmt::Debug + AsyncWrite + Unpin + Sync + Send> AsyncWriteDebug for T {}
110
111#[derive(Debug)]
112pub struct AsyncStreamPipe {
113 writer: Box<dyn AsyncWriteDebug>,
114 channel_reader: mpsc::Receiver<Vec<u8>>,
115}
116
117impl AsyncStreamPipe {
118 pub fn new(
124 mut reader: impl AsyncRead + Unpin + Send + 'static,
125 writer: impl AsyncWriteDebug + 'static,
126 max_messages: usize,
127 ) -> Self {
128 let (channel_writer, channel_reader) = mpsc::channel::<Vec<u8>>(max_messages);
129
130 task::spawn(async move {
131 loop {
132 let mut buf = vec![0; 8];
133 match reader.read_exact(&mut buf).await {
134 Ok(_) => (),
135 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
136 Err(e) => panic!("preamble read failed: {}", e),
139 }
140 let len = u64::from_be_bytes(buf.try_into().unwrap());
141 buf = vec![0; len as usize];
142 match reader.read_exact(&mut buf).await {
143 Ok(_) => (),
144 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
145 Err(e) => panic!("read failed: {}", e),
148 }
149 if channel_writer.send(buf).await.is_err() {
150 break;
152 }
153 }
154 });
155
156 AsyncStreamPipe {
157 writer: Box::new(writer),
158 channel_reader,
159 }
160 }
161}
162
163impl<T: Serialize + DeserializeOwned> AsyncPipe<T> for AsyncStreamPipe {
164 async fn send(&mut self, val: T) -> Result<()> {
165 let bytes = bincode::serialize(&val)?;
166 let len = bytes.len();
167 self.writer.write_all(&len.to_be_bytes()).await?;
168 self.writer.write_all(&bytes).await?;
169 Ok(())
170 }
171
172 async fn recv(&mut self) -> Result<T> {
173 let buf = self.channel_reader.recv().await.expect("recv failed");
174 Ok(bincode::deserialize(&buf)?)
175 }
176}
177
178pub trait WriteDebug: std::fmt::Debug + Write + Sync + Send {}
179impl<T: std::fmt::Debug + Write + Sync + Send> WriteDebug for T {}
180
181pub struct StreamPipe {
182 writer: Box<dyn WriteDebug>,
183 channel_reader: std::sync::Arc<std::sync::Mutex<::std::sync::mpsc::Receiver<Vec<u8>>>>,
184}
185
186impl StreamPipe {
187 pub fn new(
193 mut reader: impl Read + Send + 'static,
194 writer: impl WriteDebug + 'static,
195 max_messages: usize,
196 ) -> Self {
197 let (channel_writer, channel_reader) =
198 ::std::sync::mpsc::sync_channel::<Vec<u8>>(max_messages);
199
200 thread::spawn(move || {
201 loop {
202 let mut buf = vec![0; 8];
203 match reader.read_exact(&mut buf) {
204 Ok(_) => (),
205 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
206 Err(e) => panic!("preamble read failed: {}", e),
209 }
210 let len = u64::from_be_bytes(buf.try_into().unwrap());
211 buf = vec![0; len as usize];
212 match reader.read_exact(&mut buf) {
213 Ok(_) => (),
214 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
215 Err(e) => panic!("preamble read failed: {}", e),
216 }
217 if channel_writer.send(buf).is_err() {
218 break;
220 }
221 }
222 });
223
224 StreamPipe {
225 writer: Box::new(writer),
226 channel_reader: std::sync::Arc::new(std::sync::Mutex::new(channel_reader)),
227 }
228 }
229}
230
231impl<T: Serialize + DeserializeOwned> Pipe<T> for StreamPipe {
232 fn send(&mut self, val: T) -> Result<()> {
233 let bytes = bincode::serialize(&val)?;
234 let len = bytes.len();
235 self.writer.write_all(&len.to_be_bytes())?;
236 self.writer.write_all(&bytes)?;
237 self.writer.flush()?;
238 Ok(())
239 }
240
241 fn recv(&mut self) -> Result<T> {
242 let buf = self
243 .channel_reader
244 .lock()
245 .unwrap()
246 .recv()
247 .expect("recv failed");
248 Ok(bincode::deserialize(&buf)?)
249 }
250}
251
252#[allow(dead_code)]
253#[derive(Handler, HandleClient, Debug)]
254pub enum PipeMessage {
255 SendValue(Result<PyTree<RValue>, WorkerError>),
256
257 RecvValue(#[reply] OncePortHandle<PyTree<RValue>>),
258}
259
260#[derive(Debug)]
261pub struct PipeActor {
262 pipe: Option<AsyncStreamPipe>,
264 handle: Child,
265}
266
267#[derive(Debug, Clone)]
269pub struct PipeParams {
270 pub function: ResolvableFunction,
271 pub max_messages: i64,
272 pub ranks: HashMap<String, usize>,
273 pub sizes: HashMap<String, usize>,
274 pub args: Vec<PyTree<RValue>>,
275 pub kwargs: HashMap<String, PyTree<RValue>>,
276}
277
278#[async_trait]
279impl Actor for PipeActor {
280 type Params = PipeParams;
281
282 async fn new(params: Self::Params) -> Result<Self> {
283 let mut command = Command::new(
284 std::env::var("MONARCH_TENSOR_WORKER_EXE")
285 .map_err(|e| anyhow!("could not get var MONARCH_TENSOR_WORKER_EXE: {}", e))?,
286 );
287 if let Ok(main) = std::env::var("MONARCH_TENSOR_WORKER_MAIN") {
288 if std::env::var("FB_XAR_INVOKED_NAME").is_ok() {
289 command.env("PAR_MAIN_OVERRIDE", main);
290 } else {
291 command.arg("-m").arg(main);
292 }
293 }
294
295 let mut handle = command
297 .arg("pipe")
298 .stdout(Stdio::piped())
299 .stdin(Stdio::piped())
300 .kill_on_drop(true)
301 .spawn()?;
302
303 let mut pipe = AsyncStreamPipe::new(
305 handle.stdout.take().unwrap(),
306 handle.stdin.take().unwrap(),
307 params.max_messages as usize,
308 );
309 let params = OutOfProcessSetupParams {
310 ranks: params.ranks,
311 sizes: params.sizes,
312 function: params.function,
313 args: params.args,
314 kwargs: params.kwargs,
315 };
316 tokio::select! {
317 res = handle.wait() => bail!("pipe server exited: {:?}", res),
318 res = pipe.send(params) => res?,
319 }
320
321 Ok(Self {
322 pipe: Some(pipe),
323 handle,
324 })
325 }
326}
327
328impl PipeActor {
329 fn kill_pipe_server(&mut self) -> Result<()> {
332 self.handle.start_kill()?;
333
334 let pid = Pid::from_raw(self.handle.id().context("cannot get pid")? as i32);
338 match nix::sys::wait::waitpid(pid, None)? {
339 WaitStatus::Exited(_, 0) => (),
340 status => bail!("exited abnormally: {:?}", status),
341 }
342 Ok(())
343 }
344}
345
346impl Drop for PipeActor {
349 fn drop(&mut self) {
350 self.pipe.take();
353
354 if let Err(err) = self.kill_pipe_server() {
356 tracing::warn!("error cleaning up pipe server: {}", err);
357 }
358 }
359}
360
361#[async_trait]
362#[forward(PipeMessage)]
363impl PipeMessageHandler for PipeActor {
364 async fn send_value(
365 &mut self,
366 _cx: &hyperactor::Context<Self>,
367 val: Result<PyTree<RValue>, WorkerError>,
368 ) -> Result<()> {
369 let val = val.map_err(|err| anyhow::anyhow!(err.backtrace).context("worker error"))?;
371 tokio::select! {
372 res = self.handle.wait() => bail!("pipe server exited: {:?}", res),
373 res = self.pipe.as_mut().unwrap().send(val) => res?,
374 };
375 Ok(())
376 }
377
378 async fn recv_value(&mut self, _cx: &hyperactor::Context<Self>) -> Result<PyTree<RValue>> {
379 tokio::select! {
381 res = self.handle.wait() => bail!("pipe server exited: {:?}", res),
382 res = self.pipe.as_mut().unwrap().recv() => res
383 }
384 }
385}