1#![allow(dead_code)] use std::collections::HashMap;
12use std::future::Future;
13use std::os::unix::process::ExitStatusExt;
14use std::process::ExitStatus;
15use std::process::Stdio;
16use std::sync::Arc;
17use std::sync::OnceLock;
18
19use async_trait::async_trait;
20use enum_as_inner::EnumAsInner;
21use hyperactor::ProcId;
22use hyperactor::WorldId;
23use hyperactor::channel;
24use hyperactor::channel::ChannelAddr;
25use hyperactor::channel::ChannelError;
26use hyperactor::channel::ChannelTransport;
27use hyperactor::channel::ChannelTx;
28use hyperactor::channel::Rx;
29use hyperactor::channel::Tx;
30use hyperactor::channel::TxStatus;
31use hyperactor::sync::flag;
32use hyperactor::sync::monitor;
33use ndslice::view::Extent;
34use nix::sys::signal;
35use nix::unistd::Pid;
36use serde::Deserialize;
37use serde::Serialize;
38use tokio::io;
39use tokio::process::Command;
40use tokio::sync::Mutex;
41use tokio::task::JoinSet;
42
43use super::Alloc;
44use super::AllocSpec;
45use super::Allocator;
46use super::AllocatorError;
47use super::ProcState;
48use super::ProcStopReason;
49use super::logtailer::LogTailer;
50use crate::assign::Ranks;
51use crate::bootstrap;
52use crate::bootstrap::Allocator2Process;
53use crate::bootstrap::Process2Allocator;
54use crate::bootstrap::Process2AllocatorMessage;
55use crate::logging::create_log_writers;
56use crate::shortuuid::ShortUuid;
57
58const MAX_TAIL_LOG_LINES: usize = 100;
60
61pub const CLIENT_TRACE_ID_LABEL: &str = "CLIENT_TRACE_ID";
62
63pub struct ProcessAllocator {
71 cmd: Arc<Mutex<Command>>,
72}
73
74impl ProcessAllocator {
75 pub fn new(cmd: Command) -> Self {
80 Self {
81 cmd: Arc::new(Mutex::new(cmd)),
82 }
83 }
84}
85
86#[async_trait]
87impl Allocator for ProcessAllocator {
88 type Alloc = ProcessAlloc;
89
90 async fn allocate(&mut self, spec: AllocSpec) -> Result<ProcessAlloc, AllocatorError> {
91 let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix))
92 .await
93 .map_err(anyhow::Error::from)?;
94
95 let name = ShortUuid::generate();
96 Ok(ProcessAlloc {
97 name: name.clone(),
98 world_id: WorldId(name.to_string()),
99 spec: spec.clone(),
100 bootstrap_addr,
101 rx,
102 index: 0,
103 active: HashMap::new(),
104 ranks: Ranks::new(spec.extent.num_ranks()),
105 cmd: Arc::clone(&self.cmd),
106 children: JoinSet::new(),
107 running: true,
108 failed: false,
109 client_context: ClientContext {
110 trace_id: spec
111 .constraints
112 .match_labels
113 .get(CLIENT_TRACE_ID_LABEL)
114 .cloned()
115 .unwrap_or_else(|| "".to_string()),
116 },
117 })
118 }
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ClientContext {
125 pub trace_id: String,
127}
128
129pub struct ProcessAlloc {
131 name: ShortUuid,
132 world_id: WorldId, spec: AllocSpec,
134 bootstrap_addr: ChannelAddr,
135 rx: channel::ChannelRx<Process2Allocator>,
136 index: usize,
137 active: HashMap<usize, Child>,
138 ranks: Ranks<usize>,
140 cmd: Arc<Mutex<Command>>,
141 children: JoinSet<(usize, ProcStopReason)>,
142 running: bool,
143 failed: bool,
144 client_context: ClientContext,
145}
146
147#[derive(EnumAsInner)]
148enum ChannelState {
149 NotConnected,
150 Connected(ChannelTx<Allocator2Process>),
151 Failed(ChannelError),
152}
153
154struct Child {
155 local_rank: usize,
156 channel: ChannelState,
157 group: monitor::Group,
158 exit_flag: Option<flag::Flag>,
159 stdout: Option<LogTailer>,
160 stderr: Option<LogTailer>,
161 stop_reason: Arc<OnceLock<ProcStopReason>>,
162 process_pid: Arc<std::sync::Mutex<Option<i32>>>,
163}
164
165impl Child {
166 fn monitored(
167 local_rank: usize,
168 mut process: tokio::process::Child,
169 log_channel: ChannelAddr,
170 ) -> (Self, impl Future<Output = ProcStopReason>) {
171 let (group, handle) = monitor::group();
172 let (exit_flag, exit_guard) = flag::guarded();
173 let stop_reason = Arc::new(OnceLock::new());
174
175 let mut stdout_tee: Box<dyn io::AsyncWrite + Send + Unpin + 'static> =
177 Box::new(io::stdout());
178 let mut stderr_tee: Box<dyn io::AsyncWrite + Send + Unpin + 'static> =
179 Box::new(io::stderr());
180
181 match create_log_writers(local_rank, log_channel, process.id().unwrap_or(0)) {
183 Ok((stdout_writer, stderr_writer)) => {
184 stdout_tee = stdout_writer;
185 stderr_tee = stderr_writer;
186 }
187 Err(e) => {
188 tracing::error!("failed to create log writers: {}", e);
189 }
190 }
191
192 let stdout = LogTailer::tee(
193 MAX_TAIL_LOG_LINES,
194 process.stdout.take().unwrap(),
195 stdout_tee,
196 );
197
198 let stderr = LogTailer::tee(
199 MAX_TAIL_LOG_LINES,
200 process.stderr.take().unwrap(),
201 stderr_tee,
202 );
203
204 let process_pid = Arc::new(std::sync::Mutex::new(process.id().map(|id| id as i32)));
205
206 let child = Self {
207 local_rank,
208 channel: ChannelState::NotConnected,
209 group,
210 exit_flag: Some(exit_flag),
211 stdout: Some(stdout),
212 stderr: Some(stderr),
213 stop_reason: Arc::clone(&stop_reason),
214 process_pid: process_pid.clone(),
215 };
216
217 let monitor = async move {
218 let reason = tokio::select! {
219 _ = handle => {
220 Self::ensure_killed(process_pid);
221 Self::exit_status_to_reason(process.wait().await)
222 }
223 result = process.wait() => {
224 Self::exit_status_to_reason(result)
225 }
226 };
227 exit_guard.signal();
228
229 stop_reason.get_or_init(|| reason).clone()
230 };
231
232 (child, monitor)
233 }
234
235 fn ensure_killed(pid: Arc<std::sync::Mutex<Option<i32>>>) {
236 match pid.lock().unwrap().take() {
237 Some(pid) => {
238 if let Err(e) = signal::kill(Pid::from_raw(pid), signal::SIGTERM) {
239 match e {
240 nix::errno::Errno::ESRCH => {
241 tracing::debug!("pid {} already exited", pid);
243 }
244 _ => {
245 tracing::error!("failed to kill {}: {}", pid, e);
246 }
247 }
248 }
249 }
250 None => (),
251 }
252 }
253
254 fn exit_status_to_reason(result: io::Result<ExitStatus>) -> ProcStopReason {
255 match result {
256 Ok(status) if status.success() => ProcStopReason::Stopped,
257 Ok(status) => {
258 if let Some(signal) = status.signal() {
259 ProcStopReason::Killed(signal, status.core_dumped())
260 } else if let Some(code) = status.code() {
261 ProcStopReason::Exited(code, String::new())
262 } else {
263 ProcStopReason::Unknown
264 }
265 }
266 Err(e) => {
267 tracing::error!("error waiting for process: {}", e);
268 ProcStopReason::Unknown
269 }
270 }
271 }
272 #[hyperactor::instrument_infallible]
273 fn stop(&self, reason: ProcStopReason) {
274 let _ = self.stop_reason.set(reason); self.group.fail();
276 }
277
278 fn connect(&mut self, addr: ChannelAddr) -> bool {
279 if !self.channel.is_not_connected() {
280 return false;
281 }
282
283 match channel::dial(addr) {
284 Ok(channel) => {
285 let mut status = channel.status().clone();
286 self.channel = ChannelState::Connected(channel);
287 self.group.spawn(async move {
290 let _ = status
291 .wait_for(|status| matches!(status, TxStatus::Closed))
292 .await;
293 Result::<(), ()>::Err(())
294 });
295 }
296 Err(err) => {
297 self.channel = ChannelState::Failed(err);
298 self.stop(ProcStopReason::Watchdog);
299 }
300 };
301 true
302 }
303
304 fn spawn_watchdog(&mut self) {
305 let Some(exit_flag) = self.exit_flag.take() else {
306 tracing::info!("exit flag set, not spawning watchdog");
307 return;
308 };
309 let group = self.group.clone();
310 let stop_reason = self.stop_reason.clone();
311 tracing::info!("spawning watchdog");
312 tokio::spawn(async move {
313 let exit_timeout =
314 hyperactor::config::global::get(hyperactor::config::PROCESS_EXIT_TIMEOUT);
315 #[allow(clippy::disallowed_methods)]
316 if tokio::time::timeout(exit_timeout, exit_flag).await.is_err() {
317 tracing::info!("watchdog timeout, killing process");
318 let _ = stop_reason.set(ProcStopReason::Watchdog);
319 group.fail();
320 }
321 tracing::info!("Watchdog task exit");
322 });
323 }
324
325 #[hyperactor::instrument_infallible]
326 fn post(&mut self, message: Allocator2Process) {
327 if let ChannelState::Connected(channel) = &mut self.channel {
328 channel.post(message);
329 } else {
330 self.stop(ProcStopReason::Watchdog);
331 }
332 }
333
334 #[cfg(test)]
335 fn fail_group(&self) {
336 self.group.fail();
337 }
338}
339
340impl Drop for Child {
341 fn drop(&mut self) {
342 Self::ensure_killed(self.process_pid.clone());
343 }
344}
345
346impl ProcessAlloc {
347 #[hyperactor::instrument_infallible]
353 fn stop(&mut self, proc_id: &ProcId, reason: ProcStopReason) -> Result<(), anyhow::Error> {
354 self.get_mut(proc_id)?.stop(reason);
355 Ok(())
356 }
357
358 fn get(&self, proc_id: &ProcId) -> Result<&Child, anyhow::Error> {
359 self.active.get(&self.index(proc_id)?).ok_or_else(|| {
360 anyhow::anyhow!(
361 "proc {} not currently active in alloc {}",
362 proc_id,
363 self.name
364 )
365 })
366 }
367
368 fn get_mut(&mut self, proc_id: &ProcId) -> Result<&mut Child, anyhow::Error> {
369 self.active.get_mut(&self.index(proc_id)?).ok_or_else(|| {
370 anyhow::anyhow!(
371 "proc {} not currently active in alloc {}",
372 &proc_id,
373 self.name
374 )
375 })
376 }
377
378 pub(crate) fn name(&self) -> &ShortUuid {
380 &self.name
381 }
382
383 fn index(&self, proc_id: &ProcId) -> Result<usize, anyhow::Error> {
384 anyhow::ensure!(
385 proc_id
386 .world_name()
387 .expect("proc must be ranked for allocation index")
388 .parse::<ShortUuid>()?
389 == self.name,
390 "proc {} does not belong to alloc {}",
391 proc_id,
392 self.name
393 );
394 Ok(proc_id
395 .rank()
396 .expect("proc must be ranked for allocation index"))
397 }
398
399 #[hyperactor::instrument_infallible]
400 async fn maybe_spawn(&mut self) -> Option<ProcState> {
401 if self.active.len() >= self.spec.extent.num_ranks() {
402 return None;
403 }
404 let mut cmd = self.cmd.lock().await;
405 let index = self.index;
406 self.index += 1;
407 let log_channel: ChannelAddr = ChannelAddr::any(ChannelTransport::Unix);
408
409 cmd.env(
410 bootstrap::BOOTSTRAP_ADDR_ENV,
411 self.bootstrap_addr.to_string(),
412 );
413 cmd.env(
414 bootstrap::CLIENT_TRACE_ID_ENV,
415 self.client_context.trace_id.as_str(),
416 );
417 cmd.env(bootstrap::BOOTSTRAP_INDEX_ENV, index.to_string());
418 cmd.env(bootstrap::BOOTSTRAP_LOG_CHANNEL, log_channel.to_string());
419 cmd.stdout(Stdio::piped());
420 cmd.stderr(Stdio::piped());
421
422 let proc_id = ProcId::Ranked(WorldId(self.name.to_string()), index);
423 tracing::debug!("Spawning process {:?}", cmd);
424 match cmd.spawn() {
425 Err(err) => {
426 let message = format!("spawn index: {}, command: {:?}: {}", index, cmd, err);
428 tracing::error!(message);
429 self.failed = true;
430 Some(ProcState::Failed {
431 world_id: self.world_id.clone(),
432 description: message,
433 })
434 }
435 Ok(mut process) => {
436 let pid = process.id().unwrap_or(0);
437 match self.ranks.assign(index) {
438 Err(_index) => {
439 tracing::info!("could not assign rank to {}", proc_id);
440 let _ = process.kill().await;
441 None
442 }
443 Ok(rank) => {
444 let (handle, monitor) = Child::monitored(rank, process, log_channel);
445 self.children.spawn(async move { (index, monitor.await) });
446 self.active.insert(index, handle);
447 let point = self.spec.extent.point_of_rank(rank).unwrap();
449 Some(ProcState::Created {
450 proc_id,
451 point,
452 pid,
453 })
454 }
455 }
456 }
457 }
458 }
459
460 fn remove(&mut self, index: usize) -> Option<Child> {
461 self.ranks.unassign(index);
462 self.active.remove(&index)
463 }
464}
465
466#[async_trait]
467impl Alloc for ProcessAlloc {
468 #[hyperactor::instrument_infallible]
469 async fn next(&mut self) -> Option<ProcState> {
470 if !self.running && self.active.is_empty() {
471 return None;
472 }
473
474 loop {
475 if self.running && !self.failed {
477 if let state @ Some(_) = self.maybe_spawn().await {
478 return state;
479 }
480 }
481
482 let transport = self.transport().clone();
483
484 tokio::select! {
485 Ok(Process2Allocator(index, message)) = self.rx.recv() => {
486 let child = match self.active.get_mut(&index) {
487 None => {
488 tracing::info!("message {:?} from zombie {}", message, index);
489 continue;
490 }
491 Some(child) => child,
492 };
493
494 match message {
495 Process2AllocatorMessage::Hello(addr) => {
496 if !child.connect(addr.clone()) {
497 tracing::error!("received multiple hellos from {}", index);
498 continue;
499 }
500
501 child.post(Allocator2Process::StartProc(
502 ProcId::Ranked(WorldId(self.name.to_string()), index),
503 transport,
504 ));
505 }
506
507 Process2AllocatorMessage::StartedProc(proc_id, mesh_agent, addr) => {
508 break Some(ProcState::Running {
509 proc_id,
510 mesh_agent,
511 addr,
512 });
513 }
514 Process2AllocatorMessage::Heartbeat => {
515 tracing::trace!("recv heartbeat from {index}");
516 }
517 }
518 },
519
520 Some(Ok((index, mut reason))) = self.children.join_next() => {
521 let stderr_content = if let Some(mut child) = self.remove(index) {
522 let stdout = child.stdout.take().unwrap();
523 let stderr = child.stderr.take().unwrap();
524 stdout.abort();
525 stderr.abort();
526 let (_stdout, _) = stdout.join().await;
527 let (stderr_lines, _) = stderr.join().await;
528 stderr_lines.join("\n")
529 } else {
530 String::new()
531 };
532
533 if let ProcStopReason::Exited(code, _) = &mut reason {
534 reason = ProcStopReason::Exited(*code, stderr_content);
535 }
536
537 tracing::info!("child stopped with ProcStopReason::{:?}", reason);
538
539 break Some(ProcState::Stopped {
540 proc_id: ProcId::Ranked(WorldId(self.name.to_string()), index),
541 reason
542 });
543 },
544 }
545 }
546 }
547
548 fn extent(&self) -> &Extent {
549 &self.spec.extent
550 }
551
552 fn world_id(&self) -> &WorldId {
553 &self.world_id
554 }
555
556 fn transport(&self) -> ChannelTransport {
557 ChannelTransport::Unix
558 }
559
560 async fn stop(&mut self) -> Result<(), AllocatorError> {
561 for (_index, child) in self.active.iter_mut() {
566 child.post(Allocator2Process::StopAndExit(0));
567 child.spawn_watchdog();
568 }
569
570 self.running = false;
571 Ok(())
572 }
573}
574
575impl Drop for ProcessAlloc {
576 fn drop(&mut self) {
577 tracing::debug!(
578 "dropping ProcessAlloc of name: {}, world id: {}",
579 self.name,
580 self.world_id
581 );
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[cfg(fbcode_build)] crate::alloc_test_suite!(ProcessAllocator::new(Command::new(
591 buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap()
592 )));
593
594 #[tokio::test]
595 async fn test_sigterm_on_group_fail() {
596 let bootstrap_binary = buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap();
597 let mut allocator = ProcessAllocator::new(Command::new(bootstrap_binary));
598
599 let mut alloc = allocator
600 .allocate(AllocSpec {
601 extent: ndslice::extent!(replica = 1),
602 constraints: Default::default(),
603 })
604 .await
605 .unwrap();
606
607 let proc_id = {
608 loop {
609 match alloc.next().await {
610 Some(ProcState::Running { proc_id, .. }) => {
611 break proc_id;
612 }
613 Some(ProcState::Failed { description, .. }) => {
614 panic!("Process allocation failed: {}", description);
615 }
616 Some(_other) => {}
617 None => {
618 panic!("Allocation ended unexpectedly");
619 }
620 }
621 }
622 };
623
624 if let Some(child) = alloc.active.get(
625 &proc_id
626 .rank()
627 .expect("proc must be ranked for allocation lookup"),
628 ) {
629 child.fail_group();
630 }
631
632 assert!(matches!(
633 alloc.next().await,
634 Some(ProcState::Stopped {
635 reason: ProcStopReason::Killed(15, false),
636 ..
637 })
638 ));
639 }
640}