hyperactor_mesh/alloc/
process.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![allow(dead_code)] // some things currently used only in tests
10
11use std::collections::HashMap;
12use std::future::Future;
13use std::os::unix::process::ExitStatusExt;
14use std::process::ExitStatus;
15use std::process::Stdio;
16use std::sync::Arc;
17use std::sync::OnceLock;
18
19use async_trait::async_trait;
20use enum_as_inner::EnumAsInner;
21use hyperactor::ProcId;
22use hyperactor::WorldId;
23use hyperactor::channel;
24use hyperactor::channel::ChannelAddr;
25use hyperactor::channel::ChannelError;
26use hyperactor::channel::ChannelTransport;
27use hyperactor::channel::ChannelTx;
28use hyperactor::channel::Rx;
29use hyperactor::channel::Tx;
30use hyperactor::channel::TxStatus;
31use hyperactor::sync::flag;
32use hyperactor::sync::monitor;
33use ndslice::view::Extent;
34use nix::sys::signal;
35use nix::unistd::Pid;
36use serde::Deserialize;
37use serde::Serialize;
38use tokio::io;
39use tokio::process::Command;
40use tokio::sync::Mutex;
41use tokio::task::JoinSet;
42
43use super::Alloc;
44use super::AllocSpec;
45use super::Allocator;
46use super::AllocatorError;
47use super::ProcState;
48use super::ProcStopReason;
49use super::logtailer::LogTailer;
50use crate::assign::Ranks;
51use crate::bootstrap;
52use crate::bootstrap::Allocator2Process;
53use crate::bootstrap::Process2Allocator;
54use crate::bootstrap::Process2AllocatorMessage;
55use crate::logging::create_log_writers;
56use crate::shortuuid::ShortUuid;
57
58/// The maximum number of log lines to tail keep for managed processes.
59const MAX_TAIL_LOG_LINES: usize = 100;
60
61pub const CLIENT_TRACE_ID_LABEL: &str = "CLIENT_TRACE_ID";
62
63/// An allocator that allocates procs by executing managed (local)
64/// processes. ProcessAllocator is configured with a [`Command`] (template)
65/// to spawn external processes. These processes must invoke [`hyperactor_mesh::bootstrap`] or
66/// [`hyperactor_mesh::bootstrap_or_die`], which is responsible for coordinating
67/// with the allocator.
68///
69/// The process allocator tees the stdout and stderr of each proc to the parent process.
70pub struct ProcessAllocator {
71    cmd: Arc<Mutex<Command>>,
72}
73
74impl ProcessAllocator {
75    /// Create a new allocator using the provided command (template).
76    /// The command is used to spawn child processes that host procs.
77    /// The binary should yield control to [`hyperactor_mesh::bootstrap`]
78    /// or [`hyperactor_mesh::bootstrap_or_die`] or after initialization.
79    pub fn new(cmd: Command) -> Self {
80        Self {
81            cmd: Arc::new(Mutex::new(cmd)),
82        }
83    }
84}
85
86#[async_trait]
87impl Allocator for ProcessAllocator {
88    type Alloc = ProcessAlloc;
89
90    #[hyperactor::instrument(fields(name = "process_allocate", monarch_client_trace_id = spec.constraints.match_labels.get(CLIENT_TRACE_ID_LABEL).cloned().unwrap_or_else(|| "".to_string())))]
91    async fn allocate(&mut self, spec: AllocSpec) -> Result<ProcessAlloc, AllocatorError> {
92        let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix))
93            .map_err(anyhow::Error::from)?;
94
95        if spec.transport == ChannelTransport::Local {
96            return Err(AllocatorError::Other(anyhow::anyhow!(
97                "ProcessAllocator does not support local transport"
98            )));
99        }
100
101        let name = ShortUuid::generate();
102        Ok(ProcessAlloc {
103            name: name.clone(),
104            world_id: WorldId(name.to_string()),
105            spec: spec.clone(),
106            bootstrap_addr,
107            rx,
108            active: HashMap::new(),
109            ranks: Ranks::new(spec.extent.num_ranks()),
110            created: Vec::new(),
111            cmd: Arc::clone(&self.cmd),
112            children: JoinSet::new(),
113            running: true,
114            failed: false,
115            client_context: ClientContext {
116                trace_id: spec
117                    .constraints
118                    .match_labels
119                    .get(CLIENT_TRACE_ID_LABEL)
120                    .cloned()
121                    .unwrap_or_else(|| "".to_string()),
122            },
123        })
124    }
125}
126
127// Client Context is saved in ProcessAlloc, and is also passed in
128// the RemoteProcessAllocator's Allocate method
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct ClientContext {
131    /// Trace ID for correlating logs across client and worker processes
132    pub trace_id: String,
133}
134
135/// An allocation produced by [`ProcessAllocator`].
136pub struct ProcessAlloc {
137    name: ShortUuid,
138    world_id: WorldId, // to provide storage
139    spec: AllocSpec,
140    bootstrap_addr: ChannelAddr,
141    rx: channel::ChannelRx<Process2Allocator>,
142    active: HashMap<usize, Child>,
143    // Maps process index to its rank.
144    ranks: Ranks<usize>,
145    // Created processes by index.
146    created: Vec<ShortUuid>,
147    cmd: Arc<Mutex<Command>>,
148    children: JoinSet<(usize, ProcStopReason)>,
149    running: bool,
150    failed: bool,
151    client_context: ClientContext,
152}
153
154#[derive(EnumAsInner)]
155enum ChannelState {
156    NotConnected,
157    Connected(ChannelTx<Allocator2Process>),
158    Failed(ChannelError),
159}
160
161struct Child {
162    local_rank: usize,
163    channel: ChannelState,
164    group: monitor::Group,
165    exit_flag: Option<flag::Flag>,
166    stdout: Option<LogTailer>,
167    stderr: Option<LogTailer>,
168    stop_reason: Arc<OnceLock<ProcStopReason>>,
169    process_pid: Arc<std::sync::Mutex<Option<i32>>>,
170}
171
172impl Child {
173    fn monitored(
174        local_rank: usize,
175        mut process: tokio::process::Child,
176        log_channel: ChannelAddr,
177    ) -> (Self, impl Future<Output = ProcStopReason>) {
178        let (group, handle) = monitor::group();
179        let (exit_flag, exit_guard) = flag::guarded();
180        let stop_reason = Arc::new(OnceLock::new());
181
182        // Set up stdout and stderr writers
183        let mut stdout_tee: Box<dyn io::AsyncWrite + Send + Unpin + 'static> =
184            Box::new(io::stdout());
185        let mut stderr_tee: Box<dyn io::AsyncWrite + Send + Unpin + 'static> =
186            Box::new(io::stderr());
187
188        // Use the helper function to create both writers at once
189        match create_log_writers(local_rank, log_channel, process.id().unwrap_or(0)) {
190            Ok((stdout_writer, stderr_writer)) => {
191                stdout_tee = stdout_writer;
192                stderr_tee = stderr_writer;
193            }
194            Err(e) => {
195                tracing::error!("failed to create log writers: {}", e);
196            }
197        }
198
199        let stdout = LogTailer::tee(
200            MAX_TAIL_LOG_LINES,
201            process.stdout.take().unwrap(),
202            stdout_tee,
203        );
204
205        let stderr = LogTailer::tee(
206            MAX_TAIL_LOG_LINES,
207            process.stderr.take().unwrap(),
208            stderr_tee,
209        );
210
211        let process_pid = Arc::new(std::sync::Mutex::new(process.id().map(|id| id as i32)));
212
213        let child = Self {
214            local_rank,
215            channel: ChannelState::NotConnected,
216            group,
217            exit_flag: Some(exit_flag),
218            stdout: Some(stdout),
219            stderr: Some(stderr),
220            stop_reason: Arc::clone(&stop_reason),
221            process_pid: process_pid.clone(),
222        };
223
224        let monitor = async move {
225            let reason = tokio::select! {
226                _ = handle => {
227                    Self::ensure_killed(process_pid);
228                    Self::exit_status_to_reason(process.wait().await)
229                }
230                result = process.wait() => {
231                    Self::exit_status_to_reason(result)
232                }
233            };
234            exit_guard.signal();
235
236            stop_reason.get_or_init(|| reason).clone()
237        };
238
239        (child, monitor)
240    }
241
242    fn ensure_killed(pid: Arc<std::sync::Mutex<Option<i32>>>) {
243        if let Some(pid) = pid.lock().unwrap().take() {
244            if let Err(e) = signal::kill(Pid::from_raw(pid), signal::SIGTERM) {
245                match e {
246                    nix::errno::Errno::ESRCH => {
247                        // Process already gone.
248                        tracing::debug!("pid {} already exited", pid);
249                    }
250                    _ => {
251                        tracing::error!("failed to kill {}: {}", pid, e);
252                    }
253                }
254            }
255        }
256    }
257
258    fn exit_status_to_reason(result: io::Result<ExitStatus>) -> ProcStopReason {
259        match result {
260            Ok(status) if status.success() => ProcStopReason::Stopped,
261            Ok(status) => {
262                if let Some(signal) = status.signal() {
263                    ProcStopReason::Killed(signal, status.core_dumped())
264                } else if let Some(code) = status.code() {
265                    ProcStopReason::Exited(code, String::new())
266                } else {
267                    ProcStopReason::Unknown
268                }
269            }
270            Err(e) => {
271                tracing::error!("error waiting for process: {}", e);
272                ProcStopReason::Unknown
273            }
274        }
275    }
276
277    #[hyperactor::instrument_infallible]
278    fn stop(&self, reason: ProcStopReason) {
279        let _ = self.stop_reason.set(reason); // first stop wins
280        self.group.fail();
281    }
282
283    fn connect(&mut self, addr: ChannelAddr) -> bool {
284        if !self.channel.is_not_connected() {
285            return false;
286        }
287
288        match channel::dial(addr) {
289            Ok(channel) => {
290                let mut status = channel.status().clone();
291                self.channel = ChannelState::Connected(channel);
292                // Monitor the channel, killing the process if it becomes unavailable
293                // (fails keepalive).
294                self.group.spawn(async move {
295                    let _ = status
296                        .wait_for(|status| matches!(status, TxStatus::Closed))
297                        .await;
298                    Result::<(), ()>::Err(())
299                });
300            }
301            Err(err) => {
302                self.channel = ChannelState::Failed(err);
303                self.stop(ProcStopReason::Watchdog);
304            }
305        };
306        true
307    }
308
309    fn spawn_watchdog(&mut self) {
310        let Some(exit_flag) = self.exit_flag.take() else {
311            tracing::info!("exit flag set, not spawning watchdog");
312            return;
313        };
314        let group = self.group.clone();
315        let stop_reason = self.stop_reason.clone();
316        tracing::info!("spawning watchdog");
317        tokio::spawn(async move {
318            let exit_timeout =
319                hyperactor::config::global::get(hyperactor::config::PROCESS_EXIT_TIMEOUT);
320            #[allow(clippy::disallowed_methods)]
321            if tokio::time::timeout(exit_timeout, exit_flag).await.is_err() {
322                tracing::info!("watchdog timeout, killing process");
323                let _ = stop_reason.set(ProcStopReason::Watchdog);
324                group.fail();
325            }
326            tracing::info!("Watchdog task exit");
327        });
328    }
329
330    #[hyperactor::instrument_infallible]
331    fn post(&mut self, message: Allocator2Process) {
332        if let ChannelState::Connected(channel) = &mut self.channel {
333            channel.post(message);
334        } else {
335            self.stop(ProcStopReason::Watchdog);
336        }
337    }
338
339    #[cfg(test)]
340    fn fail_group(&self) {
341        self.group.fail();
342    }
343}
344
345impl Drop for Child {
346    fn drop(&mut self) {
347        Self::ensure_killed(self.process_pid.clone());
348    }
349}
350
351impl ProcessAlloc {
352    // Also implement exit (for graceful exit)
353
354    // Currently procs and processes are 1:1, so this just fully exits
355    // the process.
356
357    #[hyperactor::instrument_infallible]
358    fn stop(&mut self, proc_id: &ProcId, reason: ProcStopReason) -> Result<(), anyhow::Error> {
359        self.get_mut(proc_id)?.stop(reason);
360        Ok(())
361    }
362
363    fn get(&self, proc_id: &ProcId) -> Result<&Child, anyhow::Error> {
364        self.active.get(&self.index(proc_id)?).ok_or_else(|| {
365            anyhow::anyhow!(
366                "proc {} not currently active in alloc {}",
367                proc_id,
368                self.name
369            )
370        })
371    }
372
373    fn get_mut(&mut self, proc_id: &ProcId) -> Result<&mut Child, anyhow::Error> {
374        self.active.get_mut(&self.index(proc_id)?).ok_or_else(|| {
375            anyhow::anyhow!(
376                "proc {} not currently active in alloc {}",
377                &proc_id,
378                self.name
379            )
380        })
381    }
382
383    /// The "world name" assigned to this alloc.
384    pub(crate) fn name(&self) -> &ShortUuid {
385        &self.name
386    }
387
388    fn index(&self, proc_id: &ProcId) -> Result<usize, anyhow::Error> {
389        anyhow::ensure!(
390            proc_id
391                .world_name()
392                .expect("proc must be ranked for allocation index")
393                .parse::<ShortUuid>()?
394                == self.name,
395            "proc {} does not belong to alloc {}",
396            proc_id,
397            self.name
398        );
399        Ok(proc_id
400            .rank()
401            .expect("proc must be ranked for allocation index"))
402    }
403
404    #[hyperactor::instrument_infallible]
405    async fn maybe_spawn(&mut self) -> Option<ProcState> {
406        if self.active.len() >= self.spec.extent.num_ranks() {
407            return None;
408        }
409        let mut cmd = self.cmd.lock().await;
410        let log_channel: ChannelAddr = ChannelAddr::any(ChannelTransport::Unix);
411
412        let index = self.created.len();
413        self.created.push(ShortUuid::generate());
414        let create_key = &self.created[index];
415
416        cmd.env(
417            bootstrap::BOOTSTRAP_ADDR_ENV,
418            self.bootstrap_addr.to_string(),
419        );
420        cmd.env(
421            bootstrap::CLIENT_TRACE_ID_ENV,
422            self.client_context.trace_id.as_str(),
423        );
424        cmd.env(bootstrap::BOOTSTRAP_INDEX_ENV, index.to_string());
425        cmd.env(bootstrap::BOOTSTRAP_LOG_CHANNEL, log_channel.to_string());
426        cmd.stdout(Stdio::piped());
427        cmd.stderr(Stdio::piped());
428
429        tracing::debug!("spawning process {:?}", cmd);
430        match cmd.spawn() {
431            Err(err) => {
432                // Likely retry won't help here so fail permanently.
433                let message = format!(
434                    "spawn {} index: {}, command: {:?}: {}",
435                    create_key, index, cmd, err
436                );
437                tracing::error!(message);
438                self.failed = true;
439                Some(ProcState::Failed {
440                    world_id: self.world_id.clone(),
441                    description: message,
442                })
443            }
444            Ok(mut process) => {
445                let pid = process.id().unwrap_or(0);
446                match self.ranks.assign(index) {
447                    Err(_index) => {
448                        tracing::info!("could not assign rank to {}", create_key);
449                        let _ = process.kill().await;
450                        None
451                    }
452                    Ok(rank) => {
453                        let (handle, monitor) = Child::monitored(rank, process, log_channel);
454                        self.children.spawn(async move { (index, monitor.await) });
455                        self.active.insert(index, handle);
456                        // Adjust for shape slice offset for non-zero shapes (sub-shapes).
457                        let point = self.spec.extent.point_of_rank(rank).unwrap();
458                        Some(ProcState::Created {
459                            create_key: create_key.clone(),
460                            point,
461                            pid,
462                        })
463                    }
464                }
465            }
466        }
467    }
468
469    fn remove(&mut self, index: usize) -> Option<Child> {
470        self.ranks.unassign(index);
471        self.active.remove(&index)
472    }
473}
474
475#[async_trait]
476impl Alloc for ProcessAlloc {
477    #[hyperactor::instrument_infallible]
478    async fn next(&mut self) -> Option<ProcState> {
479        if !self.running && self.active.is_empty() {
480            return None;
481        }
482
483        loop {
484            // Do no allocate new processes if we are in failed state.
485            if self.running
486                && !self.failed
487                && let state @ Some(_) = self.maybe_spawn().await
488            {
489                return state;
490            }
491
492            let transport = self.transport().clone();
493
494            tokio::select! {
495                Ok(Process2Allocator(index, message)) = self.rx.recv() => {
496                    let child = match self.active.get_mut(&index) {
497                        None => {
498                            tracing::info!("message {:?} from zombie {}", message, index);
499                            continue;
500                        }
501                        Some(child) => child,
502                    };
503
504                    match message {
505                        Process2AllocatorMessage::Hello(addr) => {
506                            if !child.connect(addr.clone()) {
507                                tracing::error!("received multiple hellos from {}", index);
508                                continue;
509                            }
510
511                            child.post(Allocator2Process::StartProc(
512                                self.spec.proc_name.clone().map_or(
513                                    ProcId::Ranked(WorldId(self.name.to_string()), index),
514                                    |name| ProcId::Direct(addr.clone(), name)),
515                                transport,
516                            ));
517                        }
518
519                        Process2AllocatorMessage::StartedProc(proc_id, mesh_agent, addr) => {
520                            break Some(ProcState::Running {
521                                create_key: self.created[index].clone(),
522                                proc_id,
523                                mesh_agent,
524                                addr,
525                            });
526                        }
527                        Process2AllocatorMessage::Heartbeat => {
528                            tracing::trace!("recv heartbeat from {index}");
529                        }
530                    }
531                },
532
533                Some(Ok((index, mut reason))) = self.children.join_next() => {
534                    let stderr_content =  if let Some(mut child) = self.remove(index) {
535                        let stdout = child.stdout.take().unwrap();
536                        let stderr = child.stderr.take().unwrap();
537                        stdout.abort();
538                        stderr.abort();
539                        let (_stdout, _) = stdout.join().await;
540                        let (stderr_lines, _) = stderr.join().await;
541                        stderr_lines.join("\n")
542                    } else {
543                        String::new()
544                    };
545
546                    if let ProcStopReason::Exited(code, _) = &mut reason {
547                        reason = ProcStopReason::Exited(*code, stderr_content);
548                    }
549
550                    tracing::info!("child stopped with ProcStopReason::{:?}", reason);
551
552                    break Some(ProcState::Stopped {
553                        create_key: self.created[index].clone(),
554                        reason,
555                    });
556                },
557            }
558        }
559    }
560
561    fn spec(&self) -> &AllocSpec {
562        &self.spec
563    }
564
565    fn extent(&self) -> &Extent {
566        &self.spec.extent
567    }
568
569    fn world_id(&self) -> &WorldId {
570        &self.world_id
571    }
572
573    async fn stop(&mut self) -> Result<(), AllocatorError> {
574        // We rely on the teardown here, and that the process should
575        // exit on its own. We should have a hard timeout here as well,
576        // so that we never rely on the system functioning correctly
577        // for liveness.
578        for (_index, child) in self.active.iter_mut() {
579            child.post(Allocator2Process::StopAndExit(0));
580            child.spawn_watchdog();
581        }
582
583        self.running = false;
584        Ok(())
585    }
586}
587
588impl Drop for ProcessAlloc {
589    fn drop(&mut self) {
590        tracing::debug!(
591            "dropping ProcessAlloc of name: {}, world id: {}",
592            self.name,
593            self.world_id
594        );
595    }
596}
597
598#[cfg(test)]
599mod tests {
600    use super::*;
601
602    #[cfg(fbcode_build)] // we use an external binary, produced by buck
603    crate::alloc_test_suite!(ProcessAllocator::new(Command::new(
604        crate::testresource::get("monarch/hyperactor_mesh/bootstrap")
605    )));
606
607    #[tokio::test]
608    async fn test_sigterm_on_group_fail() {
609        let bootstrap_binary = crate::testresource::get("monarch/hyperactor_mesh/bootstrap");
610        let mut allocator = ProcessAllocator::new(Command::new(bootstrap_binary));
611
612        let mut alloc = allocator
613            .allocate(AllocSpec {
614                extent: ndslice::extent!(replica = 1),
615                constraints: Default::default(),
616                proc_name: None,
617                transport: ChannelTransport::Unix,
618            })
619            .await
620            .unwrap();
621
622        let proc_id = {
623            loop {
624                match alloc.next().await {
625                    Some(ProcState::Running { proc_id, .. }) => {
626                        break proc_id;
627                    }
628                    Some(ProcState::Failed { description, .. }) => {
629                        panic!("Process allocation failed: {}", description);
630                    }
631                    Some(_other) => {}
632                    None => {
633                        panic!("Allocation ended unexpectedly");
634                    }
635                }
636            }
637        };
638
639        if let Some(child) = alloc.active.get(
640            &proc_id
641                .rank()
642                .expect("proc must be ranked for allocation lookup"),
643        ) {
644            child.fail_group();
645        }
646
647        assert!(matches!(
648            alloc.next().await,
649            Some(ProcState::Stopped {
650                reason: ProcStopReason::Killed(15, false),
651                ..
652            })
653        ));
654    }
655}