1#![allow(dead_code)] use std::collections::HashMap;
12use std::future::Future;
13use std::os::unix::process::ExitStatusExt;
14use std::process::ExitStatus;
15use std::process::Stdio;
16use std::sync::Arc;
17use std::sync::OnceLock;
18
19use async_trait::async_trait;
20use enum_as_inner::EnumAsInner;
21use hyperactor::channel;
22use hyperactor::channel::ChannelAddr;
23use hyperactor::channel::ChannelError;
24use hyperactor::channel::ChannelTransport;
25use hyperactor::channel::ChannelTx;
26use hyperactor::channel::Rx;
27use hyperactor::channel::Tx;
28use hyperactor::channel::TxStatus;
29use hyperactor::reference as hyperactor_reference;
30use hyperactor::sync::flag;
31use hyperactor::sync::monitor;
32use ndslice::view::Extent;
33use nix::sys::signal;
34use nix::unistd::Pid;
35use serde::Deserialize;
36use serde::Serialize;
37use tokio::io;
38use tokio::process::Command;
39use tokio::sync::Mutex;
40use tokio::task::JoinSet;
41
42use super::Alloc;
43use super::AllocName;
44use super::AllocSpec;
45use super::Allocator;
46use super::AllocatorError;
47use super::ProcState;
48use super::ProcStopReason;
49use crate::assign::Ranks;
50use crate::bootstrap;
51use crate::bootstrap::Allocator2Process;
52use crate::bootstrap::MESH_ENABLE_LOG_FORWARDING;
53use crate::bootstrap::MESH_TAIL_LOG_LINES;
54use crate::bootstrap::Process2Allocator;
55use crate::bootstrap::Process2AllocatorMessage;
56use crate::logging::OutputTarget;
57use crate::logging::StreamFwder;
58use crate::shortuuid::ShortUuid;
59
60pub const CLIENT_TRACE_ID_LABEL: &str = "CLIENT_TRACE_ID";
61
62pub struct ProcessAllocator {
70 cmd: Arc<Mutex<Command>>,
71}
72
73impl ProcessAllocator {
74 pub fn new(cmd: Command) -> Self {
79 Self {
80 cmd: Arc::new(Mutex::new(cmd)),
81 }
82 }
83}
84
85#[async_trait]
86impl Allocator for ProcessAllocator {
87 type Alloc = ProcessAlloc;
88
89 #[hyperactor::instrument(fields(name = "process_allocate", monarch_client_trace_id = spec.constraints.match_labels.get(CLIENT_TRACE_ID_LABEL).cloned().unwrap_or_else(|| "".to_string())))]
90 async fn allocate(&mut self, spec: AllocSpec) -> Result<ProcessAlloc, AllocatorError> {
91 let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix))
92 .map_err(anyhow::Error::from)?;
93
94 if spec.transport == ChannelTransport::Local {
95 return Err(AllocatorError::Other(anyhow::anyhow!(
96 "ProcessAllocator does not support local transport"
97 )));
98 }
99
100 let name = ShortUuid::generate();
101 let alloc_name = AllocName(name.to_string());
102 tracing::info!(
103 name = "ProcessAllocStatus",
104 alloc_name = %alloc_name,
105 addr = %bootstrap_addr,
106 status = "Allocated",
107 );
108 Ok(ProcessAlloc {
109 name: name.clone(),
110 alloc_name,
111 spec: spec.clone(),
112 bootstrap_addr,
113 rx,
114 active: HashMap::new(),
115 ranks: Ranks::new(spec.extent.num_ranks()),
116 created: Vec::new(),
117 cmd: Arc::clone(&self.cmd),
118 children: JoinSet::new(),
119 running: true,
120 failed: false,
121 client_context: ClientContext {
122 trace_id: spec
123 .constraints
124 .match_labels
125 .get(CLIENT_TRACE_ID_LABEL)
126 .cloned()
127 .unwrap_or_else(|| "".to_string()),
128 },
129 })
130 }
131}
132
133#[derive(Debug, Clone, Serialize, Deserialize)]
136pub struct ClientContext {
137 pub trace_id: String,
139}
140
141pub struct ProcessAlloc {
143 name: ShortUuid,
144 alloc_name: AllocName,
145 spec: AllocSpec,
146 bootstrap_addr: ChannelAddr,
147 rx: channel::ChannelRx<Process2Allocator>,
148 active: HashMap<usize, Child>,
149 ranks: Ranks<usize>,
151 created: Vec<ShortUuid>,
153 cmd: Arc<Mutex<Command>>,
154 children: JoinSet<(usize, ProcStopReason)>,
155 running: bool,
156 failed: bool,
157 client_context: ClientContext,
158}
159
160#[derive(EnumAsInner)]
161enum ChannelState {
162 NotConnected,
163 Connected(ChannelTx<Allocator2Process>),
164 Failed(ChannelError),
165}
166
167struct Child {
168 local_rank: usize,
169 channel: ChannelState,
170 group: monitor::Group,
171 exit_flag: Option<flag::Flag>,
172 stdout_fwder: Arc<std::sync::Mutex<Option<StreamFwder>>>,
173 stderr_fwder: Arc<std::sync::Mutex<Option<StreamFwder>>>,
174 stop_reason: Arc<OnceLock<ProcStopReason>>,
175 process_pid: Arc<std::sync::Mutex<Option<i32>>>,
176}
177
178impl Child {
179 fn monitored(
180 local_rank: usize,
181 mut process: tokio::process::Child,
182 log_channel: Option<ChannelAddr>,
183 tail_size: usize,
184 proc_id: hyperactor_reference::ProcId,
185 ) -> (Self, impl Future<Output = ProcStopReason>) {
186 let (group, handle) = monitor::group();
187 let (exit_flag, exit_guard) = flag::guarded();
188 let stop_reason = Arc::new(OnceLock::new());
189 let process_pid = Arc::new(std::sync::Mutex::new(process.id().map(|id| id as i32)));
190
191 let stdout_pipe = process.stdout.take();
210 let stderr_pipe = process.stderr.take();
211
212 let child = Self {
213 local_rank,
214 channel: ChannelState::NotConnected,
215 group,
216 exit_flag: Some(exit_flag),
217 stdout_fwder: Arc::new(std::sync::Mutex::new(None)),
218 stderr_fwder: Arc::new(std::sync::Mutex::new(None)),
219 stop_reason: Arc::clone(&stop_reason),
220 process_pid: process_pid.clone(),
221 };
222
223 let child_stdout_fwder = child.stdout_fwder.clone();
225 let child_stderr_fwder = child.stderr_fwder.clone();
226
227 if let Some(stdout) = stdout_pipe {
228 let stdout_fwder = child_stdout_fwder.clone();
229 let log_channel_clone = log_channel.clone();
230 let proc_id_clone = proc_id.clone();
231 *stdout_fwder.lock().expect("stdout_fwder mutex poisoned") = Some(StreamFwder::start(
232 stdout,
233 None, OutputTarget::Stdout,
235 tail_size,
236 log_channel_clone, &proc_id_clone,
238 local_rank,
239 ));
240 }
241
242 if let Some(stderr) = stderr_pipe {
243 let stderr_fwder = child_stderr_fwder.clone();
244 *stderr_fwder.lock().expect("stderr_fwder mutex poisoned") = Some(StreamFwder::start(
245 stderr,
246 None, OutputTarget::Stderr,
248 tail_size,
249 log_channel, &proc_id,
251 local_rank,
252 ));
253 }
254
255 let monitor = async move {
256 let reason = tokio::select! {
257 _ = handle => {
258 Self::ensure_killed(process_pid);
259 Self::exit_status_to_reason(process.wait().await)
260 }
261 result = process.wait() => {
262 Self::exit_status_to_reason(result)
263 }
264 };
265 exit_guard.signal();
266
267 stop_reason.get_or_init(|| reason).clone()
268 };
269
270 (child, monitor)
271 }
272
273 fn ensure_killed(pid: Arc<std::sync::Mutex<Option<i32>>>) {
274 if let Some(pid) = pid.lock().unwrap().take() {
275 if let Err(e) = signal::kill(Pid::from_raw(pid), signal::SIGTERM) {
276 match e {
277 nix::errno::Errno::ESRCH => {
278 tracing::debug!("pid {} already exited", pid);
280 }
281 _ => {
282 tracing::error!("failed to kill {}: {}", pid, e);
283 }
284 }
285 }
286 }
287 }
288
289 fn exit_status_to_reason(result: io::Result<ExitStatus>) -> ProcStopReason {
290 match result {
291 Ok(status) if status.success() => ProcStopReason::Stopped,
292 Ok(status) => {
293 if let Some(signal) = status.signal() {
294 ProcStopReason::Killed(signal, status.core_dumped())
295 } else if let Some(code) = status.code() {
296 ProcStopReason::Exited(code, String::new())
297 } else {
298 ProcStopReason::Unknown
299 }
300 }
301 Err(e) => {
302 tracing::error!("error waiting for process: {}", e);
303 ProcStopReason::Unknown
304 }
305 }
306 }
307
308 #[hyperactor::instrument_infallible]
309 fn stop(&self, reason: ProcStopReason) {
310 let _ = self.stop_reason.set(reason); self.group.fail();
312 }
313
314 fn connect(&mut self, addr: ChannelAddr) -> bool {
315 if !self.channel.is_not_connected() {
316 return false;
317 }
318
319 match channel::dial(addr) {
320 Ok(channel) => {
321 let mut status = channel.status().clone();
322 self.channel = ChannelState::Connected(channel);
323 self.group.spawn(async move {
326 let _ = status
327 .wait_for(|status| matches!(status, TxStatus::Closed))
328 .await;
329 Result::<(), ()>::Err(())
330 });
331 }
332 Err(err) => {
333 self.channel = ChannelState::Failed(err);
334 self.stop(ProcStopReason::Watchdog);
335 }
336 };
337 true
338 }
339
340 fn spawn_watchdog(&mut self) {
341 let Some(exit_flag) = self.exit_flag.take() else {
342 tracing::info!("exit flag set, not spawning watchdog");
343 return;
344 };
345 let group = self.group.clone();
346 let stop_reason = self.stop_reason.clone();
347 tracing::info!("spawning watchdog");
348 tokio::spawn(async move {
349 let exit_timeout =
350 hyperactor_config::global::get(hyperactor::config::PROCESS_EXIT_TIMEOUT);
351 if tokio::time::timeout(exit_timeout, exit_flag).await.is_err() {
352 tracing::info!("watchdog timeout, killing process");
353 let _ = stop_reason.set(ProcStopReason::Watchdog);
354 group.fail();
355 }
356 tracing::info!("Watchdog task exit");
357 });
358 }
359
360 #[hyperactor::instrument_infallible]
361 fn post(&mut self, message: Allocator2Process) {
362 if let ChannelState::Connected(channel) = &mut self.channel {
363 channel.post(message);
364 } else {
365 self.stop(ProcStopReason::Watchdog);
366 }
367 }
368
369 #[cfg(test)]
370 fn fail_group(&self) {
371 self.group.fail();
372 }
373
374 fn take_stream_monitors(&self) -> (Option<StreamFwder>, Option<StreamFwder>) {
375 let out = self
376 .stdout_fwder
377 .lock()
378 .expect("stdout_tailer mutex poisoned")
379 .take();
380 let err = self
381 .stderr_fwder
382 .lock()
383 .expect("stderr_tailer mutex poisoned")
384 .take();
385 (out, err)
386 }
387}
388
389impl Drop for Child {
390 fn drop(&mut self) {
391 Self::ensure_killed(self.process_pid.clone());
392 }
393}
394
395impl ProcessAlloc {
396 #[hyperactor::instrument_infallible]
402 fn stop(
403 &mut self,
404 proc_id: &hyperactor_reference::ProcId,
405 reason: ProcStopReason,
406 ) -> Result<(), anyhow::Error> {
407 self.get_mut(proc_id)?.stop(reason);
408 Ok(())
409 }
410
411 fn get(&self, proc_id: &hyperactor_reference::ProcId) -> Result<&Child, anyhow::Error> {
412 self.active.get(&self.index(proc_id)?).ok_or_else(|| {
413 anyhow::anyhow!(
414 "proc {} not currently active in alloc {}",
415 proc_id,
416 self.name
417 )
418 })
419 }
420
421 fn get_mut(
422 &mut self,
423 proc_id: &hyperactor_reference::ProcId,
424 ) -> Result<&mut Child, anyhow::Error> {
425 self.active.get_mut(&self.index(proc_id)?).ok_or_else(|| {
426 anyhow::anyhow!(
427 "proc {} not currently active in alloc {}",
428 &proc_id,
429 self.name
430 )
431 })
432 }
433
434 pub(crate) fn name(&self) -> &ShortUuid {
436 &self.name
437 }
438
439 fn index(&self, proc_id: &hyperactor_reference::ProcId) -> Result<usize, anyhow::Error> {
440 let name = proc_id.name();
442 let expected_prefix = format!("{}_", self.name);
443 anyhow::ensure!(
444 name.starts_with(&expected_prefix),
445 "proc {} does not belong to alloc {}",
446 proc_id,
447 self.name
448 );
449 let index_str = &name[expected_prefix.len()..];
450 index_str
451 .parse::<usize>()
452 .map_err(|e| anyhow::anyhow!("failed to parse index from proc name '{}': {}", name, e))
453 }
454
455 #[hyperactor::instrument_infallible]
456 async fn maybe_spawn(&mut self) -> Option<ProcState> {
457 if self.active.len() >= self.spec.extent.num_ranks() {
458 return None;
459 }
460 let mut cmd = self.cmd.lock().await;
461
462 let enable_forwarding = hyperactor_config::global::get(MESH_ENABLE_LOG_FORWARDING);
470 let tail_size = hyperactor_config::global::get(MESH_TAIL_LOG_LINES);
471 if enable_forwarding || tail_size > 0 {
472 cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
473 } else {
474 cmd.stdout(Stdio::inherit()).stderr(Stdio::inherit());
475 tracing::info!(
476 "child stdio NOT captured (forwarding/file_capture/tail all disabled); \
477 inheriting parent console"
478 );
479 }
480 let log_channel: Option<ChannelAddr> = None;
486
487 let index = self.created.len();
488 self.created.push(ShortUuid::generate());
489 let create_key = &self.created[index];
490
491 let client_config = hyperactor_config::global::attrs();
493 let bootstrap = bootstrap::Bootstrap::V0ProcMesh {
494 config: Some(client_config),
495 };
496 bootstrap.to_env(&mut cmd);
497
498 cmd.env(
499 bootstrap::BOOTSTRAP_ADDR_ENV,
500 self.bootstrap_addr.to_string(),
501 );
502 cmd.env(
503 bootstrap::CLIENT_TRACE_ID_ENV,
504 self.client_context.trace_id.as_str(),
505 );
506 cmd.env(bootstrap::BOOTSTRAP_INDEX_ENV, index.to_string());
507
508 cmd.env(
509 "HYPERACTOR_PROCESS_NAME",
510 format!(
511 "host rank:{} @{}",
512 index,
513 hostname::get()
514 .unwrap_or_else(|_| "unknown_host".into())
515 .into_string()
516 .unwrap_or("unknown_host".to_string())
517 ),
518 );
519
520 tracing::debug!("spawning process {:?}", cmd);
521 match cmd.spawn() {
522 Err(err) => {
523 let message = format!(
525 "spawn {} index: {}, command: {:?}: {}",
526 create_key, index, cmd, err
527 );
528 tracing::error!(message);
529 self.failed = true;
530 Some(ProcState::Failed {
531 alloc_name: self.alloc_name.clone(),
532 description: message,
533 })
534 }
535 Ok(mut process) => {
536 let pid = process.id().unwrap_or(0);
537 match self.ranks.assign(index) {
538 Err(_index) => {
539 tracing::info!("could not assign rank to {}", create_key);
540 let _ = process.kill().await;
541 None
542 }
543 Ok(rank) => {
544 let temp_addr = ChannelAddr::any(ChannelTransport::Local);
547 let proc_id = hyperactor_reference::ProcId::with_name(
548 temp_addr,
549 format!("{}_{}", self.alloc_name.name(), rank),
550 );
551 let (handle, monitor) =
552 Child::monitored(rank, process, log_channel, tail_size, proc_id);
553
554 self.active.insert(index, handle);
557
558 self.children.spawn(async move { (index, monitor.await) });
560
561 let point = self.spec.extent.point_of_rank(rank).unwrap();
563 Some(ProcState::Created {
564 create_key: create_key.clone(),
565 point,
566 pid,
567 })
568 }
569 }
570 }
571 }
572 }
573
574 fn remove(&mut self, index: usize) -> Option<Child> {
575 self.ranks.unassign(index);
576 self.active.remove(&index)
577 }
578}
579
580#[async_trait]
581impl Alloc for ProcessAlloc {
582 #[hyperactor::instrument_infallible]
583 async fn next(&mut self) -> Option<ProcState> {
584 if !self.running && self.active.is_empty() {
585 return None;
586 }
587
588 loop {
589 if self.running
591 && !self.failed
592 && let state @ Some(_) = self.maybe_spawn().await
593 {
594 return state;
595 }
596
597 let transport = self.transport().clone();
598
599 tokio::select! {
600 Ok(Process2Allocator(index, message)) = self.rx.recv() => {
601 let child = match self.active.get_mut(&index) {
602 None => {
603 tracing::info!("message {:?} from zombie {}", message, index);
604 continue;
605 }
606 Some(child) => child,
607 };
608
609 match message {
610 Process2AllocatorMessage::Hello(addr) => {
611 if !child.connect(addr.clone()) {
612 tracing::error!("received multiple hellos from {}", index);
613 continue;
614 }
615
616 let proc_name = match &self.spec.proc_name {
617 Some(name) => name.clone(),
618 None => format!("{}_{}", self.name, index),
619 };
620 child.post(Allocator2Process::StartProc(
621 hyperactor_reference::ProcId::with_name(addr.clone(), proc_name),
622 transport,
623 ));
624 }
625
626 Process2AllocatorMessage::StartedProc(proc_id, mesh_agent, addr) => {
627 break Some(ProcState::Running {
628 create_key: self.created[index].clone(),
629 proc_id,
630 mesh_agent,
631 addr,
632 });
633 }
634 Process2AllocatorMessage::Heartbeat => {
635 tracing::trace!("recv heartbeat from {index}");
636 }
637 }
638 },
639
640 Some(Ok((index, mut reason))) = self.children.join_next() => {
641 let stderr_content = if let Some(child) = self.remove(index) {
642 let mut stderr_lines = Vec::new();
643
644 let (stdout_mon, stderr_mon) = child.take_stream_monitors();
645
646 if let Some(stdout_monitor) = stdout_mon {
648 let (_lines, _result) = stdout_monitor.abort().await;
649 if let Err(e) = _result {
650 tracing::warn!("stdout monitor abort error: {}", e);
651 }
652 }
653
654 if let Some(stderr_monitor) = stderr_mon {
656 let (lines, result) = stderr_monitor.abort().await;
657 stderr_lines = lines;
658 if let Err(e) = result {
659 tracing::warn!("stderr monitor abort error: {}", e);
660 }
661 }
662
663 stderr_lines.join("\n")
664 } else {
665 String::new()
666 };
667
668 if let ProcStopReason::Exited(code, _) = &mut reason {
669 reason = ProcStopReason::Exited(*code, stderr_content);
670 }
671
672 tracing::info!("child stopped with ProcStopReason::{:?}", reason);
673
674 break Some(ProcState::Stopped {
675 create_key: self.created[index].clone(),
676 reason,
677 });
678 },
679 }
680 }
681 }
682
683 fn spec(&self) -> &AllocSpec {
684 &self.spec
685 }
686
687 fn extent(&self) -> &Extent {
688 &self.spec.extent
689 }
690
691 fn alloc_name(&self) -> &AllocName {
692 &self.alloc_name
693 }
694
695 async fn stop(&mut self) -> Result<(), AllocatorError> {
696 tracing::info!(
697 name = "ProcessAllocStatus",
698 alloc_name = %self.alloc_name(),
699 status = "Stopping",
700 );
701 for (_index, child) in self.active.iter_mut() {
706 child.post(Allocator2Process::StopAndExit(0));
707 child.spawn_watchdog();
708 }
709
710 self.running = false;
711 tracing::info!(
712 name = "ProcessAllocStatus",
713 alloc_name = %self.alloc_name(),
714 status = "Stop::Sent",
715 "StopAndExit was sent to allocators; check their logs for the stop progress."
716 );
717 Ok(())
718 }
719}
720
721impl Drop for ProcessAlloc {
722 fn drop(&mut self) {
723 tracing::info!(
724 name = "ProcessAllocStatus",
725 alloc_name = %self.alloc_name(),
726 status = "Dropped",
727 "dropping ProcessAlloc of name: {}, alloc_name: {}",
728 self.name,
729 self.alloc_name
730 );
731 }
732}
733
734#[cfg(test)]
735mod tests {
736 use super::*;
737
738 #[cfg(fbcode_build)] crate::alloc_test_suite!(ProcessAllocator::new(Command::new(
740 crate::testresource::get("monarch/hyperactor_mesh/bootstrap")
741 )));
742
743 #[cfg(fbcode_build)]
744 #[tokio::test]
745 async fn test_sigterm_on_group_fail() {
746 let bootstrap_binary = crate::testresource::get("monarch/hyperactor_mesh/bootstrap");
747 let mut allocator = ProcessAllocator::new(Command::new(bootstrap_binary));
748
749 let mut alloc = allocator
750 .allocate(AllocSpec {
751 extent: ndslice::extent!(replica = 1),
752 constraints: Default::default(),
753 proc_name: None,
754 transport: ChannelTransport::Unix,
755 proc_allocation_mode: Default::default(),
756 })
757 .await
758 .unwrap();
759
760 let proc_id = {
761 loop {
762 match alloc.next().await {
763 Some(ProcState::Running { proc_id, .. }) => {
764 break proc_id;
765 }
766 Some(ProcState::Failed { description, .. }) => {
767 panic!("Process allocation failed: {}", description);
768 }
769 Some(_other) => {}
770 None => {
771 panic!("Allocation ended unexpectedly");
772 }
773 }
774 }
775 };
776
777 if let Some(child) = alloc.active.get(
778 &alloc
779 .index(&proc_id)
780 .expect("proc must be in allocation for lookup"),
781 ) {
782 child.fail_group();
783 }
784
785 assert!(matches!(
786 alloc.next().await,
787 Some(ProcState::Stopped {
788 reason: ProcStopReason::Killed(15, false),
789 ..
790 })
791 ));
792 }
793}