1#![allow(unused_assignments)]
17
18pub mod local;
19pub mod process;
20pub mod remoteprocess;
21
22use std::collections::HashMap;
23use std::fmt;
24
25use async_trait::async_trait;
26use enum_as_inner::EnumAsInner;
27use hyperactor::channel::ChannelAddr;
28use hyperactor::channel::ChannelTransport;
29use hyperactor::channel::TlsAddr;
30use hyperactor::reference as hyperactor_reference;
31pub use local::LocalAlloc;
32pub use local::LocalAllocator;
33use mockall::predicate::*;
34use mockall::*;
35use ndslice::Shape;
36use ndslice::Slice;
37use ndslice::view::Extent;
38use ndslice::view::Point;
39pub use process::ProcessAlloc;
40pub use process::ProcessAllocator;
41use serde::Deserialize;
42use serde::Serialize;
43use strum::AsRefStr;
44use typeuri::Named;
45
46use crate::alloc::test_utils::MockAllocWrapper;
47use crate::assign::Ranks;
48use crate::proc_agent::ProcAgent;
49use crate::shortuuid::ShortUuid;
50
51#[derive(
53 Debug,
54 Serialize,
55 Deserialize,
56 Clone,
57 PartialEq,
58 Eq,
59 PartialOrd,
60 Hash,
61 Ord
62)]
63pub struct AllocName(pub String);
64
65impl AllocName {
66 pub fn name(&self) -> &str {
68 &self.0
69 }
70}
71
72impl fmt::Display for AllocName {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 write!(f, "{}", self.0)
75 }
76}
77
78#[derive(Debug, thiserror::Error)]
80pub enum AllocatorError {
81 #[error("incomplete allocation; expected: {0}")]
82 Incomplete(Extent),
83
84 #[error("not enough resources; requested: {requested}, available: {available}")]
86 NotEnoughResources { requested: Extent, available: usize },
87
88 #[error(transparent)]
90 Other(#[from] anyhow::Error),
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize, Default)]
95pub struct AllocConstraints {
96 pub match_labels: HashMap<String, String>,
99}
100
101#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
103#[derive(Default)]
104pub enum ProcAllocationMode {
105 #[default]
110 ProcLevel,
111 HostLevel,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct AllocSpec {
120 pub extent: Extent,
125
126 pub constraints: AllocConstraints,
128
129 pub proc_name: Option<String>,
132
133 pub transport: ChannelTransport,
135
136 #[serde(default = "default_proc_allocation_mode")]
139 pub proc_allocation_mode: ProcAllocationMode,
140}
141
142fn default_proc_allocation_mode() -> ProcAllocationMode {
143 ProcAllocationMode::ProcLevel
144}
145
146#[automock(type Alloc=MockAllocWrapper;)]
148#[async_trait]
149pub trait Allocator {
150 type Alloc: Alloc;
152
153 async fn allocate(&mut self, spec: AllocSpec) -> Result<Self::Alloc, AllocatorError>;
158}
159
160#[derive(
163 Clone,
164 Debug,
165 PartialEq,
166 EnumAsInner,
167 Serialize,
168 Deserialize,
169 AsRefStr,
170 Named
171)]
172pub enum ProcState {
173 Created {
175 create_key: ShortUuid,
178 point: Point,
180 pid: u32,
182 },
183 Running {
185 create_key: ShortUuid,
187 proc_id: hyperactor_reference::ProcId,
189 mesh_agent: hyperactor_reference::ActorRef<ProcAgent>,
192 addr: ChannelAddr,
195 },
196 Stopped {
198 create_key: ShortUuid,
199 reason: ProcStopReason,
200 },
201 Failed {
209 alloc_name: AllocName,
211 description: String,
213 },
214}
215
216impl fmt::Display for ProcState {
217 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218 match self {
219 ProcState::Created {
220 create_key,
221 point,
222 pid,
223 } => {
224 write!(f, "{}: created at ({}) with PID {}", create_key, point, pid)
225 }
226 ProcState::Running { proc_id, addr, .. } => {
227 write!(f, "{}: running at {}", proc_id, addr)
228 }
229 ProcState::Stopped { create_key, reason } => {
230 write!(f, "{}: stopped: {}", create_key, reason)
231 }
232 ProcState::Failed {
233 description,
234 alloc_name,
235 } => {
236 write!(f, "{}: failed: {}", alloc_name, description)
237 }
238 }
239 }
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, EnumAsInner)]
244pub enum ProcStopReason {
245 Stopped,
247 Exited(i32, String),
249 Killed(i32, bool),
252 Watchdog,
254 HostWatchdog,
257 Unknown,
259}
260
261impl fmt::Display for ProcStopReason {
262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
263 match self {
264 Self::Stopped => write!(f, "stopped"),
265 Self::Exited(code, stderr) => {
266 if stderr.is_empty() {
267 write!(f, "exited with code {}", code)
268 } else {
269 write!(f, "exited with code {}: {}", code, stderr)
270 }
271 }
272 Self::Killed(signal, dumped) => {
273 write!(f, "killed with signal {} (core dumped={})", signal, dumped)
274 }
275 Self::Watchdog => write!(f, "proc watchdog failure"),
276 Self::HostWatchdog => write!(f, "host watchdog failure"),
277 Self::Unknown => write!(f, "unknown"),
278 }
279 }
280}
281
282#[automock]
284#[async_trait]
285pub trait Alloc {
286 async fn next(&mut self) -> Option<ProcState>;
289
290 fn spec(&self) -> &AllocSpec;
292
293 fn extent(&self) -> &Extent;
295
296 fn shape(&self) -> Shape {
298 let slice = Slice::new_row_major(self.extent().sizes());
299 Shape::new(self.extent().labels().to_vec(), slice).unwrap()
300 }
301
302 fn alloc_name(&self) -> &AllocName;
304
305 fn transport(&self) -> ChannelTransport {
307 self.spec().transport.clone()
308 }
309
310 async fn stop(&mut self) -> Result<(), AllocatorError>;
314
315 async fn stop_and_wait(&mut self) -> Result<(), AllocatorError> {
318 tracing::error!(
319 name = "AllocStatus",
320 alloc_name = %self.alloc_name(),
321 status = "StopAndWait",
322 );
323 self.stop().await?;
324 while let Some(event) = self.next().await {
325 tracing::debug!(
326 alloc_name = %self.alloc_name(),
327 "drained event: {event:?}"
328 );
329 }
330 tracing::error!(
331 name = "AllocStatus",
332 alloc_name = %self.alloc_name(),
333 status = "Stopped",
334 );
335 Ok(())
336 }
337
338 fn is_local(&self) -> bool {
341 false
342 }
343
344 fn client_router_addr(&self) -> ChannelAddr {
346 ChannelAddr::any(self.transport())
347 }
348}
349
350#[derive(Debug, Clone, PartialEq, Eq, Hash)]
351pub(crate) struct AllocatedProc {
352 pub create_key: ShortUuid,
353 pub proc_id: hyperactor_reference::ProcId,
354 pub addr: ChannelAddr,
355 pub mesh_agent: hyperactor_reference::ActorRef<ProcAgent>,
356}
357
358impl fmt::Display for AllocatedProc {
359 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
360 write!(
361 f,
362 "AllocatedProc {{ create_key: {}, proc_id: {}, addr: {}, mesh_agent: {} }}",
363 self.create_key, self.proc_id, self.addr, self.mesh_agent
364 )
365 }
366}
367
368#[async_trait]
369pub(crate) trait AllocExt {
370 async fn initialize(&mut self) -> Result<Vec<AllocatedProc>, AllocatorError>;
373}
374
375#[async_trait]
376impl<A: ?Sized + Send + Alloc> AllocExt for A {
377 async fn initialize(&mut self) -> Result<Vec<AllocatedProc>, AllocatorError> {
378 let shape = self.shape().clone();
380
381 let mut created = Ranks::new(shape.slice().len());
382 let mut running = Ranks::new(shape.slice().len());
383
384 while !running.is_full() {
385 let Some(state) = self.next().await else {
386 return Err(AllocatorError::Incomplete(self.extent().clone()));
388 };
389
390 let name = tracing::Span::current()
391 .metadata()
392 .map(|m| m.name())
393 .unwrap_or("initialize");
394 let status = format!("ProcState:{}", state.arm().unwrap_or("unknown"));
395
396 match state {
397 ProcState::Created {
398 create_key, point, ..
399 } => {
400 let rank = point.rank();
401 if let Some(old_create_key) = created.insert(rank, create_key.clone()) {
402 tracing::warn!(
403 name,
404 status,
405 rank,
406 "rank {rank} reassigned from {old_create_key} to {create_key}"
407 );
408 }
409 tracing::info!(
410 name,
411 status,
412 rank,
413 "proc with create key {}, rank {}: created",
414 create_key,
415 rank
416 );
417 }
418 ProcState::Running {
419 create_key,
420 proc_id,
421 mesh_agent,
422 addr,
423 } => {
424 let Some(rank) = created.rank(&create_key) else {
425 tracing::warn!(
426 name,
427 %proc_id,
428 status,
429 "proc id {proc_id} with create key {create_key} \
430 is running, but was not created"
431 );
432 continue;
433 };
434
435 let allocated_proc = AllocatedProc {
436 create_key,
437 proc_id: proc_id.clone(),
438 addr: addr.clone(),
439 mesh_agent: mesh_agent.clone(),
440 };
441 if let Some(old_allocated_proc) = running.insert(*rank, allocated_proc.clone())
442 {
443 tracing::warn!(
444 name,
445 %proc_id,
446 status,
447 rank,
448 "duplicate running notifications for {rank}: \
449 old:{old_allocated_proc}; \
450 new:{allocated_proc}"
451 )
452 }
453 tracing::info!(
454 name,
455 %proc_id,
456 status,
457 "proc {} rank {}: running at addr:{addr} mesh_agent:{mesh_agent}",
458 proc_id,
459 rank
460 );
461 }
462 ProcState::Stopped { create_key, reason } => {
466 tracing::error!(
467 name,
468 status,
469 "allocation failed for proc with create key {}: {}",
470 create_key,
471 reason
472 );
473 return Err(AllocatorError::Other(anyhow::Error::msg(reason)));
474 }
475 ProcState::Failed {
476 alloc_name,
477 description,
478 } => {
479 tracing::error!(
480 name,
481 status,
482 "allocation failed for {}: {}",
483 alloc_name,
484 description
485 );
486 return Err(AllocatorError::Other(anyhow::Error::msg(description)));
487 }
488 }
489 }
490
491 Ok(running.into_iter().map(Option::unwrap).collect())
494 }
495}
496
497pub(crate) fn with_unspecified_port_or_any(addr: &ChannelAddr) -> ChannelAddr {
502 match addr {
503 ChannelAddr::Tcp(socket) => {
504 let mut new_socket = socket.clone();
505 new_socket.set_port(0);
506 ChannelAddr::Tcp(new_socket)
507 }
508 ChannelAddr::MetaTls(TlsAddr { hostname, .. }) => {
509 ChannelAddr::MetaTls(TlsAddr::new(hostname.clone(), 0))
510 }
511 _ => addr.transport().any(),
512 }
513}
514
515pub mod test_utils {
516 use std::time::Duration;
517
518 use hyperactor::Actor;
519 use hyperactor::Context;
520 use hyperactor::Handler;
521 use libc::atexit;
522 use tokio::sync::broadcast::Receiver;
523 use tokio::sync::broadcast::Sender;
524 use typeuri::Named;
525
526 use super::*;
527
528 extern "C" fn exit_handler() {
529 loop {
530 std::thread::sleep(Duration::from_mins(1));
531 }
532 }
533
534 #[derive(Debug, Default)]
539 #[hyperactor::export(
540 spawn = true,
541 handlers = [
542 Wait
543 ],
544 )]
545 pub struct TestActor;
546
547 impl Actor for TestActor {}
548
549 #[derive(Debug, Serialize, Deserialize, Named, Clone)]
550 pub struct Wait;
551
552 #[async_trait]
553 impl Handler<Wait> for TestActor {
554 async fn handle(&mut self, _: &Context<Self>, _: Wait) -> Result<(), anyhow::Error> {
555 unsafe {
558 atexit(exit_handler);
559 }
560 Ok(())
561 }
562 }
563
564 pub struct MockAllocWrapper {
567 pub alloc: MockAlloc,
568 pub block_next_after: usize,
569 notify_tx: Sender<()>,
570 notify_rx: Receiver<()>,
571 next_unblocked: bool,
572 }
573
574 impl MockAllocWrapper {
575 pub fn new(alloc: MockAlloc) -> Self {
576 Self::new_block_next(alloc, usize::MAX)
577 }
578
579 pub fn new_block_next(alloc: MockAlloc, count: usize) -> Self {
580 let (tx, rx) = tokio::sync::broadcast::channel(1);
581 Self {
582 alloc,
583 block_next_after: count,
584 notify_tx: tx,
585 notify_rx: rx,
586 next_unblocked: false,
587 }
588 }
589
590 pub fn notify_tx(&self) -> Sender<()> {
591 self.notify_tx.clone()
592 }
593 }
594
595 #[async_trait]
596 impl Alloc for MockAllocWrapper {
597 async fn next(&mut self) -> Option<ProcState> {
598 match self.block_next_after {
599 0 => {
600 if !self.next_unblocked {
601 self.notify_rx.recv().await.unwrap();
602 self.next_unblocked = true;
603 }
604 }
605 1.. => {
606 self.block_next_after -= 1;
607 }
608 }
609
610 self.alloc.next().await
611 }
612
613 fn spec(&self) -> &AllocSpec {
614 self.alloc.spec()
615 }
616
617 fn extent(&self) -> &Extent {
618 self.alloc.extent()
619 }
620
621 fn alloc_name(&self) -> &AllocName {
622 self.alloc.alloc_name()
623 }
624
625 async fn stop(&mut self) -> Result<(), AllocatorError> {
626 self.alloc.stop().await
627 }
628 }
629}
630
631#[cfg(test)]
632pub(crate) mod testing {
633 use core::panic;
634 use std::collections::HashMap;
635 use std::collections::HashSet;
636 use std::time::Duration;
637
638 use hyperactor::Instance;
639 use hyperactor::actor::remote::Remote;
640 use hyperactor::channel;
641 use hyperactor::context;
642 use hyperactor::mailbox;
643 use hyperactor::mailbox::BoxedMailboxSender;
644 use hyperactor::mailbox::DialMailboxRouter;
645 use hyperactor::mailbox::IntoBoxedMailboxSender;
646 use hyperactor::mailbox::MailboxServer;
647 use hyperactor::mailbox::UndeliverableMailboxSender;
648 use hyperactor::proc::Proc;
649 use hyperactor::reference::Reference;
650 use ndslice::extent;
651 use tokio::process::Command;
652
653 use super::*;
654 use crate::alloc::test_utils::TestActor;
655 use crate::alloc::test_utils::Wait;
656 use crate::proc_agent::GspawnResult;
657 use crate::proc_agent::MeshAgentMessageClient;
658 use crate::transport::default_transport;
659
660 #[macro_export]
661 macro_rules! alloc_test_suite {
662 ($allocator:expr) => {
663 #[tokio::test]
664 async fn test_allocator_basic() {
665 $crate::alloc::testing::test_allocator_basic($allocator).await;
666 }
667 };
668 }
669
670 pub(crate) async fn test_allocator_basic(mut allocator: impl Allocator) {
671 let extent = extent!(replica = 4);
672 let mut alloc = allocator
673 .allocate(AllocSpec {
674 extent: extent.clone(),
675 constraints: Default::default(),
676 proc_name: None,
677 transport: default_transport(),
678 proc_allocation_mode: Default::default(),
679 })
680 .await
681 .unwrap();
682
683 let mut procs = HashMap::new();
686 let mut created = HashMap::new();
687 let mut running = HashSet::new();
688 while running.len() != 4 {
689 match alloc.next().await.unwrap() {
690 ProcState::Created {
691 create_key, point, ..
692 } => {
693 created.insert(create_key, point);
694 }
695 ProcState::Running {
696 create_key,
697 proc_id,
698 ..
699 } => {
700 assert!(running.insert(create_key.clone()));
701 procs.insert(proc_id, created.remove(&create_key).unwrap());
702 }
703 event => panic!("unexpected event: {:?}", event),
704 }
705 }
706
707 let points: HashSet<_> = procs.values().collect();
709 for x in 0..4 {
710 assert!(points.contains(&extent.point(vec![x]).unwrap()));
711 }
712
713 let alloc_names: HashSet<_> = procs
716 .keys()
717 .filter_map(|proc_id| {
718 proc_id
719 .name()
720 .rsplit_once('_')
721 .map(|(prefix, _)| prefix.to_string())
722 })
723 .collect();
724 assert_eq!(alloc_names.len(), 1);
725
726 alloc.stop().await.unwrap();
729 let mut stopped = HashSet::new();
730 while let Some(ProcState::Stopped {
731 create_key, reason, ..
732 }) = alloc.next().await
733 {
734 assert_eq!(reason, ProcStopReason::Stopped);
735 stopped.insert(create_key);
736 }
737 assert!(alloc.next().await.is_none());
738 assert_eq!(stopped, running);
739 }
740
741 async fn spawn_proc(
742 transport: ChannelTransport,
743 ) -> (DialMailboxRouter, Instance<()>, Proc, ChannelAddr) {
744 let (router_channel_addr, router_rx) =
745 channel::serve(ChannelAddr::any(transport.clone())).unwrap();
746 let router =
747 DialMailboxRouter::new_with_default((UndeliverableMailboxSender {}).into_boxed());
748 router.clone().serve(router_rx);
749
750 let client_proc_id = hyperactor_reference::ProcId::with_name(
751 ChannelAddr::any(ChannelTransport::Local),
752 "test_stuck_0",
753 );
754 let (client_proc_addr, client_rx) = channel::serve(ChannelAddr::any(transport)).unwrap();
755 let client_proc = Proc::configured(
756 client_proc_id.clone(),
757 BoxedMailboxSender::new(router.clone()),
758 );
759 client_proc.clone().serve(client_rx);
760 router.bind(client_proc_id.clone().into(), client_proc_addr);
761 (
762 router,
763 client_proc.instance("test_proc").unwrap().0,
764 client_proc,
765 router_channel_addr,
766 )
767 }
768
769 async fn spawn_test_actor(
770 rank: usize,
771 client_proc: &Proc,
772 cx: &impl context::Actor,
773 router_channel_addr: ChannelAddr,
774 mesh_agent: hyperactor_reference::ActorRef<ProcAgent>,
775 ) -> hyperactor_reference::ActorRef<TestActor> {
776 let (supervisor, _supervisor_handle) = client_proc.instance("supervisor").unwrap();
777 let (supervison_port, _) = supervisor.open_port();
778 let (config_handle, _) = cx.mailbox().open_port();
779 mesh_agent
780 .configure(
781 cx,
782 rank,
783 router_channel_addr,
784 Some(supervison_port.bind()),
785 HashMap::new(),
786 config_handle.bind(),
787 false,
788 )
789 .await
790 .unwrap();
791 let remote = Remote::collect();
792 let actor_type = remote
793 .name_of::<TestActor>()
794 .ok_or(anyhow::anyhow!("actor not registered"))
795 .unwrap()
796 .to_string();
797 let params = &();
798 let (completed_handle, mut completed_receiver) = mailbox::open_port(cx);
799 mesh_agent
801 .gspawn(
802 cx,
803 actor_type,
804 "Stuck".to_string(),
805 bincode::serialize(params).unwrap(),
806 completed_handle.bind(),
807 )
808 .await
809 .unwrap();
810 let result = completed_receiver.recv().await.unwrap();
811 match result {
812 GspawnResult::Success { actor_id, .. } => {
813 hyperactor_reference::ActorRef::attest(actor_id)
814 }
815 GspawnResult::Error(error_msg) => {
816 panic!("gspawn failed: {}", error_msg);
817 }
818 }
819 }
820
821 #[tokio::test]
826 #[cfg(fbcode_build)]
827 async fn test_allocator_stuck_task() {
828 let config = hyperactor_config::global::lock();
831 let _guard = config.override_key(
832 hyperactor::config::PROCESS_EXIT_TIMEOUT,
833 Duration::from_secs(1),
834 );
835
836 let command = Command::new(crate::testresource::get(
837 "monarch/hyperactor_mesh/bootstrap",
838 ));
839 let mut allocator = ProcessAllocator::new(command);
840 let mut alloc = allocator
841 .allocate(AllocSpec {
842 extent: extent! { replica = 1 },
843 constraints: Default::default(),
844 proc_name: None,
845 transport: ChannelTransport::Unix,
846 proc_allocation_mode: Default::default(),
847 })
848 .await
849 .unwrap();
850
851 let mut procs = HashMap::new();
853 let mut running = HashSet::new();
854 let mut actor_ref = None;
855 let (router, client, client_proc, router_addr) = spawn_proc(alloc.transport()).await;
856 while running.is_empty() {
857 match alloc.next().await.unwrap() {
858 ProcState::Created {
859 create_key, point, ..
860 } => {
861 procs.insert(create_key, point);
862 }
863 ProcState::Running {
864 create_key,
865 proc_id,
866 mesh_agent,
867 addr,
868 } => {
869 router.bind(Reference::Proc(proc_id.clone()), addr.clone());
870
871 assert!(procs.contains_key(&create_key));
872 assert!(!running.contains(&create_key));
873
874 actor_ref = Some(
875 spawn_test_actor(0, &client_proc, &client, router_addr, mesh_agent).await,
876 );
877 running.insert(create_key.clone());
878 break;
879 }
880 event => panic!("unexpected event: {:?}", event),
881 }
882 }
883 assert!(actor_ref.unwrap().send(&client, Wait).is_ok());
884
885 alloc.stop().await.unwrap();
887 let mut stopped = HashSet::new();
888 while let Some(ProcState::Stopped {
889 create_key, reason, ..
890 }) = alloc.next().await
891 {
892 assert_eq!(reason, ProcStopReason::Watchdog);
893 stopped.insert(create_key);
894 }
895 assert!(alloc.next().await.is_none());
896 assert_eq!(stopped, running);
897 }
898}