1#![allow(dead_code)] use std::collections::HashMap;
12use std::future::Future;
13use std::os::unix::process::ExitStatusExt;
14use std::process::ExitStatus;
15use std::process::Stdio;
16use std::sync::Arc;
17use std::sync::OnceLock;
18
19use async_trait::async_trait;
20use enum_as_inner::EnumAsInner;
21use hyperactor::channel;
22use hyperactor::channel::ChannelAddr;
23use hyperactor::channel::ChannelError;
24use hyperactor::channel::ChannelTransport;
25use hyperactor::channel::ChannelTx;
26use hyperactor::channel::Rx;
27use hyperactor::channel::Tx;
28use hyperactor::reference as hyperactor_reference;
29use hyperactor::sync::flag;
30use hyperactor::sync::monitor;
31use ndslice::view::Extent;
32use nix::sys::signal;
33use nix::unistd::Pid;
34use serde::Deserialize;
35use serde::Serialize;
36use tokio::io;
37use tokio::process::Command;
38use tokio::sync::Mutex;
39use tokio::task::JoinSet;
40
41use super::Alloc;
42use super::AllocName;
43use super::AllocSpec;
44use super::Allocator;
45use super::AllocatorError;
46use super::ProcState;
47use super::ProcStopReason;
48use crate::assign::Ranks;
49use crate::bootstrap;
50use crate::bootstrap::Allocator2Process;
51use crate::bootstrap::MESH_ENABLE_LOG_FORWARDING;
52use crate::bootstrap::MESH_TAIL_LOG_LINES;
53use crate::bootstrap::Process2Allocator;
54use crate::bootstrap::Process2AllocatorMessage;
55use crate::logging::OutputTarget;
56use crate::logging::StreamFwder;
57use crate::shortuuid::ShortUuid;
58
59pub const CLIENT_TRACE_ID_LABEL: &str = "CLIENT_TRACE_ID";
60
61pub struct ProcessAllocator {
69 cmd: Arc<Mutex<Command>>,
70}
71
72impl ProcessAllocator {
73 pub fn new(cmd: Command) -> Self {
78 Self {
79 cmd: Arc::new(Mutex::new(cmd)),
80 }
81 }
82}
83
84#[async_trait]
85impl Allocator for ProcessAllocator {
86 type Alloc = ProcessAlloc;
87
88 #[hyperactor::instrument(fields(name = "process_allocate", monarch_client_trace_id = spec.constraints.match_labels.get(CLIENT_TRACE_ID_LABEL).cloned().unwrap_or_else(|| "".to_string())))]
89 async fn allocate(&mut self, spec: AllocSpec) -> Result<ProcessAlloc, AllocatorError> {
90 let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix))
91 .map_err(anyhow::Error::from)?;
92
93 if spec.transport == ChannelTransport::Local {
94 return Err(AllocatorError::Other(anyhow::anyhow!(
95 "ProcessAllocator does not support local transport"
96 )));
97 }
98
99 let name = ShortUuid::generate();
100 let alloc_name = AllocName(name.to_string());
101 tracing::info!(
102 name = "ProcessAllocStatus",
103 alloc_name = %alloc_name,
104 addr = %bootstrap_addr,
105 status = "Allocated",
106 );
107 Ok(ProcessAlloc {
108 name: name.clone(),
109 alloc_name,
110 spec: spec.clone(),
111 bootstrap_addr,
112 rx,
113 active: HashMap::new(),
114 ranks: Ranks::new(spec.extent.num_ranks()),
115 created: Vec::new(),
116 cmd: Arc::clone(&self.cmd),
117 children: JoinSet::new(),
118 running: true,
119 failed: false,
120 client_context: ClientContext {
121 trace_id: spec
122 .constraints
123 .match_labels
124 .get(CLIENT_TRACE_ID_LABEL)
125 .cloned()
126 .unwrap_or_else(|| "".to_string()),
127 },
128 })
129 }
130}
131
132#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct ClientContext {
136 pub trace_id: String,
138}
139
140pub struct ProcessAlloc {
142 name: ShortUuid,
143 alloc_name: AllocName,
144 spec: AllocSpec,
145 bootstrap_addr: ChannelAddr,
146 rx: channel::ChannelRx<Process2Allocator>,
147 active: HashMap<usize, Child>,
148 ranks: Ranks<usize>,
150 created: Vec<ShortUuid>,
152 cmd: Arc<Mutex<Command>>,
153 children: JoinSet<(usize, ProcStopReason)>,
154 running: bool,
155 failed: bool,
156 client_context: ClientContext,
157}
158
159#[derive(EnumAsInner)]
160enum ChannelState {
161 NotConnected,
162 Connected(ChannelTx<Allocator2Process>),
163 Failed(ChannelError),
164}
165
166struct Child {
167 local_rank: usize,
168 channel: ChannelState,
169 group: monitor::Group,
170 exit_flag: Option<flag::Flag>,
171 stdout_fwder: Arc<std::sync::Mutex<Option<StreamFwder>>>,
172 stderr_fwder: Arc<std::sync::Mutex<Option<StreamFwder>>>,
173 stop_reason: Arc<OnceLock<ProcStopReason>>,
174 process_pid: Arc<std::sync::Mutex<Option<i32>>>,
175}
176
177impl Child {
178 fn monitored(
179 local_rank: usize,
180 mut process: tokio::process::Child,
181 log_channel: Option<ChannelAddr>,
182 tail_size: usize,
183 proc_id: hyperactor_reference::ProcId,
184 ) -> (Self, impl Future<Output = ProcStopReason>) {
185 let (group, handle) = monitor::group();
186 let (exit_flag, exit_guard) = flag::guarded();
187 let stop_reason = Arc::new(OnceLock::new());
188 let process_pid = Arc::new(std::sync::Mutex::new(process.id().map(|id| id as i32)));
189
190 let stdout_pipe = process.stdout.take();
209 let stderr_pipe = process.stderr.take();
210
211 let child = Self {
212 local_rank,
213 channel: ChannelState::NotConnected,
214 group,
215 exit_flag: Some(exit_flag),
216 stdout_fwder: Arc::new(std::sync::Mutex::new(None)),
217 stderr_fwder: Arc::new(std::sync::Mutex::new(None)),
218 stop_reason: Arc::clone(&stop_reason),
219 process_pid: process_pid.clone(),
220 };
221
222 let child_stdout_fwder = child.stdout_fwder.clone();
224 let child_stderr_fwder = child.stderr_fwder.clone();
225
226 if let Some(stdout) = stdout_pipe {
227 let stdout_fwder = child_stdout_fwder.clone();
228 let log_channel_clone = log_channel.clone();
229 let proc_id_clone = proc_id.clone();
230 *stdout_fwder.lock().expect("stdout_fwder mutex poisoned") = Some(StreamFwder::start(
231 stdout,
232 None, OutputTarget::Stdout,
234 tail_size,
235 log_channel_clone, &proc_id_clone,
237 local_rank,
238 ));
239 }
240
241 if let Some(stderr) = stderr_pipe {
242 let stderr_fwder = child_stderr_fwder.clone();
243 *stderr_fwder.lock().expect("stderr_fwder mutex poisoned") = Some(StreamFwder::start(
244 stderr,
245 None, OutputTarget::Stderr,
247 tail_size,
248 log_channel, &proc_id,
250 local_rank,
251 ));
252 }
253
254 let monitor = async move {
255 let reason = tokio::select! {
256 _ = handle => {
257 Self::ensure_killed(process_pid);
258 Self::exit_status_to_reason(process.wait().await)
259 }
260 result = process.wait() => {
261 Self::exit_status_to_reason(result)
262 }
263 };
264 exit_guard.signal();
265
266 stop_reason.get_or_init(|| reason).clone()
267 };
268
269 (child, monitor)
270 }
271
272 fn ensure_killed(pid: Arc<std::sync::Mutex<Option<i32>>>) {
273 if let Some(pid) = pid.lock().unwrap().take() {
274 if let Err(e) = signal::kill(Pid::from_raw(pid), signal::SIGTERM) {
275 match e {
276 nix::errno::Errno::ESRCH => {
277 tracing::debug!("pid {} already exited", pid);
279 }
280 _ => {
281 tracing::error!("failed to kill {}: {}", pid, e);
282 }
283 }
284 }
285 }
286 }
287
288 fn exit_status_to_reason(result: io::Result<ExitStatus>) -> ProcStopReason {
289 match result {
290 Ok(status) if status.success() => ProcStopReason::Stopped,
291 Ok(status) => {
292 if let Some(signal) = status.signal() {
293 ProcStopReason::Killed(signal, status.core_dumped())
294 } else if let Some(code) = status.code() {
295 ProcStopReason::Exited(code, String::new())
296 } else {
297 ProcStopReason::Unknown
298 }
299 }
300 Err(e) => {
301 tracing::error!("error waiting for process: {}", e);
302 ProcStopReason::Unknown
303 }
304 }
305 }
306
307 #[hyperactor::instrument_infallible]
308 fn stop(&self, reason: ProcStopReason) {
309 let _ = self.stop_reason.set(reason); self.group.fail();
311 }
312
313 fn connect(&mut self, addr: ChannelAddr) -> bool {
314 if !self.channel.is_not_connected() {
315 return false;
316 }
317
318 match channel::dial(addr) {
319 Ok(channel) => {
320 let mut status = channel.status().clone();
321 self.channel = ChannelState::Connected(channel);
322 self.group.spawn(async move {
325 let _ = status.wait_for(|status| status.is_closed()).await;
326 Result::<(), ()>::Err(())
327 });
328 }
329 Err(err) => {
330 self.channel = ChannelState::Failed(err);
331 self.stop(ProcStopReason::Watchdog);
332 }
333 };
334 true
335 }
336
337 fn spawn_watchdog(&mut self) {
338 let Some(exit_flag) = self.exit_flag.take() else {
339 tracing::info!("exit flag set, not spawning watchdog");
340 return;
341 };
342 let group = self.group.clone();
343 let stop_reason = self.stop_reason.clone();
344 tracing::info!("spawning watchdog");
345 tokio::spawn(async move {
346 let exit_timeout =
347 hyperactor_config::global::get(hyperactor::config::PROCESS_EXIT_TIMEOUT);
348 if tokio::time::timeout(exit_timeout, exit_flag).await.is_err() {
349 tracing::info!("watchdog timeout, killing process");
350 let _ = stop_reason.set(ProcStopReason::Watchdog);
351 group.fail();
352 }
353 tracing::info!("Watchdog task exit");
354 });
355 }
356
357 #[hyperactor::instrument_infallible]
358 fn post(&mut self, message: Allocator2Process) {
359 if let ChannelState::Connected(channel) = &mut self.channel {
360 channel.post(message);
361 } else {
362 self.stop(ProcStopReason::Watchdog);
363 }
364 }
365
366 #[cfg(test)]
367 fn fail_group(&self) {
368 self.group.fail();
369 }
370
371 fn take_stream_monitors(&self) -> (Option<StreamFwder>, Option<StreamFwder>) {
372 let out = self
373 .stdout_fwder
374 .lock()
375 .expect("stdout_tailer mutex poisoned")
376 .take();
377 let err = self
378 .stderr_fwder
379 .lock()
380 .expect("stderr_tailer mutex poisoned")
381 .take();
382 (out, err)
383 }
384}
385
386impl Drop for Child {
387 fn drop(&mut self) {
388 Self::ensure_killed(self.process_pid.clone());
389 }
390}
391
392impl ProcessAlloc {
393 #[hyperactor::instrument_infallible]
399 fn stop(
400 &mut self,
401 proc_id: &hyperactor_reference::ProcId,
402 reason: ProcStopReason,
403 ) -> Result<(), anyhow::Error> {
404 self.get_mut(proc_id)?.stop(reason);
405 Ok(())
406 }
407
408 fn get(&self, proc_id: &hyperactor_reference::ProcId) -> Result<&Child, anyhow::Error> {
409 self.active.get(&self.index(proc_id)?).ok_or_else(|| {
410 anyhow::anyhow!(
411 "proc {} not currently active in alloc {}",
412 proc_id,
413 self.name
414 )
415 })
416 }
417
418 fn get_mut(
419 &mut self,
420 proc_id: &hyperactor_reference::ProcId,
421 ) -> Result<&mut Child, anyhow::Error> {
422 self.active.get_mut(&self.index(proc_id)?).ok_or_else(|| {
423 anyhow::anyhow!(
424 "proc {} not currently active in alloc {}",
425 &proc_id,
426 self.name
427 )
428 })
429 }
430
431 pub(crate) fn name(&self) -> &ShortUuid {
433 &self.name
434 }
435
436 fn index(&self, proc_id: &hyperactor_reference::ProcId) -> Result<usize, anyhow::Error> {
437 let name = proc_id.name();
439 let expected_prefix = format!("{}_", self.name);
440 anyhow::ensure!(
441 name.starts_with(&expected_prefix),
442 "proc {} does not belong to alloc {}",
443 proc_id,
444 self.name
445 );
446 let index_str = &name[expected_prefix.len()..];
447 index_str
448 .parse::<usize>()
449 .map_err(|e| anyhow::anyhow!("failed to parse index from proc name '{}': {}", name, e))
450 }
451
452 #[hyperactor::instrument_infallible]
453 async fn maybe_spawn(&mut self) -> Option<ProcState> {
454 if self.active.len() >= self.spec.extent.num_ranks() {
455 return None;
456 }
457 let mut cmd = self.cmd.lock().await;
458
459 let enable_forwarding = hyperactor_config::global::get(MESH_ENABLE_LOG_FORWARDING);
467 let tail_size = hyperactor_config::global::get(MESH_TAIL_LOG_LINES);
468 if enable_forwarding || tail_size > 0 {
469 cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
470 } else {
471 cmd.stdout(Stdio::inherit()).stderr(Stdio::inherit());
472 tracing::info!(
473 "child stdio NOT captured (forwarding/file_capture/tail all disabled); \
474 inheriting parent console"
475 );
476 }
477 let log_channel: Option<ChannelAddr> = None;
483
484 let index = self.created.len();
485 self.created.push(ShortUuid::generate());
486 let create_key = &self.created[index];
487
488 let client_config = hyperactor_config::global::attrs();
490 let bootstrap = bootstrap::Bootstrap::V0ProcMesh {
491 config: Some(client_config),
492 };
493 bootstrap.to_env(&mut cmd);
494
495 cmd.env(
496 bootstrap::BOOTSTRAP_ADDR_ENV,
497 self.bootstrap_addr.to_string(),
498 );
499 cmd.env(
500 bootstrap::CLIENT_TRACE_ID_ENV,
501 self.client_context.trace_id.as_str(),
502 );
503 cmd.env(bootstrap::BOOTSTRAP_INDEX_ENV, index.to_string());
504
505 cmd.env(
506 "HYPERACTOR_PROCESS_NAME",
507 format!(
508 "host rank:{} @{}",
509 index,
510 hostname::get()
511 .unwrap_or_else(|_| "unknown_host".into())
512 .into_string()
513 .unwrap_or("unknown_host".to_string())
514 ),
515 );
516
517 tracing::debug!("spawning process {:?}", cmd);
518 match cmd.spawn() {
519 Err(err) => {
520 let message = format!(
522 "spawn {} index: {}, command: {:?}: {}",
523 create_key, index, cmd, err
524 );
525 tracing::error!(message);
526 self.failed = true;
527 Some(ProcState::Failed {
528 alloc_name: self.alloc_name.clone(),
529 description: message,
530 })
531 }
532 Ok(mut process) => {
533 let pid = process.id().unwrap_or(0);
534 match self.ranks.assign(index) {
535 Err(_index) => {
536 tracing::info!("could not assign rank to {}", create_key);
537 let _ = process.kill().await;
538 None
539 }
540 Ok(rank) => {
541 let temp_addr = ChannelAddr::any(ChannelTransport::Local);
544 let proc_id = hyperactor_reference::ProcId::with_name(
545 temp_addr,
546 format!("{}_{}", self.alloc_name.name(), rank),
547 );
548 let (handle, monitor) =
549 Child::monitored(rank, process, log_channel, tail_size, proc_id);
550
551 self.active.insert(index, handle);
554
555 self.children.spawn(async move { (index, monitor.await) });
557
558 let point = self.spec.extent.point_of_rank(rank).unwrap();
560 Some(ProcState::Created {
561 create_key: create_key.clone(),
562 point,
563 pid,
564 })
565 }
566 }
567 }
568 }
569 }
570
571 fn remove(&mut self, index: usize) -> Option<Child> {
572 self.ranks.unassign(index);
573 self.active.remove(&index)
574 }
575}
576
577#[async_trait]
578impl Alloc for ProcessAlloc {
579 #[hyperactor::instrument_infallible]
580 async fn next(&mut self) -> Option<ProcState> {
581 if !self.running && self.active.is_empty() {
582 return None;
583 }
584
585 loop {
586 if self.running
588 && !self.failed
589 && let state @ Some(_) = self.maybe_spawn().await
590 {
591 return state;
592 }
593
594 let transport = self.transport().clone();
595
596 tokio::select! {
597 Ok(Process2Allocator(index, message)) = self.rx.recv() => {
598 let child = match self.active.get_mut(&index) {
599 None => {
600 tracing::info!("message {:?} from zombie {}", message, index);
601 continue;
602 }
603 Some(child) => child,
604 };
605
606 match message {
607 Process2AllocatorMessage::Hello(addr) => {
608 if !child.connect(addr.clone()) {
609 tracing::error!("received multiple hellos from {}", index);
610 continue;
611 }
612
613 let proc_name = match &self.spec.proc_name {
614 Some(name) => name.clone(),
615 None => format!("{}_{}", self.name, index),
616 };
617 child.post(Allocator2Process::StartProc(
618 hyperactor_reference::ProcId::with_name(addr.clone(), proc_name),
619 transport,
620 ));
621 }
622
623 Process2AllocatorMessage::StartedProc(proc_id, mesh_agent, addr) => {
624 break Some(ProcState::Running {
625 create_key: self.created[index].clone(),
626 proc_id,
627 mesh_agent,
628 addr,
629 });
630 }
631 Process2AllocatorMessage::Heartbeat => {
632 tracing::trace!("recv heartbeat from {index}");
633 }
634 }
635 },
636
637 Some(Ok((index, mut reason))) = self.children.join_next() => {
638 let stderr_content = if let Some(child) = self.remove(index) {
639 let mut stderr_lines = Vec::new();
640
641 let (stdout_mon, stderr_mon) = child.take_stream_monitors();
642
643 if let Some(stdout_monitor) = stdout_mon {
645 let (_lines, _result) = stdout_monitor.abort().await;
646 if let Err(e) = _result {
647 tracing::warn!("stdout monitor abort error: {}", e);
648 }
649 }
650
651 if let Some(stderr_monitor) = stderr_mon {
653 let (lines, result) = stderr_monitor.abort().await;
654 stderr_lines = lines;
655 if let Err(e) = result {
656 tracing::warn!("stderr monitor abort error: {}", e);
657 }
658 }
659
660 stderr_lines.join("\n")
661 } else {
662 String::new()
663 };
664
665 if let ProcStopReason::Exited(code, _) = &mut reason {
666 reason = ProcStopReason::Exited(*code, stderr_content);
667 }
668
669 tracing::info!("child stopped with ProcStopReason::{:?}", reason);
670
671 break Some(ProcState::Stopped {
672 create_key: self.created[index].clone(),
673 reason,
674 });
675 },
676 }
677 }
678 }
679
680 fn spec(&self) -> &AllocSpec {
681 &self.spec
682 }
683
684 fn extent(&self) -> &Extent {
685 &self.spec.extent
686 }
687
688 fn alloc_name(&self) -> &AllocName {
689 &self.alloc_name
690 }
691
692 async fn stop(&mut self) -> Result<(), AllocatorError> {
693 tracing::info!(
694 name = "ProcessAllocStatus",
695 alloc_name = %self.alloc_name(),
696 status = "Stopping",
697 );
698 for (_index, child) in self.active.iter_mut() {
703 child.post(Allocator2Process::StopAndExit(0));
704 child.spawn_watchdog();
705 }
706
707 self.running = false;
708 tracing::info!(
709 name = "ProcessAllocStatus",
710 alloc_name = %self.alloc_name(),
711 status = "Stop::Sent",
712 "StopAndExit was sent to allocators; check their logs for the stop progress."
713 );
714 Ok(())
715 }
716}
717
718impl Drop for ProcessAlloc {
719 fn drop(&mut self) {
720 tracing::info!(
721 name = "ProcessAllocStatus",
722 alloc_name = %self.alloc_name(),
723 status = "Dropped",
724 "dropping ProcessAlloc of name: {}, alloc_name: {}",
725 self.name,
726 self.alloc_name
727 );
728 }
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734
735 #[cfg(fbcode_build)] crate::alloc_test_suite!(ProcessAllocator::new(Command::new(
737 crate::testresource::get("monarch/hyperactor_mesh/bootstrap")
738 )));
739
740 #[cfg(fbcode_build)]
741 #[tokio::test]
742 async fn test_sigterm_on_group_fail() {
743 let bootstrap_binary = crate::testresource::get("monarch/hyperactor_mesh/bootstrap");
744 let mut allocator = ProcessAllocator::new(Command::new(bootstrap_binary));
745
746 let mut alloc = allocator
747 .allocate(AllocSpec {
748 extent: ndslice::extent!(replica = 1),
749 constraints: Default::default(),
750 proc_name: None,
751 transport: ChannelTransport::Unix,
752 proc_allocation_mode: Default::default(),
753 })
754 .await
755 .unwrap();
756
757 let proc_id = {
758 loop {
759 match alloc.next().await {
760 Some(ProcState::Running { proc_id, .. }) => {
761 break proc_id;
762 }
763 Some(ProcState::Failed { description, .. }) => {
764 panic!("Process allocation failed: {}", description);
765 }
766 Some(_other) => {}
767 None => {
768 panic!("Allocation ended unexpectedly");
769 }
770 }
771 }
772 };
773
774 if let Some(child) = alloc.active.get(
775 &alloc
776 .index(&proc_id)
777 .expect("proc must be in allocation for lookup"),
778 ) {
779 child.fail_group();
780 }
781
782 assert!(matches!(
783 alloc.next().await,
784 Some(ProcState::Stopped {
785 reason: ProcStopReason::Killed(15, false),
786 ..
787 })
788 ));
789 }
790}