1pub mod local;
13pub mod process;
14pub mod remoteprocess;
15pub mod sim;
16
17use std::collections::HashMap;
18use std::fmt;
19
20use async_trait::async_trait;
21use enum_as_inner::EnumAsInner;
22use hyperactor::ActorRef;
23use hyperactor::ProcId;
24use hyperactor::WorldId;
25use hyperactor::channel::ChannelAddr;
26use hyperactor::channel::ChannelTransport;
27use hyperactor::channel::MetaTlsAddr;
28use hyperactor_config::CONFIG;
29use hyperactor_config::ConfigAttr;
30use hyperactor_config::attrs::declare_attrs;
31pub use local::LocalAlloc;
32pub use local::LocalAllocator;
33use mockall::predicate::*;
34use mockall::*;
35use ndslice::Shape;
36use ndslice::Slice;
37use ndslice::view::Extent;
38use ndslice::view::Point;
39pub use process::ProcessAlloc;
40pub use process::ProcessAllocator;
41use serde::Deserialize;
42use serde::Serialize;
43use strum::AsRefStr;
44use typeuri::Named;
45
46use crate::alloc::test_utils::MockAllocWrapper;
47use crate::assign::Ranks;
48use crate::proc_mesh::mesh_agent::ProcMeshAgent;
49use crate::shortuuid::ShortUuid;
50
51#[derive(Debug, thiserror::Error)]
53pub enum AllocatorError {
54 #[error("incomplete allocation; expected: {0}")]
55 Incomplete(Extent),
56
57 #[error("not enough resources; requested: {requested}, available: {available}")]
59 NotEnoughResources { requested: Extent, available: usize },
60
61 #[error(transparent)]
63 Other(#[from] anyhow::Error),
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize, Default)]
68pub struct AllocConstraints {
69 pub match_labels: HashMap<String, String>,
72}
73
74#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
76#[derive(Default)]
77pub enum ProcAllocationMode {
78 #[default]
83 ProcLevel,
84 HostLevel,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92pub struct AllocSpec {
93 pub extent: Extent,
98
99 pub constraints: AllocConstraints,
101
102 pub proc_name: Option<String>,
105
106 pub transport: ChannelTransport,
108
109 #[serde(default = "default_proc_allocation_mode")]
112 pub proc_allocation_mode: ProcAllocationMode,
113}
114
115fn default_proc_allocation_mode() -> ProcAllocationMode {
116 ProcAllocationMode::ProcLevel
117}
118
119#[automock(type Alloc=MockAllocWrapper;)]
121#[async_trait]
122pub trait Allocator {
123 type Alloc: Alloc;
125
126 async fn allocate(&mut self, spec: AllocSpec) -> Result<Self::Alloc, AllocatorError>;
131}
132
133#[derive(
136 Clone,
137 Debug,
138 PartialEq,
139 EnumAsInner,
140 Serialize,
141 Deserialize,
142 AsRefStr,
143 Named
144)]
145pub enum ProcState {
146 Created {
148 create_key: ShortUuid,
151 point: Point,
153 pid: u32,
155 },
156 Running {
158 create_key: ShortUuid,
160 proc_id: ProcId,
162 mesh_agent: ActorRef<ProcMeshAgent>,
165 addr: ChannelAddr,
168 },
169 Stopped {
171 create_key: ShortUuid,
172 reason: ProcStopReason,
173 },
174 Failed {
182 world_id: WorldId,
186 description: String,
188 },
189}
190
191impl fmt::Display for ProcState {
192 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
193 match self {
194 ProcState::Created {
195 create_key,
196 point,
197 pid,
198 } => {
199 write!(f, "{}: created at ({}) with PID {}", create_key, point, pid)
200 }
201 ProcState::Running { proc_id, addr, .. } => {
202 write!(f, "{}: running at {}", proc_id, addr)
203 }
204 ProcState::Stopped { create_key, reason } => {
205 write!(f, "{}: stopped: {}", create_key, reason)
206 }
207 ProcState::Failed {
208 description,
209 world_id,
210 } => {
211 write!(f, "{}: failed: {}", world_id, description)
212 }
213 }
214 }
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, EnumAsInner)]
219pub enum ProcStopReason {
220 Stopped,
222 Exited(i32, String),
224 Killed(i32, bool),
227 Watchdog,
229 HostWatchdog,
232 Unknown,
234}
235
236impl fmt::Display for ProcStopReason {
237 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238 match self {
239 Self::Stopped => write!(f, "stopped"),
240 Self::Exited(code, stderr) => {
241 if stderr.is_empty() {
242 write!(f, "exited with code {}", code)
243 } else {
244 write!(f, "exited with code {}: {}", code, stderr)
245 }
246 }
247 Self::Killed(signal, dumped) => {
248 write!(f, "killed with signal {} (core dumped={})", signal, dumped)
249 }
250 Self::Watchdog => write!(f, "proc watchdog failure"),
251 Self::HostWatchdog => write!(f, "host watchdog failure"),
252 Self::Unknown => write!(f, "unknown"),
253 }
254 }
255}
256
257#[automock]
259#[async_trait]
260pub trait Alloc {
261 async fn next(&mut self) -> Option<ProcState>;
264
265 fn spec(&self) -> &AllocSpec;
267
268 fn extent(&self) -> &Extent;
270
271 fn shape(&self) -> Shape {
273 let slice = Slice::new_row_major(self.extent().sizes());
274 Shape::new(self.extent().labels().to_vec(), slice).unwrap()
275 }
276
277 fn world_id(&self) -> &WorldId;
281
282 fn transport(&self) -> ChannelTransport {
284 self.spec().transport.clone()
285 }
286
287 async fn stop(&mut self) -> Result<(), AllocatorError>;
291
292 async fn stop_and_wait(&mut self) -> Result<(), AllocatorError> {
295 tracing::error!(
296 name = "AllocStatus",
297 alloc_name = %self.world_id(),
298 status = "StopAndWait",
299 );
300 self.stop().await?;
301 while let Some(event) = self.next().await {
302 tracing::debug!(
303 alloc_name = %self.world_id(),
304 "drained event: {event:?}"
305 );
306 }
307 tracing::error!(
308 name = "AllocStatus",
309 alloc_name = %self.world_id(),
310 status = "Stopped",
311 );
312 Ok(())
313 }
314
315 fn is_local(&self) -> bool {
318 false
319 }
320
321 fn client_router_addr(&self) -> ChannelAddr {
323 ChannelAddr::any(self.transport())
324 }
325}
326
327#[derive(Debug, Clone, PartialEq, Eq, Hash)]
328pub(crate) struct AllocatedProc {
329 pub create_key: ShortUuid,
330 pub proc_id: ProcId,
331 pub addr: ChannelAddr,
332 pub mesh_agent: ActorRef<ProcMeshAgent>,
333}
334
335impl fmt::Display for AllocatedProc {
336 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
337 write!(
338 f,
339 "AllocatedProc {{ create_key: {}, proc_id: {}, addr: {}, mesh_agent: {} }}",
340 self.create_key, self.proc_id, self.addr, self.mesh_agent
341 )
342 }
343}
344
345#[async_trait]
346pub(crate) trait AllocExt {
347 async fn initialize(&mut self) -> Result<Vec<AllocatedProc>, AllocatorError>;
350}
351
352#[async_trait]
353impl<A: ?Sized + Send + Alloc> AllocExt for A {
354 async fn initialize(&mut self) -> Result<Vec<AllocatedProc>, AllocatorError> {
355 let shape = self.shape().clone();
357
358 let mut created = Ranks::new(shape.slice().len());
359 let mut running = Ranks::new(shape.slice().len());
360
361 while !running.is_full() {
362 let Some(state) = self.next().await else {
363 return Err(AllocatorError::Incomplete(self.extent().clone()));
365 };
366
367 let name = tracing::Span::current()
368 .metadata()
369 .map(|m| m.name())
370 .unwrap_or("initialize");
371 let status = format!("ProcState:{}", state.arm().unwrap_or("unknown"));
372
373 match state {
374 ProcState::Created {
375 create_key, point, ..
376 } => {
377 let rank = point.rank();
378 if let Some(old_create_key) = created.insert(rank, create_key.clone()) {
379 tracing::warn!(
380 name,
381 status,
382 rank,
383 "rank {rank} reassigned from {old_create_key} to {create_key}"
384 );
385 }
386 tracing::info!(
387 name,
388 status,
389 rank,
390 "proc with create key {}, rank {}: created",
391 create_key,
392 rank
393 );
394 }
395 ProcState::Running {
396 create_key,
397 proc_id,
398 mesh_agent,
399 addr,
400 } => {
401 let Some(rank) = created.rank(&create_key) else {
402 tracing::warn!(
403 name,
404 %proc_id,
405 status,
406 "proc id {proc_id} with create key {create_key} \
407 is running, but was not created"
408 );
409 continue;
410 };
411
412 let allocated_proc = AllocatedProc {
413 create_key,
414 proc_id: proc_id.clone(),
415 addr: addr.clone(),
416 mesh_agent: mesh_agent.clone(),
417 };
418 if let Some(old_allocated_proc) = running.insert(*rank, allocated_proc.clone())
419 {
420 tracing::warn!(
421 name,
422 %proc_id,
423 status,
424 rank,
425 "duplicate running notifications for {rank}: \
426 old:{old_allocated_proc}; \
427 new:{allocated_proc}"
428 )
429 }
430 tracing::info!(
431 name,
432 %proc_id,
433 status,
434 "proc {} rank {}: running at addr:{addr} mesh_agent:{mesh_agent}",
435 proc_id,
436 rank
437 );
438 }
439 ProcState::Stopped { create_key, reason } => {
443 tracing::error!(
444 name,
445 status,
446 "allocation failed for proc with create key {}: {}",
447 create_key,
448 reason
449 );
450 return Err(AllocatorError::Other(anyhow::Error::msg(reason)));
451 }
452 ProcState::Failed {
453 world_id,
454 description,
455 } => {
456 tracing::error!(
457 name,
458 status,
459 "allocation failed for world {}: {}",
460 world_id,
461 description
462 );
463 return Err(AllocatorError::Other(anyhow::Error::msg(description)));
464 }
465 }
466 }
467
468 Ok(running.into_iter().map(Option::unwrap).collect())
471 }
472}
473
474pub(crate) fn with_unspecified_port_or_any(addr: &ChannelAddr) -> ChannelAddr {
479 match addr {
480 ChannelAddr::Tcp(socket) => {
481 let mut new_socket = socket.clone();
482 new_socket.set_port(0);
483 ChannelAddr::Tcp(new_socket)
484 }
485 ChannelAddr::MetaTls(MetaTlsAddr::Socket(socket)) => {
486 let mut new_socket = socket.clone();
487 new_socket.set_port(0);
488 ChannelAddr::MetaTls(MetaTlsAddr::Socket(new_socket))
489 }
490 ChannelAddr::MetaTls(MetaTlsAddr::Host { hostname, port: _ }) => {
491 ChannelAddr::MetaTls(MetaTlsAddr::Host {
492 hostname: hostname.clone(),
493 port: 0,
494 })
495 }
496 _ => addr.transport().any(),
497 }
498}
499
500pub mod test_utils {
501 use std::time::Duration;
502
503 use hyperactor::Actor;
504 use hyperactor::Context;
505 use hyperactor::Handler;
506 use libc::atexit;
507 use tokio::sync::broadcast::Receiver;
508 use tokio::sync::broadcast::Sender;
509 use typeuri::Named;
510
511 use super::*;
512
513 extern "C" fn exit_handler() {
514 loop {
515 #[allow(clippy::disallowed_methods)]
516 std::thread::sleep(Duration::from_mins(1));
517 }
518 }
519
520 #[derive(Debug, Default)]
525 #[hyperactor::export(
526 spawn = true,
527 handlers = [
528 Wait
529 ],
530 )]
531 pub struct TestActor;
532
533 impl Actor for TestActor {}
534
535 #[derive(Debug, Serialize, Deserialize, Named, Clone)]
536 pub struct Wait;
537
538 #[async_trait]
539 impl Handler<Wait> for TestActor {
540 async fn handle(&mut self, _: &Context<Self>, _: Wait) -> Result<(), anyhow::Error> {
541 unsafe {
544 atexit(exit_handler);
545 }
546 Ok(())
547 }
548 }
549
550 pub struct MockAllocWrapper {
553 pub alloc: MockAlloc,
554 pub block_next_after: usize,
555 notify_tx: Sender<()>,
556 notify_rx: Receiver<()>,
557 next_unblocked: bool,
558 }
559
560 impl MockAllocWrapper {
561 pub fn new(alloc: MockAlloc) -> Self {
562 Self::new_block_next(alloc, usize::MAX)
563 }
564
565 pub fn new_block_next(alloc: MockAlloc, count: usize) -> Self {
566 let (tx, rx) = tokio::sync::broadcast::channel(1);
567 Self {
568 alloc,
569 block_next_after: count,
570 notify_tx: tx,
571 notify_rx: rx,
572 next_unblocked: false,
573 }
574 }
575
576 pub fn notify_tx(&self) -> Sender<()> {
577 self.notify_tx.clone()
578 }
579 }
580
581 #[async_trait]
582 impl Alloc for MockAllocWrapper {
583 async fn next(&mut self) -> Option<ProcState> {
584 match self.block_next_after {
585 0 => {
586 if !self.next_unblocked {
587 self.notify_rx.recv().await.unwrap();
588 self.next_unblocked = true;
589 }
590 }
591 1.. => {
592 self.block_next_after -= 1;
593 }
594 }
595
596 self.alloc.next().await
597 }
598
599 fn spec(&self) -> &AllocSpec {
600 self.alloc.spec()
601 }
602
603 fn extent(&self) -> &Extent {
604 self.alloc.extent()
605 }
606
607 fn world_id(&self) -> &WorldId {
608 self.alloc.world_id()
609 }
610
611 async fn stop(&mut self) -> Result<(), AllocatorError> {
612 self.alloc.stop().await
613 }
614 }
615}
616
617#[cfg(test)]
618pub(crate) mod testing {
619 use core::panic;
620 use std::collections::HashMap;
621 use std::collections::HashSet;
622 use std::time::Duration;
623
624 use hyperactor::Instance;
625 use hyperactor::actor::remote::Remote;
626 use hyperactor::channel;
627 use hyperactor::context;
628 use hyperactor::mailbox;
629 use hyperactor::mailbox::BoxedMailboxSender;
630 use hyperactor::mailbox::DialMailboxRouter;
631 use hyperactor::mailbox::IntoBoxedMailboxSender;
632 use hyperactor::mailbox::MailboxServer;
633 use hyperactor::mailbox::UndeliverableMailboxSender;
634 use hyperactor::proc::Proc;
635 use hyperactor::reference::Reference;
636 use ndslice::extent;
637 use tokio::process::Command;
638
639 use super::*;
640 use crate::alloc::test_utils::TestActor;
641 use crate::alloc::test_utils::Wait;
642 use crate::proc_mesh::default_transport;
643 use crate::proc_mesh::mesh_agent::GspawnResult;
644 use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
645
646 #[macro_export]
647 macro_rules! alloc_test_suite {
648 ($allocator:expr) => {
649 #[tokio::test]
650 async fn test_allocator_basic() {
651 $crate::alloc::testing::test_allocator_basic($allocator).await;
652 }
653 };
654 }
655
656 pub(crate) async fn test_allocator_basic(mut allocator: impl Allocator) {
657 let extent = extent!(replica = 4);
658 let mut alloc = allocator
659 .allocate(AllocSpec {
660 extent: extent.clone(),
661 constraints: Default::default(),
662 proc_name: None,
663 transport: default_transport(),
664 proc_allocation_mode: Default::default(),
665 })
666 .await
667 .unwrap();
668
669 let mut procs = HashMap::new();
672 let mut created = HashMap::new();
673 let mut running = HashSet::new();
674 while running.len() != 4 {
675 match alloc.next().await.unwrap() {
676 ProcState::Created {
677 create_key, point, ..
678 } => {
679 created.insert(create_key, point);
680 }
681 ProcState::Running {
682 create_key,
683 proc_id,
684 ..
685 } => {
686 assert!(running.insert(create_key.clone()));
687 procs.insert(proc_id, created.remove(&create_key).unwrap());
688 }
689 event => panic!("unexpected event: {:?}", event),
690 }
691 }
692
693 let points: HashSet<_> = procs.values().collect();
695 for x in 0..4 {
696 assert!(points.contains(&extent.point(vec![x]).unwrap()));
697 }
698
699 let worlds: HashSet<_> = procs.keys().map(|proc_id| proc_id.world_id()).collect();
701 assert_eq!(worlds.len(), 1);
702
703 alloc.stop().await.unwrap();
706 let mut stopped = HashSet::new();
707 while let Some(ProcState::Stopped {
708 create_key, reason, ..
709 }) = alloc.next().await
710 {
711 assert_eq!(reason, ProcStopReason::Stopped);
712 stopped.insert(create_key);
713 }
714 assert!(alloc.next().await.is_none());
715 assert_eq!(stopped, running);
716 }
717
718 async fn spawn_proc(
719 transport: ChannelTransport,
720 ) -> (DialMailboxRouter, Instance<()>, Proc, ChannelAddr) {
721 let (router_channel_addr, router_rx) =
722 channel::serve(ChannelAddr::any(transport.clone())).unwrap();
723 let router =
724 DialMailboxRouter::new_with_default((UndeliverableMailboxSender {}).into_boxed());
725 router.clone().serve(router_rx);
726
727 let client_proc_id = ProcId::Ranked(WorldId("test_stuck".to_string()), 0);
728 let (client_proc_addr, client_rx) = channel::serve(ChannelAddr::any(transport)).unwrap();
729 let client_proc = Proc::new(
730 client_proc_id.clone(),
731 BoxedMailboxSender::new(router.clone()),
732 );
733 client_proc.clone().serve(client_rx);
734 router.bind(client_proc_id.clone().into(), client_proc_addr);
735 (
736 router,
737 client_proc.instance("test_proc").unwrap().0,
738 client_proc,
739 router_channel_addr,
740 )
741 }
742
743 async fn spawn_test_actor(
744 rank: usize,
745 client_proc: &Proc,
746 cx: &impl context::Actor,
747 router_channel_addr: ChannelAddr,
748 mesh_agent: ActorRef<ProcMeshAgent>,
749 ) -> ActorRef<TestActor> {
750 let (supervisor, _supervisor_handle) = client_proc.instance("supervisor").unwrap();
751 let (supervison_port, _) = supervisor.open_port();
752 let (config_handle, _) = cx.mailbox().open_port();
753 mesh_agent
754 .configure(
755 cx,
756 rank,
757 router_channel_addr,
758 Some(supervison_port.bind()),
759 HashMap::new(),
760 config_handle.bind(),
761 false,
762 )
763 .await
764 .unwrap();
765 let remote = Remote::collect();
766 let actor_type = remote
767 .name_of::<TestActor>()
768 .ok_or(anyhow::anyhow!("actor not registered"))
769 .unwrap()
770 .to_string();
771 let params = &();
772 let (completed_handle, mut completed_receiver) = mailbox::open_port(cx);
773 mesh_agent
775 .gspawn(
776 cx,
777 actor_type,
778 "Stuck".to_string(),
779 bincode::serialize(params).unwrap(),
780 completed_handle.bind(),
781 )
782 .await
783 .unwrap();
784 let result = completed_receiver.recv().await.unwrap();
785 match result {
786 GspawnResult::Success { actor_id, .. } => ActorRef::attest(actor_id),
787 GspawnResult::Error(error_msg) => {
788 panic!("gspawn failed: {}", error_msg);
789 }
790 }
791 }
792
793 #[tokio::test]
798 #[cfg(fbcode_build)]
799 async fn test_allocator_stuck_task() {
800 let config = hyperactor_config::global::lock();
803 let _guard = config.override_key(
804 hyperactor::config::PROCESS_EXIT_TIMEOUT,
805 Duration::from_secs(1),
806 );
807
808 let command = Command::new(crate::testresource::get(
809 "monarch/hyperactor_mesh/bootstrap",
810 ));
811 let mut allocator = ProcessAllocator::new(command);
812 let mut alloc = allocator
813 .allocate(AllocSpec {
814 extent: extent! { replica = 1 },
815 constraints: Default::default(),
816 proc_name: None,
817 transport: ChannelTransport::Unix,
818 proc_allocation_mode: Default::default(),
819 })
820 .await
821 .unwrap();
822
823 let mut procs = HashMap::new();
825 let mut running = HashSet::new();
826 let mut actor_ref = None;
827 let (router, client, client_proc, router_addr) = spawn_proc(alloc.transport()).await;
828 while running.is_empty() {
829 match alloc.next().await.unwrap() {
830 ProcState::Created {
831 create_key, point, ..
832 } => {
833 procs.insert(create_key, point);
834 }
835 ProcState::Running {
836 create_key,
837 proc_id,
838 mesh_agent,
839 addr,
840 } => {
841 router.bind(Reference::Proc(proc_id.clone()), addr.clone());
842
843 assert!(procs.contains_key(&create_key));
844 assert!(!running.contains(&create_key));
845
846 actor_ref = Some(
847 spawn_test_actor(0, &client_proc, &client, router_addr, mesh_agent).await,
848 );
849 running.insert(create_key.clone());
850 break;
851 }
852 event => panic!("unexpected event: {:?}", event),
853 }
854 }
855 assert!(actor_ref.unwrap().send(&client, Wait).is_ok());
856
857 alloc.stop().await.unwrap();
859 let mut stopped = HashSet::new();
860 while let Some(ProcState::Stopped {
861 create_key, reason, ..
862 }) = alloc.next().await
863 {
864 assert_eq!(reason, ProcStopReason::Watchdog);
865 stopped.insert(create_key);
866 }
867 assert!(alloc.next().await.is_none());
868 assert_eq!(stopped, running);
869 }
870}