1#![allow(dead_code)] use std::collections::HashMap;
12use std::future::Future;
13use std::os::unix::process::ExitStatusExt;
14use std::process::ExitStatus;
15use std::process::Stdio;
16use std::sync::Arc;
17use std::sync::OnceLock;
18
19use async_trait::async_trait;
20use enum_as_inner::EnumAsInner;
21use hyperactor::ProcId;
22use hyperactor::WorldId;
23use hyperactor::channel;
24use hyperactor::channel::ChannelAddr;
25use hyperactor::channel::ChannelError;
26use hyperactor::channel::ChannelTransport;
27use hyperactor::channel::ChannelTx;
28use hyperactor::channel::Rx;
29use hyperactor::channel::Tx;
30use hyperactor::channel::TxStatus;
31use hyperactor::sync::flag;
32use hyperactor::sync::monitor;
33use ndslice::view::Extent;
34use nix::sys::signal;
35use nix::unistd::Pid;
36use serde::Deserialize;
37use serde::Serialize;
38use tokio::io;
39use tokio::process::Command;
40use tokio::sync::Mutex;
41use tokio::task::JoinSet;
42
43use super::Alloc;
44use super::AllocSpec;
45use super::Allocator;
46use super::AllocatorError;
47use super::ProcState;
48use super::ProcStopReason;
49use super::logtailer::LogTailer;
50use crate::assign::Ranks;
51use crate::bootstrap;
52use crate::bootstrap::Allocator2Process;
53use crate::bootstrap::Process2Allocator;
54use crate::bootstrap::Process2AllocatorMessage;
55use crate::logging::create_log_writers;
56use crate::shortuuid::ShortUuid;
57
58const MAX_TAIL_LOG_LINES: usize = 100;
60
61pub const CLIENT_TRACE_ID_LABEL: &str = "CLIENT_TRACE_ID";
62
63pub struct ProcessAllocator {
71 cmd: Arc<Mutex<Command>>,
72}
73
74impl ProcessAllocator {
75 pub fn new(cmd: Command) -> Self {
80 Self {
81 cmd: Arc::new(Mutex::new(cmd)),
82 }
83 }
84}
85
86#[async_trait]
87impl Allocator for ProcessAllocator {
88 type Alloc = ProcessAlloc;
89
90 #[hyperactor::instrument(fields(name = "process_allocate", monarch_client_trace_id = spec.constraints.match_labels.get(CLIENT_TRACE_ID_LABEL).cloned().unwrap_or_else(|| "".to_string())))]
91 async fn allocate(&mut self, spec: AllocSpec) -> Result<ProcessAlloc, AllocatorError> {
92 let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix))
93 .map_err(anyhow::Error::from)?;
94
95 if spec.transport == ChannelTransport::Local {
96 return Err(AllocatorError::Other(anyhow::anyhow!(
97 "ProcessAllocator does not support local transport"
98 )));
99 }
100
101 let name = ShortUuid::generate();
102 Ok(ProcessAlloc {
103 name: name.clone(),
104 world_id: WorldId(name.to_string()),
105 spec: spec.clone(),
106 bootstrap_addr,
107 rx,
108 active: HashMap::new(),
109 ranks: Ranks::new(spec.extent.num_ranks()),
110 created: Vec::new(),
111 cmd: Arc::clone(&self.cmd),
112 children: JoinSet::new(),
113 running: true,
114 failed: false,
115 client_context: ClientContext {
116 trace_id: spec
117 .constraints
118 .match_labels
119 .get(CLIENT_TRACE_ID_LABEL)
120 .cloned()
121 .unwrap_or_else(|| "".to_string()),
122 },
123 })
124 }
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct ClientContext {
131 pub trace_id: String,
133}
134
135pub struct ProcessAlloc {
137 name: ShortUuid,
138 world_id: WorldId, spec: AllocSpec,
140 bootstrap_addr: ChannelAddr,
141 rx: channel::ChannelRx<Process2Allocator>,
142 active: HashMap<usize, Child>,
143 ranks: Ranks<usize>,
145 created: Vec<ShortUuid>,
147 cmd: Arc<Mutex<Command>>,
148 children: JoinSet<(usize, ProcStopReason)>,
149 running: bool,
150 failed: bool,
151 client_context: ClientContext,
152}
153
154#[derive(EnumAsInner)]
155enum ChannelState {
156 NotConnected,
157 Connected(ChannelTx<Allocator2Process>),
158 Failed(ChannelError),
159}
160
161struct Child {
162 local_rank: usize,
163 channel: ChannelState,
164 group: monitor::Group,
165 exit_flag: Option<flag::Flag>,
166 stdout: Option<LogTailer>,
167 stderr: Option<LogTailer>,
168 stop_reason: Arc<OnceLock<ProcStopReason>>,
169 process_pid: Arc<std::sync::Mutex<Option<i32>>>,
170}
171
172impl Child {
173 fn monitored(
174 local_rank: usize,
175 mut process: tokio::process::Child,
176 log_channel: ChannelAddr,
177 ) -> (Self, impl Future<Output = ProcStopReason>) {
178 let (group, handle) = monitor::group();
179 let (exit_flag, exit_guard) = flag::guarded();
180 let stop_reason = Arc::new(OnceLock::new());
181
182 let mut stdout_tee: Box<dyn io::AsyncWrite + Send + Unpin + 'static> =
184 Box::new(io::stdout());
185 let mut stderr_tee: Box<dyn io::AsyncWrite + Send + Unpin + 'static> =
186 Box::new(io::stderr());
187
188 match create_log_writers(local_rank, log_channel, process.id().unwrap_or(0)) {
190 Ok((stdout_writer, stderr_writer)) => {
191 stdout_tee = stdout_writer;
192 stderr_tee = stderr_writer;
193 }
194 Err(e) => {
195 tracing::error!("failed to create log writers: {}", e);
196 }
197 }
198
199 let stdout = LogTailer::tee(
200 MAX_TAIL_LOG_LINES,
201 process.stdout.take().unwrap(),
202 stdout_tee,
203 );
204
205 let stderr = LogTailer::tee(
206 MAX_TAIL_LOG_LINES,
207 process.stderr.take().unwrap(),
208 stderr_tee,
209 );
210
211 let process_pid = Arc::new(std::sync::Mutex::new(process.id().map(|id| id as i32)));
212
213 let child = Self {
214 local_rank,
215 channel: ChannelState::NotConnected,
216 group,
217 exit_flag: Some(exit_flag),
218 stdout: Some(stdout),
219 stderr: Some(stderr),
220 stop_reason: Arc::clone(&stop_reason),
221 process_pid: process_pid.clone(),
222 };
223
224 let monitor = async move {
225 let reason = tokio::select! {
226 _ = handle => {
227 Self::ensure_killed(process_pid);
228 Self::exit_status_to_reason(process.wait().await)
229 }
230 result = process.wait() => {
231 Self::exit_status_to_reason(result)
232 }
233 };
234 exit_guard.signal();
235
236 stop_reason.get_or_init(|| reason).clone()
237 };
238
239 (child, monitor)
240 }
241
242 fn ensure_killed(pid: Arc<std::sync::Mutex<Option<i32>>>) {
243 if let Some(pid) = pid.lock().unwrap().take() {
244 if let Err(e) = signal::kill(Pid::from_raw(pid), signal::SIGTERM) {
245 match e {
246 nix::errno::Errno::ESRCH => {
247 tracing::debug!("pid {} already exited", pid);
249 }
250 _ => {
251 tracing::error!("failed to kill {}: {}", pid, e);
252 }
253 }
254 }
255 }
256 }
257
258 fn exit_status_to_reason(result: io::Result<ExitStatus>) -> ProcStopReason {
259 match result {
260 Ok(status) if status.success() => ProcStopReason::Stopped,
261 Ok(status) => {
262 if let Some(signal) = status.signal() {
263 ProcStopReason::Killed(signal, status.core_dumped())
264 } else if let Some(code) = status.code() {
265 ProcStopReason::Exited(code, String::new())
266 } else {
267 ProcStopReason::Unknown
268 }
269 }
270 Err(e) => {
271 tracing::error!("error waiting for process: {}", e);
272 ProcStopReason::Unknown
273 }
274 }
275 }
276
277 #[hyperactor::instrument_infallible]
278 fn stop(&self, reason: ProcStopReason) {
279 let _ = self.stop_reason.set(reason); self.group.fail();
281 }
282
283 fn connect(&mut self, addr: ChannelAddr) -> bool {
284 if !self.channel.is_not_connected() {
285 return false;
286 }
287
288 match channel::dial(addr) {
289 Ok(channel) => {
290 let mut status = channel.status().clone();
291 self.channel = ChannelState::Connected(channel);
292 self.group.spawn(async move {
295 let _ = status
296 .wait_for(|status| matches!(status, TxStatus::Closed))
297 .await;
298 Result::<(), ()>::Err(())
299 });
300 }
301 Err(err) => {
302 self.channel = ChannelState::Failed(err);
303 self.stop(ProcStopReason::Watchdog);
304 }
305 };
306 true
307 }
308
309 fn spawn_watchdog(&mut self) {
310 let Some(exit_flag) = self.exit_flag.take() else {
311 tracing::info!("exit flag set, not spawning watchdog");
312 return;
313 };
314 let group = self.group.clone();
315 let stop_reason = self.stop_reason.clone();
316 tracing::info!("spawning watchdog");
317 tokio::spawn(async move {
318 let exit_timeout =
319 hyperactor::config::global::get(hyperactor::config::PROCESS_EXIT_TIMEOUT);
320 #[allow(clippy::disallowed_methods)]
321 if tokio::time::timeout(exit_timeout, exit_flag).await.is_err() {
322 tracing::info!("watchdog timeout, killing process");
323 let _ = stop_reason.set(ProcStopReason::Watchdog);
324 group.fail();
325 }
326 tracing::info!("Watchdog task exit");
327 });
328 }
329
330 #[hyperactor::instrument_infallible]
331 fn post(&mut self, message: Allocator2Process) {
332 if let ChannelState::Connected(channel) = &mut self.channel {
333 channel.post(message);
334 } else {
335 self.stop(ProcStopReason::Watchdog);
336 }
337 }
338
339 #[cfg(test)]
340 fn fail_group(&self) {
341 self.group.fail();
342 }
343}
344
345impl Drop for Child {
346 fn drop(&mut self) {
347 Self::ensure_killed(self.process_pid.clone());
348 }
349}
350
351impl ProcessAlloc {
352 #[hyperactor::instrument_infallible]
358 fn stop(&mut self, proc_id: &ProcId, reason: ProcStopReason) -> Result<(), anyhow::Error> {
359 self.get_mut(proc_id)?.stop(reason);
360 Ok(())
361 }
362
363 fn get(&self, proc_id: &ProcId) -> Result<&Child, anyhow::Error> {
364 self.active.get(&self.index(proc_id)?).ok_or_else(|| {
365 anyhow::anyhow!(
366 "proc {} not currently active in alloc {}",
367 proc_id,
368 self.name
369 )
370 })
371 }
372
373 fn get_mut(&mut self, proc_id: &ProcId) -> Result<&mut Child, anyhow::Error> {
374 self.active.get_mut(&self.index(proc_id)?).ok_or_else(|| {
375 anyhow::anyhow!(
376 "proc {} not currently active in alloc {}",
377 &proc_id,
378 self.name
379 )
380 })
381 }
382
383 pub(crate) fn name(&self) -> &ShortUuid {
385 &self.name
386 }
387
388 fn index(&self, proc_id: &ProcId) -> Result<usize, anyhow::Error> {
389 anyhow::ensure!(
390 proc_id
391 .world_name()
392 .expect("proc must be ranked for allocation index")
393 .parse::<ShortUuid>()?
394 == self.name,
395 "proc {} does not belong to alloc {}",
396 proc_id,
397 self.name
398 );
399 Ok(proc_id
400 .rank()
401 .expect("proc must be ranked for allocation index"))
402 }
403
404 #[hyperactor::instrument_infallible]
405 async fn maybe_spawn(&mut self) -> Option<ProcState> {
406 if self.active.len() >= self.spec.extent.num_ranks() {
407 return None;
408 }
409 let mut cmd = self.cmd.lock().await;
410 let log_channel: ChannelAddr = ChannelAddr::any(ChannelTransport::Unix);
411
412 let index = self.created.len();
413 self.created.push(ShortUuid::generate());
414 let create_key = &self.created[index];
415
416 cmd.env(
417 bootstrap::BOOTSTRAP_ADDR_ENV,
418 self.bootstrap_addr.to_string(),
419 );
420 cmd.env(
421 bootstrap::CLIENT_TRACE_ID_ENV,
422 self.client_context.trace_id.as_str(),
423 );
424 cmd.env(bootstrap::BOOTSTRAP_INDEX_ENV, index.to_string());
425 cmd.env(bootstrap::BOOTSTRAP_LOG_CHANNEL, log_channel.to_string());
426 cmd.stdout(Stdio::piped());
427 cmd.stderr(Stdio::piped());
428
429 tracing::debug!("spawning process {:?}", cmd);
430 match cmd.spawn() {
431 Err(err) => {
432 let message = format!(
434 "spawn {} index: {}, command: {:?}: {}",
435 create_key, index, cmd, err
436 );
437 tracing::error!(message);
438 self.failed = true;
439 Some(ProcState::Failed {
440 world_id: self.world_id.clone(),
441 description: message,
442 })
443 }
444 Ok(mut process) => {
445 let pid = process.id().unwrap_or(0);
446 match self.ranks.assign(index) {
447 Err(_index) => {
448 tracing::info!("could not assign rank to {}", create_key);
449 let _ = process.kill().await;
450 None
451 }
452 Ok(rank) => {
453 let (handle, monitor) = Child::monitored(rank, process, log_channel);
454 self.children.spawn(async move { (index, monitor.await) });
455 self.active.insert(index, handle);
456 let point = self.spec.extent.point_of_rank(rank).unwrap();
458 Some(ProcState::Created {
459 create_key: create_key.clone(),
460 point,
461 pid,
462 })
463 }
464 }
465 }
466 }
467 }
468
469 fn remove(&mut self, index: usize) -> Option<Child> {
470 self.ranks.unassign(index);
471 self.active.remove(&index)
472 }
473}
474
475#[async_trait]
476impl Alloc for ProcessAlloc {
477 #[hyperactor::instrument_infallible]
478 async fn next(&mut self) -> Option<ProcState> {
479 if !self.running && self.active.is_empty() {
480 return None;
481 }
482
483 loop {
484 if self.running
486 && !self.failed
487 && let state @ Some(_) = self.maybe_spawn().await
488 {
489 return state;
490 }
491
492 let transport = self.transport().clone();
493
494 tokio::select! {
495 Ok(Process2Allocator(index, message)) = self.rx.recv() => {
496 let child = match self.active.get_mut(&index) {
497 None => {
498 tracing::info!("message {:?} from zombie {}", message, index);
499 continue;
500 }
501 Some(child) => child,
502 };
503
504 match message {
505 Process2AllocatorMessage::Hello(addr) => {
506 if !child.connect(addr.clone()) {
507 tracing::error!("received multiple hellos from {}", index);
508 continue;
509 }
510
511 child.post(Allocator2Process::StartProc(
512 self.spec.proc_name.clone().map_or(
513 ProcId::Ranked(WorldId(self.name.to_string()), index),
514 |name| ProcId::Direct(addr.clone(), name)),
515 transport,
516 ));
517 }
518
519 Process2AllocatorMessage::StartedProc(proc_id, mesh_agent, addr) => {
520 break Some(ProcState::Running {
521 create_key: self.created[index].clone(),
522 proc_id,
523 mesh_agent,
524 addr,
525 });
526 }
527 Process2AllocatorMessage::Heartbeat => {
528 tracing::trace!("recv heartbeat from {index}");
529 }
530 }
531 },
532
533 Some(Ok((index, mut reason))) = self.children.join_next() => {
534 let stderr_content = if let Some(mut child) = self.remove(index) {
535 let stdout = child.stdout.take().unwrap();
536 let stderr = child.stderr.take().unwrap();
537 stdout.abort();
538 stderr.abort();
539 let (_stdout, _) = stdout.join().await;
540 let (stderr_lines, _) = stderr.join().await;
541 stderr_lines.join("\n")
542 } else {
543 String::new()
544 };
545
546 if let ProcStopReason::Exited(code, _) = &mut reason {
547 reason = ProcStopReason::Exited(*code, stderr_content);
548 }
549
550 tracing::info!("child stopped with ProcStopReason::{:?}", reason);
551
552 break Some(ProcState::Stopped {
553 create_key: self.created[index].clone(),
554 reason,
555 });
556 },
557 }
558 }
559 }
560
561 fn spec(&self) -> &AllocSpec {
562 &self.spec
563 }
564
565 fn extent(&self) -> &Extent {
566 &self.spec.extent
567 }
568
569 fn world_id(&self) -> &WorldId {
570 &self.world_id
571 }
572
573 async fn stop(&mut self) -> Result<(), AllocatorError> {
574 for (_index, child) in self.active.iter_mut() {
579 child.post(Allocator2Process::StopAndExit(0));
580 child.spawn_watchdog();
581 }
582
583 self.running = false;
584 Ok(())
585 }
586}
587
588impl Drop for ProcessAlloc {
589 fn drop(&mut self) {
590 tracing::debug!(
591 "dropping ProcessAlloc of name: {}, world id: {}",
592 self.name,
593 self.world_id
594 );
595 }
596}
597
598#[cfg(test)]
599mod tests {
600 use super::*;
601
602 #[cfg(fbcode_build)] crate::alloc_test_suite!(ProcessAllocator::new(Command::new(
604 crate::testresource::get("monarch/hyperactor_mesh/bootstrap")
605 )));
606
607 #[tokio::test]
608 async fn test_sigterm_on_group_fail() {
609 let bootstrap_binary = crate::testresource::get("monarch/hyperactor_mesh/bootstrap");
610 let mut allocator = ProcessAllocator::new(Command::new(bootstrap_binary));
611
612 let mut alloc = allocator
613 .allocate(AllocSpec {
614 extent: ndslice::extent!(replica = 1),
615 constraints: Default::default(),
616 proc_name: None,
617 transport: ChannelTransport::Unix,
618 })
619 .await
620 .unwrap();
621
622 let proc_id = {
623 loop {
624 match alloc.next().await {
625 Some(ProcState::Running { proc_id, .. }) => {
626 break proc_id;
627 }
628 Some(ProcState::Failed { description, .. }) => {
629 panic!("Process allocation failed: {}", description);
630 }
631 Some(_other) => {}
632 None => {
633 panic!("Allocation ended unexpectedly");
634 }
635 }
636 }
637 };
638
639 if let Some(child) = alloc.active.get(
640 &proc_id
641 .rank()
642 .expect("proc must be ranked for allocation lookup"),
643 ) {
644 child.fail_group();
645 }
646
647 assert!(matches!(
648 alloc.next().await,
649 Some(ProcState::Stopped {
650 reason: ProcStopReason::Killed(15, false),
651 ..
652 })
653 ));
654 }
655}