1pub mod local;
13pub(crate) mod logtailer;
14pub mod process;
15pub mod remoteprocess;
16pub mod sim;
17
18use std::collections::HashMap;
19use std::fmt;
20
21use async_trait::async_trait;
22use enum_as_inner::EnumAsInner;
23use hyperactor::ActorRef;
24use hyperactor::ProcId;
25use hyperactor::WorldId;
26use hyperactor::channel::ChannelAddr;
27use hyperactor::channel::ChannelTransport;
28pub use local::LocalAlloc;
29pub use local::LocalAllocator;
30use mockall::predicate::*;
31use mockall::*;
32use ndslice::Shape;
33use ndslice::Slice;
34use ndslice::view::Extent;
35use ndslice::view::Point;
36pub use process::ProcessAlloc;
37pub use process::ProcessAllocator;
38use serde::Deserialize;
39use serde::Serialize;
40
41use crate::alloc::test_utils::MockAllocWrapper;
42use crate::proc_mesh::mesh_agent::MeshAgent;
43
44#[derive(Debug, thiserror::Error)]
46pub enum AllocatorError {
47 #[error("incomplete allocation; expected: {0}")]
48 Incomplete(Extent),
49
50 #[error("not enough resources; requested: {requested}, available: {available}")]
52 NotEnoughResources { requested: Extent, available: usize },
53
54 #[error(transparent)]
56 Other(#[from] anyhow::Error),
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize, Default)]
61pub struct AllocConstraints {
62 pub match_labels: HashMap<String, String>,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct AllocSpec {
70 pub extent: Extent,
75 pub constraints: AllocConstraints,
77}
78
79#[automock(type Alloc=MockAllocWrapper;)]
81#[async_trait]
82pub trait Allocator {
83 type Alloc: Alloc;
85
86 async fn allocate(&mut self, spec: AllocSpec) -> Result<Self::Alloc, AllocatorError>;
91}
92
93#[derive(Clone, Debug, PartialEq, EnumAsInner, Serialize, Deserialize)]
96pub enum ProcState {
97 Created {
99 proc_id: ProcId,
101 point: Point,
103 pid: u32,
105 },
106 Running {
108 proc_id: ProcId,
109 mesh_agent: ActorRef<MeshAgent>,
112 addr: ChannelAddr,
115 },
116 Stopped {
118 proc_id: ProcId,
119 reason: ProcStopReason,
120 },
121 Failed {
129 world_id: WorldId,
131 description: String,
133 },
134}
135
136impl fmt::Display for ProcState {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 match self {
139 ProcState::Created {
140 proc_id,
141 point,
142 pid,
143 } => {
144 write!(f, "{}: created at ({}) with PID {}", proc_id, point, pid)
145 }
146 ProcState::Running { proc_id, addr, .. } => {
147 write!(f, "{}: running at {}", proc_id, addr)
148 }
149 ProcState::Stopped { proc_id, reason } => {
150 write!(f, "{}: stopped: {}", proc_id, reason)
151 }
152 ProcState::Failed {
153 description,
154 world_id,
155 } => {
156 write!(f, "{}: failed: {}", world_id, description)
157 }
158 }
159 }
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, EnumAsInner)]
164pub enum ProcStopReason {
165 Stopped,
167 Exited(i32, String),
169 Killed(i32, bool),
172 Watchdog,
174 HostWatchdog,
177 Unknown,
179}
180
181impl fmt::Display for ProcStopReason {
182 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
183 match self {
184 Self::Stopped => write!(f, "stopped"),
185 Self::Exited(code, stderr) => {
186 if stderr.is_empty() {
187 write!(f, "exited with code {}", code)
188 } else {
189 write!(f, "exited with code {}: {}", code, stderr)
190 }
191 }
192 Self::Killed(signal, dumped) => {
193 write!(f, "killed with signal {} (core dumped={})", signal, dumped)
194 }
195 Self::Watchdog => write!(f, "proc watchdog failure"),
196 Self::HostWatchdog => write!(f, "host watchdog failure"),
197 Self::Unknown => write!(f, "unknown"),
198 }
199 }
200}
201
202#[automock]
204#[async_trait]
205pub trait Alloc {
206 async fn next(&mut self) -> Option<ProcState>;
209
210 fn extent(&self) -> &Extent;
212
213 fn shape(&self) -> Shape {
215 let slice = Slice::new_row_major(self.extent().sizes());
216 Shape::new(self.extent().labels().to_vec(), slice).unwrap()
217 }
218
219 fn world_id(&self) -> &WorldId;
223
224 fn transport(&self) -> ChannelTransport;
226
227 async fn stop(&mut self) -> Result<(), AllocatorError>;
231
232 async fn stop_and_wait(&mut self) -> Result<(), AllocatorError> {
235 self.stop().await?;
236 while let Some(event) = self.next().await {
237 tracing::debug!("drained event: {:?}", event);
238 }
239 Ok(())
240 }
241}
242
243pub mod test_utils {
244 use std::time::Duration;
245
246 use hyperactor::Actor;
247 use hyperactor::Context;
248 use hyperactor::Handler;
249 use hyperactor::Named;
250 use libc::atexit;
251 use tokio::sync::broadcast::Receiver;
252 use tokio::sync::broadcast::Sender;
253
254 use super::*;
255
256 extern "C" fn exit_handler() {
257 loop {
258 #[allow(clippy::disallowed_methods)]
259 std::thread::sleep(Duration::from_secs(60));
260 }
261 }
262
263 #[derive(Debug, Default, Actor)]
268 #[hyperactor::export(
269 spawn = true,
270 handlers = [
271 Wait
272 ],
273 )]
274 pub struct TestActor;
275
276 #[derive(Debug, Serialize, Deserialize, Named, Clone)]
277 pub struct Wait;
278
279 #[async_trait]
280 impl Handler<Wait> for TestActor {
281 async fn handle(&mut self, _: &Context<Self>, _: Wait) -> Result<(), anyhow::Error> {
282 unsafe {
285 atexit(exit_handler);
286 }
287 Ok(())
288 }
289 }
290
291 pub struct MockAllocWrapper {
294 pub alloc: MockAlloc,
295 pub block_next_after: usize,
296 notify_tx: Sender<()>,
297 notify_rx: Receiver<()>,
298 next_unblocked: bool,
299 }
300
301 impl MockAllocWrapper {
302 pub fn new(alloc: MockAlloc) -> Self {
303 Self::new_block_next(alloc, usize::MAX)
304 }
305
306 pub fn new_block_next(alloc: MockAlloc, count: usize) -> Self {
307 let (tx, rx) = tokio::sync::broadcast::channel(1);
308 Self {
309 alloc,
310 block_next_after: count,
311 notify_tx: tx,
312 notify_rx: rx,
313 next_unblocked: false,
314 }
315 }
316
317 pub fn notify_tx(&self) -> Sender<()> {
318 self.notify_tx.clone()
319 }
320 }
321
322 #[async_trait]
323 impl Alloc for MockAllocWrapper {
324 async fn next(&mut self) -> Option<ProcState> {
325 match self.block_next_after {
326 0 => {
327 if !self.next_unblocked {
328 self.notify_rx.recv().await.unwrap();
329 self.next_unblocked = true;
330 }
331 }
332 1.. => {
333 self.block_next_after -= 1;
334 }
335 }
336
337 self.alloc.next().await
338 }
339
340 fn extent(&self) -> &Extent {
341 self.alloc.extent()
342 }
343
344 fn world_id(&self) -> &WorldId {
345 self.alloc.world_id()
346 }
347
348 fn transport(&self) -> ChannelTransport {
349 self.alloc.transport()
350 }
351
352 async fn stop(&mut self) -> Result<(), AllocatorError> {
353 self.alloc.stop().await
354 }
355 }
356}
357
358#[cfg(test)]
359pub(crate) mod testing {
360 use core::panic;
361 use std::collections::HashMap;
362 use std::collections::HashSet;
363 use std::time::Duration;
364
365 use hyperactor::Mailbox;
366 use hyperactor::actor::remote::Remote;
367 use hyperactor::channel;
368 use hyperactor::mailbox;
369 use hyperactor::mailbox::BoxedMailboxSender;
370 use hyperactor::mailbox::DialMailboxRouter;
371 use hyperactor::mailbox::IntoBoxedMailboxSender;
372 use hyperactor::mailbox::MailboxServer;
373 use hyperactor::mailbox::UndeliverableMailboxSender;
374 use hyperactor::proc::Proc;
375 use hyperactor::reference::Reference;
376 use ndslice::extent;
377 use tokio::process::Command;
378
379 use super::*;
380 use crate::alloc::test_utils::TestActor;
381 use crate::alloc::test_utils::Wait;
382 use crate::proc_mesh::mesh_agent::GspawnResult;
383 use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
384
385 #[macro_export]
386 macro_rules! alloc_test_suite {
387 ($allocator:expr_2021) => {
388 #[tokio::test]
389 async fn test_allocator_basic() {
390 $crate::alloc::testing::test_allocator_basic($allocator).await;
391 }
392 };
393 }
394
395 pub(crate) async fn test_allocator_basic(mut allocator: impl Allocator) {
396 let extent = extent!(replica = 4);
397 let mut alloc = allocator
398 .allocate(AllocSpec {
399 extent: extent.clone(),
400 constraints: Default::default(),
401 })
402 .await
403 .unwrap();
404
405 let mut procs = HashMap::new();
408 let mut running = HashSet::new();
409 while running.len() != 4 {
410 match alloc.next().await.unwrap() {
411 ProcState::Created { proc_id, point, .. } => {
412 procs.insert(proc_id, point);
413 }
414 ProcState::Running { proc_id, .. } => {
415 assert!(procs.contains_key(&proc_id));
416 assert!(!running.contains(&proc_id));
417 running.insert(proc_id);
418 }
419 event => panic!("unexpected event: {:?}", event),
420 }
421 }
422
423 let points: HashSet<_> = procs.values().collect();
425 for x in 0..4 {
426 assert!(points.contains(&extent.point(vec![x]).unwrap()));
427 }
428
429 let worlds: HashSet<_> = procs.keys().map(|proc_id| proc_id.world_id()).collect();
431 assert_eq!(worlds.len(), 1);
432
433 alloc.stop().await.unwrap();
436 let mut stopped = HashSet::new();
437 while let Some(ProcState::Stopped { proc_id, reason }) = alloc.next().await {
438 assert_eq!(reason, ProcStopReason::Stopped);
439 stopped.insert(proc_id);
440 }
441 assert!(alloc.next().await.is_none());
442 assert_eq!(stopped, running);
443 }
444
445 async fn spawn_proc(
446 transport: ChannelTransport,
447 ) -> (DialMailboxRouter, Mailbox, Proc, ChannelAddr) {
448 let (router_channel_addr, router_rx) = channel::serve(ChannelAddr::any(transport.clone()))
449 .await
450 .unwrap();
451 let router =
452 DialMailboxRouter::new_with_default((UndeliverableMailboxSender {}).into_boxed());
453 router.clone().serve(router_rx);
454
455 let client_proc_id = ProcId::Ranked(WorldId("test_stuck".to_string()), 0);
456 let (client_proc_addr, client_rx) =
457 channel::serve(ChannelAddr::any(transport)).await.unwrap();
458 let client_proc = Proc::new(
459 client_proc_id.clone(),
460 BoxedMailboxSender::new(router.clone()),
461 );
462 client_proc.clone().serve(client_rx);
463 router.bind(client_proc_id.clone().into(), client_proc_addr);
464 (
465 router,
466 client_proc.attach("test_proc").unwrap(),
467 client_proc,
468 router_channel_addr,
469 )
470 }
471
472 async fn spawn_test_actor(
473 rank: usize,
474 client_proc: &Proc,
475 client: &Mailbox,
476 router_channel_addr: ChannelAddr,
477 mesh_agent: ActorRef<MeshAgent>,
478 ) -> ActorRef<TestActor> {
479 let supervisor = client_proc.attach("supervisor").unwrap();
480 let (supervison_port, _) = supervisor.open_port();
481 let (config_handle, _) = client.open_port();
482 mesh_agent
483 .configure(
484 client,
485 rank,
486 router_channel_addr,
487 supervison_port.bind(),
488 HashMap::new(),
489 config_handle.bind(),
490 )
491 .await
492 .unwrap();
493 let remote = Remote::collect();
494 let actor_type = remote
495 .name_of::<TestActor>()
496 .ok_or(anyhow::anyhow!("actor not registered"))
497 .unwrap()
498 .to_string();
499 let params = &();
500 let (completed_handle, mut completed_receiver) = mailbox::open_port(client);
501 mesh_agent
503 .gspawn(
504 client,
505 actor_type,
506 "Stuck".to_string(),
507 bincode::serialize(params).unwrap(),
508 completed_handle.bind(),
509 )
510 .await
511 .unwrap();
512 let result = completed_receiver.recv().await.unwrap();
513 match result {
514 GspawnResult::Success { actor_id, .. } => ActorRef::attest(actor_id),
515 GspawnResult::Error(error_msg) => {
516 panic!("gspawn failed: {}", error_msg);
517 }
518 }
519 }
520
521 #[tokio::test]
526 async fn test_allocator_stuck_task() {
527 let config = hyperactor::config::global::lock();
530 let _guard = config.override_key(
531 hyperactor::config::PROCESS_EXIT_TIMEOUT,
532 Duration::from_secs(1),
533 );
534
535 let command =
536 Command::new(buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap());
537 let mut allocator = ProcessAllocator::new(command);
538 let mut alloc = allocator
539 .allocate(AllocSpec {
540 extent: extent! { replica = 1 },
541 constraints: Default::default(),
542 })
543 .await
544 .unwrap();
545
546 let mut procs = HashMap::new();
548 let mut running = HashSet::new();
549 let mut actor_ref = None;
550 let (router, client, client_proc, router_addr) = spawn_proc(alloc.transport()).await;
551 while running.is_empty() {
552 match alloc.next().await.unwrap() {
553 ProcState::Created { proc_id, point, .. } => {
554 procs.insert(proc_id, point);
555 }
556 ProcState::Running {
557 proc_id,
558 mesh_agent,
559 addr,
560 } => {
561 router.bind(Reference::Proc(proc_id.clone()), addr.clone());
562
563 assert!(procs.contains_key(&proc_id));
564 assert!(!running.contains(&proc_id));
565
566 actor_ref = Some(
567 spawn_test_actor(0, &client_proc, &client, router_addr, mesh_agent).await,
568 );
569 running.insert(proc_id);
570 break;
571 }
572 event => panic!("unexpected event: {:?}", event),
573 }
574 }
575 assert!(actor_ref.unwrap().send(&client, Wait).is_ok());
576
577 alloc.stop().await.unwrap();
579 let mut stopped = HashSet::new();
580 while let Some(ProcState::Stopped { proc_id, reason }) = alloc.next().await {
581 assert_eq!(reason, ProcStopReason::Watchdog);
582 stopped.insert(proc_id);
583 }
584 assert!(alloc.next().await.is_none());
585 assert_eq!(stopped, running);
586 }
587}