hyperactor_mesh/alloc/
process.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![allow(dead_code)] // some things currently used only in tests
10
11use std::collections::HashMap;
12use std::future::Future;
13use std::os::unix::process::ExitStatusExt;
14use std::process::ExitStatus;
15use std::process::Stdio;
16use std::sync::Arc;
17use std::sync::OnceLock;
18
19use async_trait::async_trait;
20use enum_as_inner::EnumAsInner;
21use hyperactor::ProcId;
22use hyperactor::WorldId;
23use hyperactor::channel;
24use hyperactor::channel::ChannelAddr;
25use hyperactor::channel::ChannelError;
26use hyperactor::channel::ChannelTransport;
27use hyperactor::channel::ChannelTx;
28use hyperactor::channel::Rx;
29use hyperactor::channel::Tx;
30use hyperactor::channel::TxStatus;
31use hyperactor::sync::flag;
32use hyperactor::sync::monitor;
33use ndslice::view::Extent;
34use nix::sys::signal;
35use nix::unistd::Pid;
36use serde::Deserialize;
37use serde::Serialize;
38use tokio::io;
39use tokio::process::Command;
40use tokio::sync::Mutex;
41use tokio::task::JoinSet;
42
43use super::Alloc;
44use super::AllocSpec;
45use super::Allocator;
46use super::AllocatorError;
47use super::ProcState;
48use super::ProcStopReason;
49use super::logtailer::LogTailer;
50use crate::assign::Ranks;
51use crate::bootstrap;
52use crate::bootstrap::Allocator2Process;
53use crate::bootstrap::Process2Allocator;
54use crate::bootstrap::Process2AllocatorMessage;
55use crate::logging::create_log_writers;
56use crate::shortuuid::ShortUuid;
57
58/// The maximum number of log lines to tail keep for managed processes.
59const MAX_TAIL_LOG_LINES: usize = 100;
60
61pub const CLIENT_TRACE_ID_LABEL: &str = "CLIENT_TRACE_ID";
62
63/// An allocator that allocates procs by executing managed (local)
64/// processes. ProcessAllocator is configured with a [`Command`] (template)
65/// to spawn external processes. These processes must invoke [`hyperactor_mesh::bootstrap`] or
66/// [`hyperactor_mesh::bootstrap_or_die`], which is responsible for coordinating
67/// with the allocator.
68///
69/// The process allocator tees the stdout and stderr of each proc to the parent process.
70pub struct ProcessAllocator {
71    cmd: Arc<Mutex<Command>>,
72}
73
74impl ProcessAllocator {
75    /// Create a new allocator using the provided command (template).
76    /// The command is used to spawn child processes that host procs.
77    /// The binary should yield control to [`hyperactor_mesh::bootstrap`]
78    /// or [`hyperactor_mesh::bootstrap_or_die`] or after initialization.
79    pub fn new(cmd: Command) -> Self {
80        Self {
81            cmd: Arc::new(Mutex::new(cmd)),
82        }
83    }
84}
85
86#[async_trait]
87impl Allocator for ProcessAllocator {
88    type Alloc = ProcessAlloc;
89
90    async fn allocate(&mut self, spec: AllocSpec) -> Result<ProcessAlloc, AllocatorError> {
91        let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix))
92            .await
93            .map_err(anyhow::Error::from)?;
94
95        let name = ShortUuid::generate();
96        Ok(ProcessAlloc {
97            name: name.clone(),
98            world_id: WorldId(name.to_string()),
99            spec: spec.clone(),
100            bootstrap_addr,
101            rx,
102            index: 0,
103            active: HashMap::new(),
104            ranks: Ranks::new(spec.extent.num_ranks()),
105            cmd: Arc::clone(&self.cmd),
106            children: JoinSet::new(),
107            running: true,
108            failed: false,
109            client_context: ClientContext {
110                trace_id: spec
111                    .constraints
112                    .match_labels
113                    .get(CLIENT_TRACE_ID_LABEL)
114                    .cloned()
115                    .unwrap_or_else(|| "".to_string()),
116            },
117        })
118    }
119}
120
121// Client Context is saved in ProcessAlloc, and is also passed in
122// the RemoteProcessAllocator's Allocate method
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ClientContext {
125    /// Trace ID for correlating logs across client and worker processes
126    pub trace_id: String,
127}
128
129/// An allocation produced by [`ProcessAllocator`].
130pub struct ProcessAlloc {
131    name: ShortUuid,
132    world_id: WorldId, // to provide storage
133    spec: AllocSpec,
134    bootstrap_addr: ChannelAddr,
135    rx: channel::ChannelRx<Process2Allocator>,
136    index: usize,
137    active: HashMap<usize, Child>,
138    // Maps process index to its rank.
139    ranks: Ranks<usize>,
140    cmd: Arc<Mutex<Command>>,
141    children: JoinSet<(usize, ProcStopReason)>,
142    running: bool,
143    failed: bool,
144    client_context: ClientContext,
145}
146
147#[derive(EnumAsInner)]
148enum ChannelState {
149    NotConnected,
150    Connected(ChannelTx<Allocator2Process>),
151    Failed(ChannelError),
152}
153
154struct Child {
155    local_rank: usize,
156    channel: ChannelState,
157    group: monitor::Group,
158    exit_flag: Option<flag::Flag>,
159    stdout: Option<LogTailer>,
160    stderr: Option<LogTailer>,
161    stop_reason: Arc<OnceLock<ProcStopReason>>,
162    process_pid: Arc<std::sync::Mutex<Option<i32>>>,
163}
164
165impl Child {
166    fn monitored(
167        local_rank: usize,
168        mut process: tokio::process::Child,
169        log_channel: ChannelAddr,
170    ) -> (Self, impl Future<Output = ProcStopReason>) {
171        let (group, handle) = monitor::group();
172        let (exit_flag, exit_guard) = flag::guarded();
173        let stop_reason = Arc::new(OnceLock::new());
174
175        // Set up stdout and stderr writers
176        let mut stdout_tee: Box<dyn io::AsyncWrite + Send + Unpin + 'static> =
177            Box::new(io::stdout());
178        let mut stderr_tee: Box<dyn io::AsyncWrite + Send + Unpin + 'static> =
179            Box::new(io::stderr());
180
181        // Use the helper function to create both writers at once
182        match create_log_writers(local_rank, log_channel, process.id().unwrap_or(0)) {
183            Ok((stdout_writer, stderr_writer)) => {
184                stdout_tee = stdout_writer;
185                stderr_tee = stderr_writer;
186            }
187            Err(e) => {
188                tracing::error!("failed to create log writers: {}", e);
189            }
190        }
191
192        let stdout = LogTailer::tee(
193            MAX_TAIL_LOG_LINES,
194            process.stdout.take().unwrap(),
195            stdout_tee,
196        );
197
198        let stderr = LogTailer::tee(
199            MAX_TAIL_LOG_LINES,
200            process.stderr.take().unwrap(),
201            stderr_tee,
202        );
203
204        let process_pid = Arc::new(std::sync::Mutex::new(process.id().map(|id| id as i32)));
205
206        let child = Self {
207            local_rank,
208            channel: ChannelState::NotConnected,
209            group,
210            exit_flag: Some(exit_flag),
211            stdout: Some(stdout),
212            stderr: Some(stderr),
213            stop_reason: Arc::clone(&stop_reason),
214            process_pid: process_pid.clone(),
215        };
216
217        let monitor = async move {
218            let reason = tokio::select! {
219                _ = handle => {
220                    Self::ensure_killed(process_pid);
221                    Self::exit_status_to_reason(process.wait().await)
222                }
223                result = process.wait() => {
224                    Self::exit_status_to_reason(result)
225                }
226            };
227            exit_guard.signal();
228
229            stop_reason.get_or_init(|| reason).clone()
230        };
231
232        (child, monitor)
233    }
234
235    fn ensure_killed(pid: Arc<std::sync::Mutex<Option<i32>>>) {
236        match pid.lock().unwrap().take() {
237            Some(pid) => {
238                if let Err(e) = signal::kill(Pid::from_raw(pid), signal::SIGTERM) {
239                    match e {
240                        nix::errno::Errno::ESRCH => {
241                            // Process already gone.
242                            tracing::debug!("pid {} already exited", pid);
243                        }
244                        _ => {
245                            tracing::error!("failed to kill {}: {}", pid, e);
246                        }
247                    }
248                }
249            }
250            None => (),
251        }
252    }
253
254    fn exit_status_to_reason(result: io::Result<ExitStatus>) -> ProcStopReason {
255        match result {
256            Ok(status) if status.success() => ProcStopReason::Stopped,
257            Ok(status) => {
258                if let Some(signal) = status.signal() {
259                    ProcStopReason::Killed(signal, status.core_dumped())
260                } else if let Some(code) = status.code() {
261                    ProcStopReason::Exited(code, String::new())
262                } else {
263                    ProcStopReason::Unknown
264                }
265            }
266            Err(e) => {
267                tracing::error!("error waiting for process: {}", e);
268                ProcStopReason::Unknown
269            }
270        }
271    }
272    #[hyperactor::instrument_infallible]
273    fn stop(&self, reason: ProcStopReason) {
274        let _ = self.stop_reason.set(reason); // first stop wins
275        self.group.fail();
276    }
277
278    fn connect(&mut self, addr: ChannelAddr) -> bool {
279        if !self.channel.is_not_connected() {
280            return false;
281        }
282
283        match channel::dial(addr) {
284            Ok(channel) => {
285                let mut status = channel.status().clone();
286                self.channel = ChannelState::Connected(channel);
287                // Monitor the channel, killing the process if it becomes unavailable
288                // (fails keepalive).
289                self.group.spawn(async move {
290                    let _ = status
291                        .wait_for(|status| matches!(status, TxStatus::Closed))
292                        .await;
293                    Result::<(), ()>::Err(())
294                });
295            }
296            Err(err) => {
297                self.channel = ChannelState::Failed(err);
298                self.stop(ProcStopReason::Watchdog);
299            }
300        };
301        true
302    }
303
304    fn spawn_watchdog(&mut self) {
305        let Some(exit_flag) = self.exit_flag.take() else {
306            tracing::info!("exit flag set, not spawning watchdog");
307            return;
308        };
309        let group = self.group.clone();
310        let stop_reason = self.stop_reason.clone();
311        tracing::info!("spawning watchdog");
312        tokio::spawn(async move {
313            let exit_timeout =
314                hyperactor::config::global::get(hyperactor::config::PROCESS_EXIT_TIMEOUT);
315            #[allow(clippy::disallowed_methods)]
316            if tokio::time::timeout(exit_timeout, exit_flag).await.is_err() {
317                tracing::info!("watchdog timeout, killing process");
318                let _ = stop_reason.set(ProcStopReason::Watchdog);
319                group.fail();
320            }
321            tracing::info!("Watchdog task exit");
322        });
323    }
324
325    #[hyperactor::instrument_infallible]
326    fn post(&mut self, message: Allocator2Process) {
327        if let ChannelState::Connected(channel) = &mut self.channel {
328            channel.post(message);
329        } else {
330            self.stop(ProcStopReason::Watchdog);
331        }
332    }
333
334    #[cfg(test)]
335    fn fail_group(&self) {
336        self.group.fail();
337    }
338}
339
340impl Drop for Child {
341    fn drop(&mut self) {
342        Self::ensure_killed(self.process_pid.clone());
343    }
344}
345
346impl ProcessAlloc {
347    // Also implement exit (for graceful exit)
348
349    // Currently procs and processes are 1:1, so this just fully exits
350    // the process.
351
352    #[hyperactor::instrument_infallible]
353    fn stop(&mut self, proc_id: &ProcId, reason: ProcStopReason) -> Result<(), anyhow::Error> {
354        self.get_mut(proc_id)?.stop(reason);
355        Ok(())
356    }
357
358    fn get(&self, proc_id: &ProcId) -> Result<&Child, anyhow::Error> {
359        self.active.get(&self.index(proc_id)?).ok_or_else(|| {
360            anyhow::anyhow!(
361                "proc {} not currently active in alloc {}",
362                proc_id,
363                self.name
364            )
365        })
366    }
367
368    fn get_mut(&mut self, proc_id: &ProcId) -> Result<&mut Child, anyhow::Error> {
369        self.active.get_mut(&self.index(proc_id)?).ok_or_else(|| {
370            anyhow::anyhow!(
371                "proc {} not currently active in alloc {}",
372                &proc_id,
373                self.name
374            )
375        })
376    }
377
378    /// The "world name" assigned to this alloc.
379    pub(crate) fn name(&self) -> &ShortUuid {
380        &self.name
381    }
382
383    fn index(&self, proc_id: &ProcId) -> Result<usize, anyhow::Error> {
384        anyhow::ensure!(
385            proc_id
386                .world_name()
387                .expect("proc must be ranked for allocation index")
388                .parse::<ShortUuid>()?
389                == self.name,
390            "proc {} does not belong to alloc {}",
391            proc_id,
392            self.name
393        );
394        Ok(proc_id
395            .rank()
396            .expect("proc must be ranked for allocation index"))
397    }
398
399    #[hyperactor::instrument_infallible]
400    async fn maybe_spawn(&mut self) -> Option<ProcState> {
401        if self.active.len() >= self.spec.extent.num_ranks() {
402            return None;
403        }
404        let mut cmd = self.cmd.lock().await;
405        let index = self.index;
406        self.index += 1;
407        let log_channel: ChannelAddr = ChannelAddr::any(ChannelTransport::Unix);
408
409        cmd.env(
410            bootstrap::BOOTSTRAP_ADDR_ENV,
411            self.bootstrap_addr.to_string(),
412        );
413        cmd.env(
414            bootstrap::CLIENT_TRACE_ID_ENV,
415            self.client_context.trace_id.as_str(),
416        );
417        cmd.env(bootstrap::BOOTSTRAP_INDEX_ENV, index.to_string());
418        cmd.env(bootstrap::BOOTSTRAP_LOG_CHANNEL, log_channel.to_string());
419        cmd.stdout(Stdio::piped());
420        cmd.stderr(Stdio::piped());
421
422        let proc_id = ProcId::Ranked(WorldId(self.name.to_string()), index);
423        tracing::debug!("Spawning process {:?}", cmd);
424        match cmd.spawn() {
425            Err(err) => {
426                // Likely retry won't help here so fail permanently.
427                let message = format!("spawn index: {}, command: {:?}: {}", index, cmd, err);
428                tracing::error!(message);
429                self.failed = true;
430                Some(ProcState::Failed {
431                    world_id: self.world_id.clone(),
432                    description: message,
433                })
434            }
435            Ok(mut process) => {
436                let pid = process.id().unwrap_or(0);
437                match self.ranks.assign(index) {
438                    Err(_index) => {
439                        tracing::info!("could not assign rank to {}", proc_id);
440                        let _ = process.kill().await;
441                        None
442                    }
443                    Ok(rank) => {
444                        let (handle, monitor) = Child::monitored(rank, process, log_channel);
445                        self.children.spawn(async move { (index, monitor.await) });
446                        self.active.insert(index, handle);
447                        // Adjust for shape slice offset for non-zero shapes (sub-shapes).
448                        let point = self.spec.extent.point_of_rank(rank).unwrap();
449                        Some(ProcState::Created {
450                            proc_id,
451                            point,
452                            pid,
453                        })
454                    }
455                }
456            }
457        }
458    }
459
460    fn remove(&mut self, index: usize) -> Option<Child> {
461        self.ranks.unassign(index);
462        self.active.remove(&index)
463    }
464}
465
466#[async_trait]
467impl Alloc for ProcessAlloc {
468    #[hyperactor::instrument_infallible]
469    async fn next(&mut self) -> Option<ProcState> {
470        if !self.running && self.active.is_empty() {
471            return None;
472        }
473
474        loop {
475            // Do no allocate new processes if we are in failed state.
476            if self.running && !self.failed {
477                if let state @ Some(_) = self.maybe_spawn().await {
478                    return state;
479                }
480            }
481
482            let transport = self.transport().clone();
483
484            tokio::select! {
485                Ok(Process2Allocator(index, message)) = self.rx.recv() => {
486                    let child = match self.active.get_mut(&index) {
487                        None => {
488                            tracing::info!("message {:?} from zombie {}", message, index);
489                            continue;
490                        }
491                        Some(child) => child,
492                    };
493
494                    match message {
495                        Process2AllocatorMessage::Hello(addr) => {
496                            if !child.connect(addr.clone()) {
497                                tracing::error!("received multiple hellos from {}", index);
498                                continue;
499                            }
500
501                            child.post(Allocator2Process::StartProc(
502                                ProcId::Ranked(WorldId(self.name.to_string()), index),
503                                transport,
504                            ));
505                        }
506
507                        Process2AllocatorMessage::StartedProc(proc_id, mesh_agent, addr) => {
508                            break Some(ProcState::Running {
509                                proc_id,
510                                mesh_agent,
511                                addr,
512                            });
513                        }
514                        Process2AllocatorMessage::Heartbeat => {
515                            tracing::trace!("recv heartbeat from {index}");
516                        }
517                    }
518                },
519
520                Some(Ok((index, mut reason))) = self.children.join_next() => {
521                    let stderr_content =  if let Some(mut child) = self.remove(index) {
522                        let stdout = child.stdout.take().unwrap();
523                        let stderr = child.stderr.take().unwrap();
524                        stdout.abort();
525                        stderr.abort();
526                        let (_stdout, _) = stdout.join().await;
527                        let (stderr_lines, _) = stderr.join().await;
528                        stderr_lines.join("\n")
529                    } else {
530                        String::new()
531                    };
532
533                    if let ProcStopReason::Exited(code, _) = &mut reason {
534                        reason = ProcStopReason::Exited(*code, stderr_content);
535                    }
536
537                    tracing::info!("child stopped with ProcStopReason::{:?}", reason);
538
539                    break Some(ProcState::Stopped {
540                        proc_id: ProcId::Ranked(WorldId(self.name.to_string()), index),
541                        reason
542                    });
543                },
544            }
545        }
546    }
547
548    fn extent(&self) -> &Extent {
549        &self.spec.extent
550    }
551
552    fn world_id(&self) -> &WorldId {
553        &self.world_id
554    }
555
556    fn transport(&self) -> ChannelTransport {
557        ChannelTransport::Unix
558    }
559
560    async fn stop(&mut self) -> Result<(), AllocatorError> {
561        // We rely on the teardown here, and that the process should
562        // exit on its own. We should have a hard timeout here as well,
563        // so that we never rely on the system functioning correctly
564        // for liveness.
565        for (_index, child) in self.active.iter_mut() {
566            child.post(Allocator2Process::StopAndExit(0));
567            child.spawn_watchdog();
568        }
569
570        self.running = false;
571        Ok(())
572    }
573}
574
575impl Drop for ProcessAlloc {
576    fn drop(&mut self) {
577        tracing::debug!(
578            "dropping ProcessAlloc of name: {}, world id: {}",
579            self.name,
580            self.world_id
581        );
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    #[cfg(fbcode_build)] // we use an external binary, produced by buck
590    crate::alloc_test_suite!(ProcessAllocator::new(Command::new(
591        buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap()
592    )));
593
594    #[tokio::test]
595    async fn test_sigterm_on_group_fail() {
596        let bootstrap_binary = buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap();
597        let mut allocator = ProcessAllocator::new(Command::new(bootstrap_binary));
598
599        let mut alloc = allocator
600            .allocate(AllocSpec {
601                extent: ndslice::extent!(replica = 1),
602                constraints: Default::default(),
603            })
604            .await
605            .unwrap();
606
607        let proc_id = {
608            loop {
609                match alloc.next().await {
610                    Some(ProcState::Running { proc_id, .. }) => {
611                        break proc_id;
612                    }
613                    Some(ProcState::Failed { description, .. }) => {
614                        panic!("Process allocation failed: {}", description);
615                    }
616                    Some(_other) => {}
617                    None => {
618                        panic!("Allocation ended unexpectedly");
619                    }
620                }
621            }
622        };
623
624        if let Some(child) = alloc.active.get(
625            &proc_id
626                .rank()
627                .expect("proc must be ranked for allocation lookup"),
628        ) {
629            child.fail_group();
630        }
631
632        assert!(matches!(
633            alloc.next().await,
634            Some(ProcState::Stopped {
635                reason: ProcStopReason::Killed(15, false),
636                ..
637            })
638        ));
639    }
640}