1use std::collections::HashMap;
10use std::collections::HashSet;
11use std::collections::VecDeque;
12use std::sync::Arc;
13use std::time::Duration;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use dashmap::DashMap;
18use futures::FutureExt;
19use futures::future::join_all;
20use futures::future::select_all;
21use hyperactor::Named;
22use hyperactor::WorldId;
23use hyperactor::channel;
24use hyperactor::channel::ChannelAddr;
25use hyperactor::channel::ChannelRx;
26use hyperactor::channel::ChannelTransport;
27use hyperactor::channel::ChannelTx;
28use hyperactor::channel::Rx;
29use hyperactor::channel::Tx;
30use hyperactor::channel::TxStatus;
31use hyperactor::clock;
32use hyperactor::clock::Clock;
33use hyperactor::clock::RealClock;
34use hyperactor::config;
35use hyperactor::mailbox::DialMailboxRouter;
36use hyperactor::mailbox::MailboxServer;
37use hyperactor::observe_async;
38use hyperactor::observe_result;
39use hyperactor::reference::Reference;
40use hyperactor::serde_json;
41use mockall::automock;
42use ndslice::Region;
43use ndslice::View;
44use ndslice::ViewExt;
45use ndslice::view::Extent;
46use ndslice::view::Point;
47use serde::Deserialize;
48use serde::Serialize;
49use strum::AsRefStr;
50use tokio::io::AsyncWriteExt;
51use tokio::process::Command;
52use tokio::sync::mpsc::UnboundedReceiver;
53use tokio::sync::mpsc::UnboundedSender;
54use tokio::sync::mpsc::unbounded_channel;
55use tokio::task::JoinHandle;
56use tokio_stream::StreamExt;
57use tokio_stream::wrappers::WatchStream;
58use tokio_util::sync::CancellationToken;
59
60use crate::alloc::Alloc;
61use crate::alloc::AllocConstraints;
62use crate::alloc::AllocSpec;
63use crate::alloc::Allocator;
64use crate::alloc::AllocatorError;
65use crate::alloc::ProcState;
66use crate::alloc::ProcStopReason;
67use crate::alloc::ProcessAllocator;
68use crate::alloc::process::CLIENT_TRACE_ID_LABEL;
69use crate::alloc::process::ClientContext;
70use crate::shortuuid::ShortUuid;
71
72#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
74pub enum RemoteProcessAllocatorMessage {
75 Allocate {
77 alloc_key: ShortUuid,
79 extent: Extent,
81 bootstrap_addr: ChannelAddr,
83 hosts: Vec<String>,
86 client_context: Option<ClientContext>,
90 },
91 Stop,
93 HeartBeat,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
102pub enum RemoteProcessProcStateMessage {
103 Allocated {
105 alloc_key: ShortUuid,
106 world_id: WorldId,
107 },
108 Update(ShortUuid, ProcState),
110 Done(ShortUuid),
112 HeartBeat,
114}
115
116pub struct RemoteProcessAllocator {
118 cancel_token: CancellationToken,
119}
120
121async fn conditional_sleeper<F: futures::Future<Output = ()>>(t: Option<F>) {
122 match t {
123 Some(timer) => timer.await,
124 None => futures::future::pending().await,
125 }
126}
127
128impl RemoteProcessAllocator {
129 pub fn new() -> Arc<Self> {
131 Arc::new(Self {
132 cancel_token: CancellationToken::new(),
133 })
134 }
135
136 pub fn terminate(&self) {
138 self.cancel_token.cancel();
139 }
140
141 #[hyperactor::instrument]
157 pub async fn start(
158 &self,
159 cmd: Command,
160 serve_addr: ChannelAddr,
161 timeout: Option<Duration>,
162 ) -> Result<(), anyhow::Error> {
163 let process_allocator = ProcessAllocator::new(cmd);
164 self.start_with_allocator(serve_addr, process_allocator, timeout)
165 .await
166 }
167
168 pub async fn start_with_allocator<A: Allocator + Send + Sync + 'static>(
172 &self,
173 serve_addr: ChannelAddr,
174 mut process_allocator: A,
175 timeout: Option<Duration>,
176 ) -> Result<(), anyhow::Error>
177 where
178 <A as Allocator>::Alloc: Send,
179 <A as Allocator>::Alloc: Sync,
180 {
181 tracing::info!("starting remote allocator on: {}", serve_addr);
182 let (_, mut rx) = channel::serve(serve_addr.clone()).map_err(anyhow::Error::from)?;
183
184 struct ActiveAllocation {
185 handle: JoinHandle<()>,
186 cancel_token: CancellationToken,
187 }
188 #[observe_async("RemoteProcessAllocator")]
189 async fn ensure_previous_alloc_stopped(active_allocation: &mut Option<ActiveAllocation>) {
190 if let Some(active_allocation) = active_allocation.take() {
191 tracing::info!("previous alloc found, stopping");
192 active_allocation.cancel_token.cancel();
193 match active_allocation.handle.await {
194 Ok(_) => {
195 tracing::info!("allocation stopped.")
197 }
198 Err(e) => {
199 tracing::error!("allocation handler failed: {}", e);
200 }
201 }
202 }
203 }
204
205 let mut active_allocation: Option<ActiveAllocation> = None;
206 loop {
207 let sleep = conditional_sleeper(timeout.map(|t| RealClock.sleep(t)));
210 tokio::select! {
211 msg = rx.recv() => {
212 match msg {
213 Ok(RemoteProcessAllocatorMessage::Allocate {
214 alloc_key,
215 extent,
216 bootstrap_addr,
217 hosts,
218 client_context,
219 }) => {
220 tracing::info!("received allocation request for {} with extent {}", alloc_key, extent);
221 ensure_previous_alloc_stopped(&mut active_allocation).await;
222
223 let mut constraints: AllocConstraints = Default::default();
225 if let Some(context) = &client_context {
226 constraints = AllocConstraints {
227 match_labels: HashMap::from([(
228 CLIENT_TRACE_ID_LABEL.to_string(),
229 context.trace_id.to_string(),
230 )]
231 )};
232 tracing::info!(
233 monarch_client_trace_id = context.trace_id.to_string(),
234 "allocating...",
235 );
236 }
237
238
239 let spec = AllocSpec {
240 extent,
241 constraints,
242 proc_name: None, transport: ChannelTransport::Unix,
244 };
245
246 match process_allocator.allocate(spec.clone()).await {
247 Ok(alloc) => {
248 let cancel_token = CancellationToken::new();
249 active_allocation = Some(ActiveAllocation {
250 cancel_token: cancel_token.clone(),
251 handle: tokio::spawn(Self::handle_allocation_request(
252 Box::new(alloc) as Box<dyn Alloc + Send + Sync>,
253 alloc_key,
254 serve_addr.transport(),
255 bootstrap_addr,
256 hosts,
257 cancel_token,
258 )),
259 })
260 }
261 Err(e) => {
262 tracing::error!("allocation for {:?} failed: {}", spec, e);
263 continue;
264 }
265 }
266 }
267 Ok(RemoteProcessAllocatorMessage::Stop) => {
268 tracing::info!("received stop request");
269
270 ensure_previous_alloc_stopped(&mut active_allocation).await;
271 }
272 Ok(RemoteProcessAllocatorMessage::HeartBeat) => {}
276 Err(e) => {
277 tracing::error!("upstream channel error: {}", e);
278 continue;
279 }
280 }
281 }
282 _ = self.cancel_token.cancelled() => {
283 tracing::info!("main loop cancelled");
284
285 ensure_previous_alloc_stopped(&mut active_allocation).await;
286
287 break;
288 }
289 _ = sleep => {
290 if active_allocation.is_some() {
292 continue;
293 }
294 tracing::warn!("timeout of {} seconds elapsed without any allocations, exiting", timeout.unwrap_or_default().as_secs());
297 break;
298 }
299 }
300 }
301
302 Ok(())
303 }
304
305 #[observe_async("RemoteProcessAllocator")]
306 async fn handle_allocation_request(
307 alloc: Box<dyn Alloc + Send + Sync>,
308 alloc_key: ShortUuid,
309 serve_transport: ChannelTransport,
310 bootstrap_addr: ChannelAddr,
311 hosts: Vec<String>,
312 cancel_token: CancellationToken,
313 ) {
314 tracing::info!("handle allocation request, bootstrap_addr: {bootstrap_addr}");
315 let (forwarder_addr, forwarder_rx) = match channel::serve(ChannelAddr::any(serve_transport))
320 {
321 Ok(v) => v,
322 Err(e) => {
323 tracing::error!("failed to to bootstrap forwarder actor: {}", e);
324 return;
325 }
326 };
327 let router = DialMailboxRouter::new();
328 let mailbox_handle = router.clone().serve(forwarder_rx);
329 tracing::info!("started forwarder on: {}", forwarder_addr);
330
331 if let Ok(hosts_file) = std::env::var("TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE") {
334 tracing::info!("writing hosts to {}", hosts_file);
335 #[derive(Serialize)]
336 struct Hosts {
337 hostnames: Vec<String>,
338 }
339 match serde_json::to_string(&Hosts { hostnames: hosts }) {
340 Ok(json) => match tokio::fs::File::create(&hosts_file).await {
341 Ok(mut file) => {
342 if file.write_all(json.as_bytes()).await.is_err() {
343 tracing::error!("failed to write hosts to {}", hosts_file);
344 return;
345 }
346 }
347 Err(e) => {
348 tracing::error!("failed to open hosts file {}: {}", hosts_file, e);
349 return;
350 }
351 },
352 Err(e) => {
353 tracing::error!("failed to serialize hosts: {}", e);
354 return;
355 }
356 }
357 }
358
359 Self::handle_allocation_loop(
360 alloc,
361 alloc_key,
362 bootstrap_addr,
363 router,
364 forwarder_addr,
365 cancel_token,
366 )
367 .await;
368
369 mailbox_handle.stop("alloc stopped");
370 if let Err(e) = mailbox_handle.await {
371 tracing::error!("failed to join forwarder: {}", e);
372 }
373 }
374
375 async fn handle_allocation_loop(
376 mut alloc: Box<dyn Alloc + Send + Sync>,
377 alloc_key: ShortUuid,
378 bootstrap_addr: ChannelAddr,
379 router: DialMailboxRouter,
380 forward_addr: ChannelAddr,
381 cancel_token: CancellationToken,
382 ) {
383 let world_id = alloc.world_id().clone();
384 tracing::info!("starting handle allocation loop for {}", world_id);
385 let tx = match channel::dial(bootstrap_addr) {
386 Ok(tx) => tx,
387 Err(err) => {
388 tracing::error!("failed to dial bootstrap address: {}", err);
389 return;
390 }
391 };
392 let message = RemoteProcessProcStateMessage::Allocated {
393 alloc_key: alloc_key.clone(),
394 world_id,
395 };
396 tracing::info!(name = message.as_ref(), "sending allocated message",);
397 if let Err(e) = tx.send(message).await {
398 tracing::error!("failed to send Allocated message: {}", e);
399 return;
400 }
401
402 let mut mesh_agents_by_create_key = HashMap::new();
403 let mut running = true;
404 let tx_status = tx.status().clone();
405 let mut tx_watcher = WatchStream::new(tx_status);
406 loop {
407 tokio::select! {
408 _ = cancel_token.cancelled(), if running => {
409 tracing::info!("cancelled, stopping allocation");
410 running = false;
411 if let Err(e) = alloc.stop().await {
412 tracing::error!("stop failed: {}", e);
413 break;
414 }
415 }
416 status = tx_watcher.next(), if running => {
417 match status {
418 Some(TxStatus::Closed) => {
419 tracing::error!("upstream channel state closed");
420 break;
421 },
422 _ => {
423 tracing::debug!("got channel event: {:?}", status.unwrap());
424 continue;
425 }
426 }
427 }
428 e = alloc.next() => {
429 match e {
430 Some(event) => {
431 tracing::debug!(name = event.as_ref(), "got event: {:?}", event);
432 let event = match event {
433 ProcState::Created { .. } => event,
434 ProcState::Running { create_key, proc_id, mesh_agent, addr } => {
435 tracing::debug!("remapping mesh_agent {}: addr {} -> {}", mesh_agent, addr, forward_addr);
437 mesh_agents_by_create_key.insert(create_key.clone(), mesh_agent.clone());
438 router.bind(mesh_agent.actor_id().proc_id().clone().into(), addr);
439 ProcState::Running { create_key, proc_id, mesh_agent, addr: forward_addr.clone() }
440 },
441 ProcState::Stopped { create_key, reason } => {
442 match mesh_agents_by_create_key.remove(&create_key) {
443 Some(mesh_agent) => {
444 tracing::debug!("unmapping mesh_agent {}", mesh_agent);
445 let agent_ref: Reference = mesh_agent.actor_id().proc_id().clone().into();
446 router.unbind(&agent_ref);
447 },
448 None => {
449 tracing::warn!("mesh_agent not found for create key {}", create_key);
450 }
451 }
452 ProcState::Stopped { create_key, reason }
453 },
454 ProcState::Failed { ref world_id, ref description } => {
455 tracing::error!("allocation failed for {}: {}", world_id, description);
456 event
457 }
458 };
459 tracing::debug!(name = event.as_ref(), "sending event: {:?}", event);
460 tx.post(RemoteProcessProcStateMessage::Update(alloc_key.clone(), event));
461 }
462 None => {
463 tracing::debug!("sending done");
464 tx.post(RemoteProcessProcStateMessage::Done(alloc_key.clone()));
465 running = false;
466 break;
467 }
468 }
469 }
470 _ = RealClock.sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)) => {
471 tracing::trace!("sending heartbeat");
472 tx.post(RemoteProcessProcStateMessage::HeartBeat);
473 }
474 }
475 }
476 tracing::info!("allocation handler loop exited");
477 if running {
478 tracing::info!("stopping processes");
479 if let Err(e) = alloc.stop_and_wait().await {
480 tracing::error!("stop failed: {}", e);
481 return;
482 }
483 tracing::info!("stop finished");
484 }
485 }
486}
487
488type HostId = String;
491
492#[derive(Clone)]
494pub struct RemoteProcessAllocHost {
495 pub id: HostId,
498 pub hostname: String,
500}
501
502struct RemoteProcessAllocHostState {
504 alloc_key: ShortUuid,
506 host_id: HostId,
508 tx: ChannelTx<RemoteProcessAllocatorMessage>,
510 active_procs: HashSet<ShortUuid>,
512 region: Region,
514 world_id: Option<WorldId>,
516 failed: bool,
518 allocated: bool,
520}
521
522#[automock]
523#[async_trait]
524pub trait RemoteProcessAllocInitializer {
526 async fn initialize_alloc(&mut self) -> Result<Vec<RemoteProcessAllocHost>, anyhow::Error>;
528}
529
530struct HostStates {
533 inner: HashMap<HostId, RemoteProcessAllocHostState>,
534 host_addresses: Arc<DashMap<HostId, ChannelAddr>>,
535}
536
537impl HostStates {
538 fn new(host_addresses: Arc<DashMap<HostId, ChannelAddr>>) -> HostStates {
539 Self {
540 inner: HashMap::new(),
541 host_addresses,
542 }
543 }
544
545 fn insert(
546 &mut self,
547 host_id: HostId,
548 state: RemoteProcessAllocHostState,
549 address: ChannelAddr,
550 ) {
551 self.host_addresses.insert(host_id.clone(), address);
552 self.inner.insert(host_id, state);
553 }
554
555 fn get(&self, host_id: &HostId) -> Option<&RemoteProcessAllocHostState> {
556 self.inner.get(host_id)
557 }
558
559 fn get_mut(&mut self, host_id: &HostId) -> Option<&mut RemoteProcessAllocHostState> {
560 self.inner.get_mut(host_id)
561 }
562
563 fn remove(&mut self, host_id: &HostId) -> Option<RemoteProcessAllocHostState> {
564 self.host_addresses.remove(host_id);
565 self.inner.remove(host_id)
566 }
567
568 fn iter(&self) -> impl Iterator<Item = (&HostId, &RemoteProcessAllocHostState)> {
569 self.inner.iter()
570 }
571
572 fn iter_mut(&mut self) -> impl Iterator<Item = (&HostId, &mut RemoteProcessAllocHostState)> {
573 self.inner.iter_mut()
574 }
575
576 fn is_empty(&self) -> bool {
577 self.inner.is_empty()
578 }
579 }
581
582pub struct RemoteProcessAlloc {
585 initializer: Box<dyn RemoteProcessAllocInitializer + Send + Sync>,
588 spec: AllocSpec,
589 remote_allocator_port: u16,
590 world_id: WorldId,
591 ordered_hosts: Vec<RemoteProcessAllocHost>,
592 started: bool,
594 running: bool,
596 failed: bool,
598 alloc_to_host: HashMap<ShortUuid, HostId>,
600 host_states: HostStates,
601 world_offsets: HashMap<WorldId, usize>,
602 event_queue: VecDeque<ProcState>,
603 comm_watcher_tx: UnboundedSender<HostId>,
604 comm_watcher_rx: UnboundedReceiver<HostId>,
605
606 bootstrap_addr: ChannelAddr,
607 rx: ChannelRx<RemoteProcessProcStateMessage>,
608 _signal_cleanup_guard: hyperactor::SignalCleanupGuard,
609}
610
611impl RemoteProcessAlloc {
612 #[observe_result("RemoteProcessAlloc")]
617 pub async fn new(
618 spec: AllocSpec,
619 world_id: WorldId,
620 remote_allocator_port: u16,
621 initializer: impl RemoteProcessAllocInitializer + Send + Sync + 'static,
622 ) -> Result<Self, anyhow::Error> {
623 let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(spec.transport.clone()))
624 .map_err(anyhow::Error::from)?;
625
626 tracing::info!(
627 "starting alloc for {} on: {}",
628 world_id,
629 bootstrap_addr.clone()
630 );
631
632 let (comm_watcher_tx, comm_watcher_rx) = unbounded_channel();
633
634 let host_addresses = Arc::new(DashMap::<HostId, ChannelAddr>::new());
635 let host_addresses_for_signal = host_addresses.clone();
636
637 let signal_cleanup_guard =
639 hyperactor::register_signal_cleanup_scoped(Box::pin(async move {
640 join_all(host_addresses_for_signal.iter().map(|entry| async move {
641 let addr = entry.value().clone();
642 match channel::dial(addr.clone()) {
643 Ok(tx) => {
644 if let Err(e) = tx.send(RemoteProcessAllocatorMessage::Stop).await {
645 tracing::error!("Failed to send Stop to {}: {}", addr, e);
646 }
647 }
648 Err(e) => {
649 tracing::error!("Failed to dial {} during signal cleanup: {}", addr, e);
650 }
651 }
652 }))
653 .await;
654 }));
655
656 Ok(Self {
657 spec,
658 world_id,
659 remote_allocator_port,
660 initializer: Box::new(initializer),
661 world_offsets: HashMap::new(),
662 ordered_hosts: Vec::new(),
663 alloc_to_host: HashMap::new(),
664 host_states: HostStates::new(host_addresses),
665 bootstrap_addr,
666 event_queue: VecDeque::new(),
667 comm_watcher_tx,
668 comm_watcher_rx,
669 rx,
670 started: false,
671 running: true,
672 failed: false,
673 _signal_cleanup_guard: signal_cleanup_guard,
674 })
675 }
676
677 async fn start_comm_watcher(&self) {
685 let mut tx_watchers = Vec::new();
686 for host in &self.ordered_hosts {
687 let tx_status = self.host_states.get(&host.id).unwrap().tx.status().clone();
688 let watcher = WatchStream::new(tx_status);
689 tx_watchers.push((watcher, host.id.clone()));
690 }
691 assert!(!tx_watchers.is_empty());
692 let tx = self.comm_watcher_tx.clone();
693 tokio::spawn(async move {
694 loop {
695 let mut tx_status_futures = Vec::new();
696 for (watcher, _) in &mut tx_watchers {
697 let fut = watcher.next().boxed();
698 tx_status_futures.push(fut);
699 }
700 let (tx_status, index, _) = select_all(tx_status_futures).await;
701 let host_id = match tx_watchers.get(index) {
702 Some((_, host_id)) => host_id.clone(),
703 None => {
704 tracing::error!(
706 "got selected index {} with no matching host in {}",
707 index,
708 tx_watchers.len()
709 );
710 continue;
711 }
712 };
713 if let Some(tx_status) = tx_status {
714 tracing::debug!("host {} channel event: {:?}", host_id, tx_status);
715 if tx_status == TxStatus::Closed {
716 if tx.send(host_id.clone()).is_err() {
717 break;
719 }
720 tx_watchers.remove(index);
721 if tx_watchers.is_empty() {
722 break;
724 }
725 }
726 }
727 }
728 });
729 }
730
731 async fn ensure_started(&mut self) -> Result<(), anyhow::Error> {
736 if self.started || self.failed {
737 return Ok(());
738 }
739
740 self.started = true;
741 let hosts = self
742 .initializer
743 .initialize_alloc()
744 .await
745 .context("alloc initializer error")?;
746 if hosts.is_empty() {
747 anyhow::bail!("initializer returned empty list of hosts");
748 }
749 let hostnames: Vec<_> = hosts.iter().map(|e| e.hostname.clone()).collect();
752 tracing::info!("obtained {} hosts for this allocation", hostnames.len());
753
754 anyhow::ensure!(
756 self.spec.extent.len() >= 2,
757 "invalid extent: {}, expected at least 2 dimensions",
758 self.spec.extent
759 );
760
761 let split_dim = &self.spec.extent.labels()[self.spec.extent.len() - 1];
763 for (i, region) in self.spec.extent.group_by(split_dim)?.enumerate() {
764 let host = &hosts[i];
765 tracing::debug!("allocating: {} for host: {}", region, host.id);
766
767 let remote_addr = match self.spec.transport {
768 ChannelTransport::MetaTls(_) => {
769 format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
770 }
771 ChannelTransport::Tcp => {
772 format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
773 }
774 ChannelTransport::Unix => host.hostname.clone(),
776 _ => {
777 anyhow::bail!(
778 "unsupported transport for host {}: {:?}",
779 host.id,
780 self.spec.transport,
781 );
782 }
783 };
784
785 tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
786 let remote_addr = remote_addr.parse::<ChannelAddr>()?;
787 let tx = channel::dial(remote_addr.clone())
788 .map_err(anyhow::Error::from)
789 .context(format!(
790 "failed to dial remote {} for host {}",
791 remote_addr, host.id
792 ))?;
793
794 let alloc_key = ShortUuid::generate();
796 assert!(
797 self.alloc_to_host
798 .insert(alloc_key.clone(), host.id.clone())
799 .is_none()
800 );
801
802 let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
803 let client_context = Some(ClientContext { trace_id });
804 let message = RemoteProcessAllocatorMessage::Allocate {
805 alloc_key: alloc_key.clone(),
806 extent: region.extent(),
807 bootstrap_addr: self.bootstrap_addr.clone(),
808 hosts: hostnames.clone(),
809 client_context,
810 };
811 tracing::info!(
812 name = message.as_ref(),
813 "sending allocate message to workers"
814 );
815 tx.post(message);
816
817 self.host_states.insert(
818 host.id.clone(),
819 RemoteProcessAllocHostState {
820 alloc_key,
821 host_id: host.id.clone(),
822 tx,
823 active_procs: HashSet::new(),
824 region,
825 world_id: None,
826 failed: false,
827 allocated: false,
828 },
829 remote_addr,
830 );
831 }
832
833 self.ordered_hosts = hosts;
834 self.start_comm_watcher().await;
835 self.started = true;
836
837 Ok(())
838 }
839
840 fn get_host_state_mut(
842 &mut self,
843 alloc_key: &ShortUuid,
844 ) -> Result<&mut RemoteProcessAllocHostState, anyhow::Error> {
845 let host_id: &HostId = self
846 .alloc_to_host
847 .get(alloc_key)
848 .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
849
850 self.host_states
851 .get_mut(host_id)
852 .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
853 }
854
855 fn get_host_state(
857 &self,
858 alloc_key: &ShortUuid,
859 ) -> Result<&RemoteProcessAllocHostState, anyhow::Error> {
860 let host_id: &HostId = self
861 .alloc_to_host
862 .get(alloc_key)
863 .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
864
865 self.host_states
866 .get(host_id)
867 .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
868 }
869
870 fn remove_host_state(
871 &mut self,
872 alloc_key: &ShortUuid,
873 ) -> Result<RemoteProcessAllocHostState, anyhow::Error> {
874 let host_id: &HostId = self
875 .alloc_to_host
876 .get(alloc_key)
877 .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
878
879 self.host_states
880 .remove(host_id)
881 .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
882 }
883
884 fn add_proc_id_to_host_state(
885 &mut self,
886 alloc_key: &ShortUuid,
887 create_key: &ShortUuid,
888 ) -> Result<(), anyhow::Error> {
889 let task_state = self.get_host_state_mut(alloc_key)?;
890 if !task_state.active_procs.insert(create_key.clone()) {
891 tracing::error!("proc with create key {} already in host state", create_key);
893 }
894 task_state.allocated = true;
895 Ok(())
896 }
897
898 fn remove_proc_from_host_state(
899 &mut self,
900 alloc_key: &ShortUuid,
901 create_key: &ShortUuid,
902 ) -> Result<(), anyhow::Error> {
903 let task_state = self.get_host_state_mut(alloc_key)?;
904 if !task_state.active_procs.remove(create_key) {
905 tracing::error!("proc with create_key already in host state: {}", create_key);
907 }
908 Ok(())
909 }
910
911 fn project_proc_into_global_extent(
913 &self,
914 alloc_key: &ShortUuid,
915 point: &Point,
916 ) -> Result<Point, anyhow::Error> {
917 let global_rank = self
918 .get_host_state(alloc_key)?
919 .region
920 .get(point.rank())
921 .ok_or_else(|| {
922 anyhow::anyhow!(
923 "rank {} out of bounds for in alloc {}",
924 point.rank(),
925 alloc_key
926 )
927 })?;
928 Ok(self.spec.extent.point_of_rank(global_rank)?)
929 }
930
931 fn cleanup_host_channel_closed(
933 &mut self,
934 host_id: HostId,
935 ) -> Result<Vec<ShortUuid>, anyhow::Error> {
936 let state = match self.host_states.remove(&host_id) {
937 Some(state) => state,
938 None => {
939 anyhow::bail!(
941 "got channel closed event for host {} which has no known state",
942 host_id
943 );
944 }
945 };
946 self.ordered_hosts.retain(|host| host.id != host_id);
947 self.alloc_to_host.remove(&state.alloc_key);
948 if let Some(world_id) = state.world_id {
949 self.world_offsets.remove(&world_id);
950 }
951 let create_keys = state.active_procs.iter().cloned().collect();
952
953 Ok(create_keys)
954 }
955}
956
957#[async_trait]
958impl Alloc for RemoteProcessAlloc {
959 async fn next(&mut self) -> Option<ProcState> {
960 loop {
961 if let state @ Some(_) = self.event_queue.pop_front() {
962 break state;
963 }
964
965 if !self.running {
966 break None;
967 }
968
969 if let Err(e) = self.ensure_started().await {
970 break Some(ProcState::Failed {
971 world_id: self.world_id.clone(),
972 description: format!("failed to ensure started: {:#}", e),
973 });
974 }
975
976 let heartbeat_interval =
977 config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL);
978 let mut heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
979 let mut reloop = false;
981 let update = loop {
982 tokio::select! {
983 msg = self.rx.recv() => {
984 tracing::debug!("got ProcState message from allocator: {:?}", msg);
985 match msg {
986 Ok(RemoteProcessProcStateMessage::Allocated { alloc_key, world_id }) => {
987 tracing::info!("remote alloc {}: allocated", alloc_key);
988 match self.get_host_state_mut(&alloc_key) {
989 Ok(state) => {
990 state.world_id = Some(world_id.clone());
991 }
992 Err(err) => {
993 tracing::error!(
995 "received allocated message alloc: {} with no known state: {}",
996 alloc_key, err,
997 );
998 }
999 }
1000 }
1001 Ok(RemoteProcessProcStateMessage::Update(alloc_key, proc_state)) => {
1002 let update = match proc_state {
1003 ProcState::Created { ref create_key, .. } => {
1004 if let Err(e) = self.add_proc_id_to_host_state(&alloc_key, create_key) {
1005 tracing::error!("failed to add proc with create key {} host state: {}", create_key, e);
1006 }
1007 proc_state
1008 }
1009 ProcState::Stopped{ ref create_key, ..} => {
1010 if let Err(e) = self.remove_proc_from_host_state(&alloc_key, create_key) {
1011 tracing::error!("failed to remove proc with create key {} host state: {}", create_key, e);
1012 }
1013 proc_state
1014 }
1015 ProcState::Failed { ref world_id, ref description } => {
1016 match self.get_host_state_mut(&alloc_key) {
1017 Ok(state) => {
1018 state.failed = true;
1019 ProcState::Failed {
1020 world_id: world_id.clone(),
1021 description: format!("host {} failed: {}", state.host_id, description),
1022 }
1023 }
1024 Err(e) => {
1025 tracing::error!("failed to find host state for world id: {}: {}", world_id, e);
1026 proc_state
1027 }
1028 }
1029 }
1030 _ => proc_state
1031 };
1032
1033 break Some((Some(alloc_key), update));
1034 }
1035 Ok(RemoteProcessProcStateMessage::Done(alloc_key)) => {
1036 tracing::info!("allocator {} is done", alloc_key);
1037
1038 if let Ok(state) = self.remove_host_state(&alloc_key) {
1039 if !state.active_procs.is_empty() {
1040 tracing::error!("received done for alloc {} with active procs: {:?}", alloc_key, state.active_procs);
1041 }
1042 } else {
1043 tracing::error!("received done for unknown alloc {}", alloc_key);
1044 }
1045
1046 if self.host_states.is_empty() {
1047 self.running = false;
1048 break None;
1049 }
1050 }
1051 Ok(RemoteProcessProcStateMessage::HeartBeat) => {}
1055 Err(e) => {
1056 break Some((None, ProcState::Failed {world_id: self.world_id.clone(), description: format!("error receiving events: {}", e)}));
1057 }
1058 }
1059 }
1060
1061 _ = clock::RealClock.sleep_until(heartbeat_time) => {
1062 self.host_states.iter().for_each(|(_, host_state)| host_state.tx.post(RemoteProcessAllocatorMessage::HeartBeat));
1063 heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
1064 }
1065
1066 closed_host_id = self.comm_watcher_rx.recv() => {
1067 if let Some(closed_host_id) = closed_host_id {
1068 tracing::debug!("host {} channel closed, cleaning up", closed_host_id);
1069 if let Some(state) = self.host_states.get(&closed_host_id)
1070 && !state.allocated {
1071 break Some((None, ProcState::Failed {
1072 world_id: self.world_id.clone(),
1073 description: format!(
1074 "no process has ever been allocated on {} before the channel is closed; \
1075 a common issue could be the channel was never established",
1076 closed_host_id
1077 )}));
1078 }
1079 let create_keys = match self.cleanup_host_channel_closed(closed_host_id) {
1080 Ok(create_keys) => create_keys,
1081 Err(err) => {
1082 tracing::error!("failed to cleanup disconnected host: {}", err);
1083 continue;
1084 }
1085 };
1086 for create_key in create_keys {
1087 tracing::debug!("queuing Stopped state for proc with create key {}", create_key);
1088 self.event_queue.push_back(
1089 ProcState::Stopped {
1090 create_key,
1091 reason: ProcStopReason::HostWatchdog
1092 }
1093 );
1094 }
1095 if self.host_states.is_empty() {
1097 tracing::info!("no more hosts left, stopping the alloc");
1098 self.running = false;
1099 }
1100 reloop = true;
1102 break None;
1103 } else {
1104 tracing::warn!("unexpected comm watcher channel close");
1106 break None;
1107 }
1108 }
1109 }
1110 };
1111
1112 if reloop {
1113 assert!(update.is_none());
1116 continue;
1117 }
1118
1119 break match update {
1120 Some((
1121 Some(alloc_key),
1122 ProcState::Created {
1123 create_key,
1124 point,
1125 pid,
1126 },
1127 )) => match self.project_proc_into_global_extent(&alloc_key, &point) {
1128 Ok(global_point) => {
1129 tracing::debug!("reprojected coords: {} -> {}", point, global_point);
1130 Some(ProcState::Created {
1131 create_key,
1132 point: global_point,
1133 pid,
1134 })
1135 }
1136 Err(e) => {
1137 tracing::error!(
1138 "failed to project coords for proc: {}.{}: {}",
1139 alloc_key,
1140 create_key,
1141 e
1142 );
1143 None
1144 }
1145 },
1146 Some((None, ProcState::Created { .. })) => {
1147 panic!("illegal state: missing alloc_key for ProcState::Created event")
1148 }
1149 Some((_, update)) => {
1150 if let ProcState::Failed { description, .. } = &update {
1151 tracing::error!(description);
1152 self.failed = true;
1153 }
1154 Some(update)
1155 }
1156 None => None,
1157 };
1158 }
1159 }
1160
1161 fn spec(&self) -> &AllocSpec {
1162 &self.spec
1163 }
1164
1165 fn extent(&self) -> &Extent {
1166 &self.spec.extent
1167 }
1168
1169 fn world_id(&self) -> &WorldId {
1170 &self.world_id
1171 }
1172
1173 async fn stop(&mut self) -> Result<(), AllocatorError> {
1174 tracing::info!("stopping alloc");
1175
1176 for (host_id, task_state) in self.host_states.iter_mut() {
1177 tracing::debug!("stopping alloc at host {}", host_id);
1178 task_state.tx.post(RemoteProcessAllocatorMessage::Stop);
1179 }
1180
1181 Ok(())
1182 }
1183}
1184
1185impl Drop for RemoteProcessAlloc {
1186 fn drop(&mut self) {
1187 tracing::debug!("dropping RemoteProcessAlloc of world_id {}", self.world_id);
1188 }
1189}
1190
1191#[cfg(test)]
1192mod test {
1193 use std::assert_matches::assert_matches;
1194
1195 use hyperactor::ActorRef;
1196 use hyperactor::channel::ChannelRx;
1197 use hyperactor::clock::ClockKind;
1198 use hyperactor::id;
1199 use ndslice::extent;
1200 use tokio::sync::oneshot;
1201
1202 use super::*;
1203 use crate::alloc::ChannelTransport;
1204 use crate::alloc::MockAlloc;
1205 use crate::alloc::MockAllocWrapper;
1206 use crate::alloc::MockAllocator;
1207 use crate::alloc::ProcStopReason;
1208 use crate::proc_mesh::mesh_agent::ProcMeshAgent;
1209
1210 async fn read_all_created(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1211 let mut i: usize = 0;
1212 while i < alloc_len {
1213 let m = rx.recv().await.unwrap();
1214 match m {
1215 RemoteProcessProcStateMessage::Update(_, ProcState::Created { .. }) => i += 1,
1216 RemoteProcessProcStateMessage::HeartBeat => {}
1217 _ => panic!("unexpected message: {:?}", m),
1218 }
1219 }
1220 }
1221
1222 async fn read_all_running(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1223 let mut i: usize = 0;
1224 while i < alloc_len {
1225 let m = rx.recv().await.unwrap();
1226 match m {
1227 RemoteProcessProcStateMessage::Update(_, ProcState::Running { .. }) => i += 1,
1228 RemoteProcessProcStateMessage::HeartBeat => {}
1229 _ => panic!("unexpected message: {:?}", m),
1230 }
1231 }
1232 }
1233
1234 async fn read_all_stopped(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1235 let mut i: usize = 0;
1236 while i < alloc_len {
1237 let m = rx.recv().await.unwrap();
1238 match m {
1239 RemoteProcessProcStateMessage::Update(_, ProcState::Stopped { .. }) => i += 1,
1240 RemoteProcessProcStateMessage::HeartBeat => {}
1241 _ => panic!("unexpected message: {:?}", m),
1242 }
1243 }
1244 }
1245
1246 fn set_procstate_expectations(alloc: &mut MockAlloc, extent: Extent) {
1247 alloc.expect_extent().return_const(extent.clone());
1248 let mut create_keys = Vec::new();
1249 for point in extent.points() {
1250 let create_key = ShortUuid::generate();
1251 create_keys.push(create_key.clone());
1252 alloc.expect_next().times(1).return_once(move || {
1253 Some(ProcState::Created {
1254 create_key: create_key.clone(),
1255 point,
1256 pid: 0,
1257 })
1258 });
1259 }
1260 for (i, create_key) in create_keys
1261 .iter()
1262 .take(extent.num_ranks())
1263 .cloned()
1264 .enumerate()
1265 {
1266 let proc_id = format!("test[{i}]").parse().unwrap();
1267 let mesh_agent = ActorRef::<ProcMeshAgent>::attest(
1268 format!("test[{i}].mesh_agent[{i}]").parse().unwrap(),
1269 );
1270 alloc.expect_next().times(1).return_once(move || {
1271 Some(ProcState::Running {
1272 create_key,
1273 proc_id,
1274 addr: ChannelAddr::Unix("/proc0".parse().unwrap()),
1275 mesh_agent,
1276 })
1277 });
1278 }
1279 for create_key in create_keys.iter().take(extent.num_ranks()).cloned() {
1280 alloc.expect_next().times(1).return_once(|| {
1281 Some(ProcState::Stopped {
1282 create_key,
1283 reason: ProcStopReason::Unknown,
1284 })
1285 });
1286 }
1287 }
1288
1289 #[timed_test::async_timed_test(timeout_secs = 5)]
1290 async fn test_simple() {
1291 let config = hyperactor::config::global::lock();
1292 let _guard = config.override_key(
1293 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1294 Duration::from_millis(100),
1295 );
1296 hyperactor_telemetry::initialize_logging(ClockKind::default());
1297 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1298 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1299 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1300
1301 let extent = extent!(host = 1, gpu = 2);
1302 let tx = channel::dial(serve_addr.clone()).unwrap();
1303
1304 let world_id: WorldId = id!(test_world_id);
1305 let mut alloc = MockAlloc::new();
1306 alloc.expect_world_id().return_const(world_id.clone());
1307 alloc.expect_extent().return_const(extent.clone());
1308
1309 set_procstate_expectations(&mut alloc, extent.clone());
1310
1311 alloc.expect_next().return_const(None);
1313
1314 let mut allocator = MockAllocator::new();
1315 let total_messages = extent.num_ranks() * 3 + 1;
1316 let mock_wrapper = MockAllocWrapper::new_block_next(
1317 alloc,
1318 total_messages,
1320 );
1321 allocator
1322 .expect_allocate()
1323 .times(1)
1324 .return_once(move |_| Ok(mock_wrapper));
1325
1326 let remote_allocator = RemoteProcessAllocator::new();
1327 let handle = tokio::spawn({
1328 let remote_allocator = remote_allocator.clone();
1329 async move {
1330 remote_allocator
1331 .start_with_allocator(serve_addr, allocator, None)
1332 .await
1333 }
1334 });
1335
1336 let alloc_key = ShortUuid::generate();
1337
1338 tx.send(RemoteProcessAllocatorMessage::Allocate {
1339 alloc_key: alloc_key.clone(),
1340 extent: extent.clone(),
1341 bootstrap_addr,
1342 hosts: vec![],
1343 client_context: None,
1344 })
1345 .await
1346 .unwrap();
1347
1348 let m = rx.recv().await.unwrap();
1350 assert_matches!(
1351 m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1352 if got_world_id == world_id && got_alloc_key == alloc_key
1353 );
1354
1355 let mut rank: usize = 0;
1357 let mut create_keys = Vec::with_capacity(extent.num_ranks());
1358 while rank < extent.num_ranks() {
1359 let m = rx.recv().await.unwrap();
1360 match m {
1361 RemoteProcessProcStateMessage::Update(
1362 got_alloc_key,
1363 ProcState::Created {
1364 create_key, point, ..
1365 },
1366 ) => {
1367 let expected_point = extent.point_of_rank(rank).unwrap();
1368 assert_eq!(got_alloc_key, alloc_key);
1369 assert_eq!(point, expected_point);
1370 create_keys.push(create_key);
1371 rank += 1;
1372 }
1373 RemoteProcessProcStateMessage::HeartBeat => {}
1374 _ => panic!("unexpected message: {:?}", m),
1375 }
1376 }
1377 let mut rank: usize = 0;
1379 while rank < extent.num_ranks() {
1380 let m = rx.recv().await.unwrap();
1381 match m {
1382 RemoteProcessProcStateMessage::Update(
1383 got_alloc_key,
1384 ProcState::Running {
1385 create_key,
1386 proc_id,
1387 mesh_agent,
1388 addr: _,
1389 },
1390 ) => {
1391 assert_eq!(got_alloc_key, alloc_key);
1392 assert_eq!(create_key, create_keys[rank]);
1393 let expected_proc_id = format!("test[{}]", rank).parse().unwrap();
1394 let expected_mesh_agent = ActorRef::<ProcMeshAgent>::attest(
1395 format!("test[{}].mesh_agent[{}]", rank, rank)
1396 .parse()
1397 .unwrap(),
1398 );
1399 assert_eq!(proc_id, expected_proc_id);
1400 assert_eq!(mesh_agent, expected_mesh_agent);
1401 rank += 1;
1402 }
1403 RemoteProcessProcStateMessage::HeartBeat => {}
1404 _ => panic!("unexpected message: {:?}", m),
1405 }
1406 }
1407 let mut rank: usize = 0;
1409 while rank < extent.num_ranks() {
1410 let m = rx.recv().await.unwrap();
1411 match m {
1412 RemoteProcessProcStateMessage::Update(
1413 got_alloc_key,
1414 ProcState::Stopped {
1415 create_key,
1416 reason: ProcStopReason::Unknown,
1417 },
1418 ) => {
1419 assert_eq!(got_alloc_key, alloc_key);
1420 assert_eq!(create_key, create_keys[rank]);
1421 rank += 1;
1422 }
1423 RemoteProcessProcStateMessage::HeartBeat => {}
1424 _ => panic!("unexpected message: {:?}", m),
1425 }
1426 }
1427 loop {
1429 let m = rx.recv().await.unwrap();
1430 match m {
1431 RemoteProcessProcStateMessage::Done(got_alloc_key) => {
1432 assert_eq!(got_alloc_key, alloc_key);
1433 break;
1434 }
1435 RemoteProcessProcStateMessage::HeartBeat => {}
1436 _ => panic!("unexpected message: {:?}", m),
1437 }
1438 }
1439
1440 remote_allocator.terminate();
1441 handle.await.unwrap().unwrap();
1442 }
1443
1444 #[timed_test::async_timed_test(timeout_secs = 15)]
1445 async fn test_normal_stop() {
1446 let config = hyperactor::config::global::lock();
1447 let _guard = config.override_key(
1448 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1449 Duration::from_millis(100),
1450 );
1451 hyperactor_telemetry::initialize_logging(ClockKind::default());
1452 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1453 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1454 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1455
1456 let extent = extent!(host = 1, gpu = 2);
1457 let tx = channel::dial(serve_addr.clone()).unwrap();
1458
1459 let world_id: WorldId = id!(test_world_id);
1460 let mut alloc = MockAllocWrapper::new_block_next(
1461 MockAlloc::new(),
1462 extent.num_ranks() * 2,
1464 );
1465 let next_tx = alloc.notify_tx();
1466 alloc.alloc.expect_world_id().return_const(world_id.clone());
1467 alloc.alloc.expect_extent().return_const(extent.clone());
1468
1469 set_procstate_expectations(&mut alloc.alloc, extent.clone());
1470
1471 alloc.alloc.expect_next().return_const(None);
1472 alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1473
1474 let mut allocator = MockAllocator::new();
1475 allocator
1476 .expect_allocate()
1477 .times(1)
1478 .return_once(|_| Ok(alloc));
1479
1480 let remote_allocator = RemoteProcessAllocator::new();
1481 let handle = tokio::spawn({
1482 let remote_allocator = remote_allocator.clone();
1483 async move {
1484 remote_allocator
1485 .start_with_allocator(serve_addr, allocator, None)
1486 .await
1487 }
1488 });
1489
1490 let alloc_key = ShortUuid::generate();
1491 tx.send(RemoteProcessAllocatorMessage::Allocate {
1492 alloc_key: alloc_key.clone(),
1493 extent: extent.clone(),
1494 bootstrap_addr,
1495 hosts: vec![],
1496 client_context: None,
1497 })
1498 .await
1499 .unwrap();
1500
1501 let m = rx.recv().await.unwrap();
1503 assert_matches!(
1504 m,
1505 RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1506 if world_id == got_world_id && alloc_key == got_alloc_key
1507 );
1508
1509 read_all_created(&mut rx, extent.num_ranks()).await;
1510 read_all_running(&mut rx, extent.num_ranks()).await;
1511
1512 tracing::info!("stopping allocation");
1514 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1515 next_tx.send(()).unwrap();
1517
1518 read_all_stopped(&mut rx, extent.num_ranks()).await;
1519
1520 remote_allocator.terminate();
1521 handle.await.unwrap().unwrap();
1522 }
1523
1524 #[timed_test::async_timed_test(timeout_secs = 15)]
1525 async fn test_realloc() {
1526 let config = hyperactor::config::global::lock();
1527 let _guard = config.override_key(
1528 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1529 Duration::from_millis(100),
1530 );
1531 hyperactor_telemetry::initialize_logging(ClockKind::default());
1532 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1533 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1534 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1535
1536 let extent = extent!(host = 1, gpu = 2);
1537
1538 let tx = channel::dial(serve_addr.clone()).unwrap();
1539
1540 let world_id: WorldId = id!(test_world_id);
1541 let mut alloc1 = MockAllocWrapper::new_block_next(
1542 MockAlloc::new(),
1543 extent.num_ranks() * 2,
1545 );
1546 let next_tx1 = alloc1.notify_tx();
1547 alloc1
1548 .alloc
1549 .expect_world_id()
1550 .return_const(world_id.clone());
1551 alloc1.alloc.expect_extent().return_const(extent.clone());
1552
1553 set_procstate_expectations(&mut alloc1.alloc, extent.clone());
1554 alloc1.alloc.expect_next().return_const(None);
1555 alloc1.alloc.expect_stop().times(1).return_once(|| Ok(()));
1556 let mut alloc2 = MockAllocWrapper::new_block_next(
1558 MockAlloc::new(),
1559 extent.num_ranks() * 2,
1561 );
1562 let next_tx2 = alloc2.notify_tx();
1563 alloc2
1564 .alloc
1565 .expect_world_id()
1566 .return_const(world_id.clone());
1567 alloc2.alloc.expect_extent().return_const(extent.clone());
1568 set_procstate_expectations(&mut alloc2.alloc, extent.clone());
1569 alloc2.alloc.expect_next().return_const(None);
1570 alloc2.alloc.expect_stop().times(1).return_once(|| Ok(()));
1571
1572 let mut allocator = MockAllocator::new();
1573 allocator
1574 .expect_allocate()
1575 .times(1)
1576 .return_once(|_| Ok(alloc1));
1577 allocator
1579 .expect_allocate()
1580 .times(1)
1581 .return_once(|_| Ok(alloc2));
1582
1583 let remote_allocator = RemoteProcessAllocator::new();
1584 let handle = tokio::spawn({
1585 let remote_allocator = remote_allocator.clone();
1586 async move {
1587 remote_allocator
1588 .start_with_allocator(serve_addr, allocator, None)
1589 .await
1590 }
1591 });
1592
1593 let alloc_key = ShortUuid::generate();
1594
1595 tx.send(RemoteProcessAllocatorMessage::Allocate {
1596 alloc_key: alloc_key.clone(),
1597 extent: extent.clone(),
1598 bootstrap_addr: bootstrap_addr.clone(),
1599 hosts: vec![],
1600 client_context: None,
1601 })
1602 .await
1603 .unwrap();
1604
1605 let m = rx.recv().await.unwrap();
1607 assert_matches!(
1608 m,
1609 RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1610 if got_world_id == world_id && got_alloc_key == alloc_key
1611 );
1612
1613 read_all_created(&mut rx, extent.num_ranks()).await;
1614 read_all_running(&mut rx, extent.num_ranks()).await;
1615
1616 let alloc_key = ShortUuid::generate();
1617
1618 tx.send(RemoteProcessAllocatorMessage::Allocate {
1620 alloc_key: alloc_key.clone(),
1621 extent: extent.clone(),
1622 bootstrap_addr,
1623 hosts: vec![],
1624 client_context: None,
1625 })
1626 .await
1627 .unwrap();
1628 next_tx1.send(()).unwrap();
1630 read_all_stopped(&mut rx, extent.num_ranks()).await;
1632 let m = rx.recv().await.unwrap();
1633 assert_matches!(m, RemoteProcessProcStateMessage::Done(_));
1634 let m = rx.recv().await.unwrap();
1635 assert_matches!(
1636 m,
1637 RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1638 if got_world_id == world_id && got_alloc_key == alloc_key
1639 );
1640 read_all_created(&mut rx, extent.num_ranks()).await;
1642 read_all_running(&mut rx, extent.num_ranks()).await;
1643 tracing::info!("stopping allocation");
1645 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1646 next_tx2.send(()).unwrap();
1648
1649 read_all_stopped(&mut rx, extent.num_ranks()).await;
1650
1651 remote_allocator.terminate();
1652 handle.await.unwrap().unwrap();
1653 }
1654
1655 #[timed_test::async_timed_test(timeout_secs = 15)]
1656 async fn test_upstream_closed() {
1657 let config = hyperactor::config::global::lock();
1659 let _guard1 = config.override_key(
1660 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1661 Duration::from_secs(1),
1662 );
1663 let _guard2 = config.override_key(
1664 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1665 Duration::from_millis(100),
1666 );
1667
1668 hyperactor_telemetry::initialize_logging(ClockKind::default());
1669 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1670 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1671 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1672
1673 let extent = extent!(host = 1, gpu = 2);
1674
1675 let tx = channel::dial(serve_addr.clone()).unwrap();
1676
1677 let world_id: WorldId = id!(test_world_id);
1678 let mut alloc = MockAllocWrapper::new_block_next(
1679 MockAlloc::new(),
1680 extent.num_ranks() * 2,
1682 );
1683 let next_tx = alloc.notify_tx();
1684 alloc.alloc.expect_world_id().return_const(world_id.clone());
1685 alloc.alloc.expect_extent().return_const(extent.clone());
1686
1687 set_procstate_expectations(&mut alloc.alloc, extent.clone());
1688
1689 alloc.alloc.expect_next().return_const(None);
1690 let (stop_tx, stop_rx) = oneshot::channel();
1693 alloc.alloc.expect_stop().times(1).return_once(|| {
1694 stop_tx.send(()).unwrap();
1695 Ok(())
1696 });
1697
1698 let mut allocator = MockAllocator::new();
1699 allocator
1700 .expect_allocate()
1701 .times(1)
1702 .return_once(|_| Ok(alloc));
1703
1704 let remote_allocator = RemoteProcessAllocator::new();
1705 let handle = tokio::spawn({
1706 let remote_allocator = remote_allocator.clone();
1707 async move {
1708 remote_allocator
1709 .start_with_allocator(serve_addr, allocator, None)
1710 .await
1711 }
1712 });
1713
1714 let alloc_key = ShortUuid::generate();
1715
1716 tx.send(RemoteProcessAllocatorMessage::Allocate {
1717 alloc_key: alloc_key.clone(),
1718 extent: extent.clone(),
1719 bootstrap_addr,
1720 hosts: vec![],
1721 client_context: None,
1722 })
1723 .await
1724 .unwrap();
1725
1726 let m = rx.recv().await.unwrap();
1728 assert_matches!(
1729 m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1730 if got_world_id == world_id && got_alloc_key == alloc_key
1731 );
1732
1733 read_all_created(&mut rx, extent.num_ranks()).await;
1734 read_all_running(&mut rx, extent.num_ranks()).await;
1735
1736 tracing::info!("closing upstream");
1738 drop(rx);
1739 #[allow(clippy::disallowed_methods)]
1741 tokio::time::sleep(Duration::from_secs(2)).await;
1742 stop_rx.await.unwrap();
1744 next_tx.send(()).unwrap();
1746 remote_allocator.terminate();
1747 handle.await.unwrap().unwrap();
1748 }
1749
1750 #[timed_test::async_timed_test(timeout_secs = 15)]
1751 async fn test_inner_alloc_failure() {
1752 let config = hyperactor::config::global::lock();
1753 let _guard = config.override_key(
1754 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1755 Duration::from_secs(60),
1756 );
1757 hyperactor_telemetry::initialize_logging(ClockKind::default());
1758 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1759 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1760 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1761
1762 let extent = extent!(host = 1, gpu = 2);
1763
1764 let tx = channel::dial(serve_addr.clone()).unwrap();
1765
1766 let test_world_id: WorldId = id!(test_world_id);
1767 let mut alloc = MockAllocWrapper::new_block_next(
1768 MockAlloc::new(),
1769 1,
1771 );
1772 let next_tx = alloc.notify_tx();
1773 alloc
1774 .alloc
1775 .expect_world_id()
1776 .return_const(test_world_id.clone());
1777 alloc.alloc.expect_extent().return_const(extent.clone());
1778 alloc
1779 .alloc
1780 .expect_next()
1781 .times(1)
1782 .return_const(Some(ProcState::Failed {
1783 world_id: test_world_id.clone(),
1784 description: "test".to_string(),
1785 }));
1786 alloc.alloc.expect_next().times(1).return_const(None);
1787
1788 alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1789
1790 let mut allocator = MockAllocator::new();
1791 allocator
1792 .expect_allocate()
1793 .times(1)
1794 .return_once(|_| Ok(alloc));
1795
1796 let remote_allocator = RemoteProcessAllocator::new();
1797 let handle = tokio::spawn({
1798 let remote_allocator = remote_allocator.clone();
1799 async move {
1800 remote_allocator
1801 .start_with_allocator(serve_addr, allocator, None)
1802 .await
1803 }
1804 });
1805
1806 let alloc_key = ShortUuid::generate();
1807 tx.send(RemoteProcessAllocatorMessage::Allocate {
1808 alloc_key: alloc_key.clone(),
1809 extent: extent.clone(),
1810 bootstrap_addr,
1811 hosts: vec![],
1812 client_context: None,
1813 })
1814 .await
1815 .unwrap();
1816
1817 let m = rx.recv().await.unwrap();
1819 assert_matches!(
1820 m,
1821 RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1822 if test_world_id == got_world_id && alloc_key == got_alloc_key
1823 );
1824
1825 let m = rx.recv().await.unwrap();
1827 assert_matches!(
1828 m,
1829 RemoteProcessProcStateMessage::Update(
1830 got_alloc_key,
1831 ProcState::Failed { world_id, description }
1832 ) if got_alloc_key == alloc_key && world_id == test_world_id && description == "test"
1833 );
1834
1835 tracing::info!("stopping allocation");
1836 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1837 next_tx.send(()).unwrap();
1839 let m = rx.recv().await.unwrap();
1841 assert_matches!(
1842 m,
1843 RemoteProcessProcStateMessage::Done(got_alloc_key)
1844 if got_alloc_key == alloc_key
1845 );
1846
1847 remote_allocator.terminate();
1848 handle.await.unwrap().unwrap();
1849 }
1850
1851 #[timed_test::async_timed_test(timeout_secs = 15)]
1852 async fn test_trace_id_propagation() {
1853 let config = hyperactor::config::global::lock();
1854 let _guard = config.override_key(
1855 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1856 Duration::from_secs(60),
1857 );
1858 hyperactor_telemetry::initialize_logging(ClockKind::default());
1859 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1860 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1861 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1862
1863 let extent = extent!(host = 1, gpu = 1);
1864 let tx = channel::dial(serve_addr.clone()).unwrap();
1865 let test_world_id: WorldId = id!(test_world_id);
1866 let test_trace_id = "test_trace_id_12345";
1867
1868 let mut alloc = MockAlloc::new();
1870 alloc.expect_world_id().return_const(test_world_id.clone());
1871 alloc.expect_extent().return_const(extent.clone());
1872 alloc.expect_next().return_const(None);
1873
1874 let mut allocator = MockAllocator::new();
1876 allocator
1877 .expect_allocate()
1878 .times(1)
1879 .withf(move |spec: &AllocSpec| {
1880 spec.constraints
1882 .match_labels
1883 .get(CLIENT_TRACE_ID_LABEL)
1884 .is_some_and(|trace_id| trace_id == test_trace_id)
1885 })
1886 .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
1887
1888 let remote_allocator = RemoteProcessAllocator::new();
1889 let handle = tokio::spawn({
1890 let remote_allocator = remote_allocator.clone();
1891 async move {
1892 remote_allocator
1893 .start_with_allocator(serve_addr, allocator, None)
1894 .await
1895 }
1896 });
1897
1898 let alloc_key = ShortUuid::generate();
1899 tx.send(RemoteProcessAllocatorMessage::Allocate {
1900 alloc_key: alloc_key.clone(),
1901 extent: extent.clone(),
1902 bootstrap_addr,
1903 hosts: vec![],
1904 client_context: Some(ClientContext {
1905 trace_id: test_trace_id.to_string(),
1906 }),
1907 })
1908 .await
1909 .unwrap();
1910
1911 let m = rx.recv().await.unwrap();
1913 assert_matches!(
1914 m,
1915 RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1916 if got_world_id == test_world_id && got_alloc_key == alloc_key
1917 );
1918
1919 let m = rx.recv().await.unwrap();
1921 assert_matches!(
1922 m,
1923 RemoteProcessProcStateMessage::Done(got_alloc_key)
1924 if alloc_key == got_alloc_key
1925 );
1926
1927 remote_allocator.terminate();
1928 handle.await.unwrap().unwrap();
1929 }
1930
1931 #[timed_test::async_timed_test(timeout_secs = 15)]
1932 async fn test_trace_id_propagation_no_client_context() {
1933 let config = hyperactor::config::global::lock();
1934 let _guard = config.override_key(
1935 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1936 Duration::from_secs(60),
1937 );
1938 hyperactor_telemetry::initialize_logging(ClockKind::default());
1939 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1940 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1941 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1942
1943 let extent = extent!(host = 1, gpu = 1);
1944 let tx = channel::dial(serve_addr.clone()).unwrap();
1945 let test_world_id: WorldId = id!(test_world_id);
1946
1947 let mut alloc = MockAlloc::new();
1949 alloc.expect_world_id().return_const(test_world_id.clone());
1950 alloc.expect_extent().return_const(extent.clone());
1951 alloc.expect_next().return_const(None);
1952
1953 let mut allocator = MockAllocator::new();
1955 allocator
1956 .expect_allocate()
1957 .times(1)
1958 .withf(move |spec: &AllocSpec| {
1959 spec.constraints.match_labels.is_empty()
1961 })
1962 .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
1963
1964 let remote_allocator = RemoteProcessAllocator::new();
1965 let handle = tokio::spawn({
1966 let remote_allocator = remote_allocator.clone();
1967 async move {
1968 remote_allocator
1969 .start_with_allocator(serve_addr, allocator, None)
1970 .await
1971 }
1972 });
1973
1974 let alloc_key = ShortUuid::generate();
1975 tx.send(RemoteProcessAllocatorMessage::Allocate {
1976 alloc_key: alloc_key.clone(),
1977 extent: extent.clone(),
1978 bootstrap_addr,
1979 hosts: vec![],
1980 client_context: None,
1981 })
1982 .await
1983 .unwrap();
1984
1985 let m = rx.recv().await.unwrap();
1987 assert_matches!(
1988 m,
1989 RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1990 if got_world_id == test_world_id && got_alloc_key == alloc_key
1991 );
1992
1993 let m = rx.recv().await.unwrap();
1995 assert_matches!(
1996 m,
1997 RemoteProcessProcStateMessage::Done(got_alloc_key)
1998 if got_alloc_key == alloc_key
1999 );
2000
2001 remote_allocator.terminate();
2002 handle.await.unwrap().unwrap();
2003 }
2004}
2005
2006#[cfg(test)]
2007mod test_alloc {
2008 use std::os::unix::process::ExitStatusExt;
2009
2010 use hyperactor::clock::ClockKind;
2011 use hyperactor::config;
2012 use ndslice::extent;
2013 use nix::sys::signal;
2014 use nix::unistd::Pid;
2015 use timed_test::async_timed_test;
2016
2017 use super::*;
2018
2019 #[async_timed_test(timeout_secs = 60)]
2020 async fn test_alloc_simple() {
2021 let config = hyperactor::config::global::lock();
2023 let _guard1 = config.override_key(
2024 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2025 Duration::from_secs(1),
2026 );
2027 let _guard2 = config.override_key(
2028 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2029 Duration::from_millis(100),
2030 );
2031 hyperactor_telemetry::initialize_logging(ClockKind::default());
2032
2033 let spec = AllocSpec {
2034 extent: extent!(host = 2, gpu = 2),
2035 constraints: Default::default(),
2036 proc_name: None,
2037 transport: ChannelTransport::Unix,
2038 };
2039 let world_id = WorldId("test_world_id".to_string());
2040
2041 let task1_allocator = RemoteProcessAllocator::new();
2042 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2043 let task1_addr_string = task1_addr.to_string();
2044 let task1_cmd = Command::new(crate::testresource::get(
2045 "monarch/hyperactor_mesh/bootstrap",
2046 ));
2047 let task2_allocator = RemoteProcessAllocator::new();
2048 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2049 let task2_addr_string = task2_addr.to_string();
2050 let task2_cmd = Command::new(crate::testresource::get(
2051 "monarch/hyperactor_mesh/bootstrap",
2052 ));
2053 let task1_allocator_copy = task1_allocator.clone();
2054 let task1_allocator_handle = tokio::spawn(async move {
2055 tracing::info!("spawning task1");
2056 task1_allocator_copy
2057 .start(task1_cmd, task1_addr, None)
2058 .await
2059 .unwrap();
2060 });
2061 let task2_allocator_copy = task2_allocator.clone();
2062 let task2_allocator_handle = tokio::spawn(async move {
2063 task2_allocator_copy
2064 .start(task2_cmd, task2_addr, None)
2065 .await
2066 .unwrap();
2067 });
2068
2069 let mut initializer = MockRemoteProcessAllocInitializer::new();
2070 initializer.expect_initialize_alloc().return_once(move || {
2071 Ok(vec![
2072 RemoteProcessAllocHost {
2073 hostname: task1_addr_string,
2074 id: "task1".to_string(),
2075 },
2076 RemoteProcessAllocHost {
2077 hostname: task2_addr_string,
2078 id: "task2".to_string(),
2079 },
2080 ])
2081 });
2082 let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2083 .await
2084 .unwrap();
2085 let mut created = HashSet::new();
2086 let mut running_procs = HashSet::new();
2087 let mut proc_points = HashSet::new();
2088 for _ in 0..spec.extent.num_ranks() * 2 {
2089 let proc_state = alloc.next().await.unwrap();
2090 tracing::debug!("test got message: {:?}", proc_state);
2091 match proc_state {
2092 ProcState::Created {
2093 create_key, point, ..
2094 } => {
2095 created.insert(create_key);
2096 proc_points.insert(point);
2097 }
2098 ProcState::Running { create_key, .. } => {
2099 assert!(created.remove(&create_key));
2100 running_procs.insert(create_key);
2101 }
2102 _ => panic!("expected Created or Running"),
2103 }
2104 }
2105 assert!(created.is_empty());
2106 assert!(
2108 spec.extent
2109 .points()
2110 .all(|point| proc_points.contains(&point))
2111 );
2112
2113 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2115 tokio::select! {
2116 _ = hyperactor::clock::RealClock.sleep_until(timeout) => {},
2117 _ = alloc.next() => panic!("expected no more items"),
2118 }
2119
2120 alloc.stop().await.unwrap();
2122 for _ in 0..spec.extent.num_ranks() {
2123 let proc_state = alloc.next().await.unwrap();
2124 tracing::info!("test received next proc_state: {:?}", proc_state);
2125 match proc_state {
2126 ProcState::Stopped {
2127 create_key, reason, ..
2128 } => {
2129 assert!(running_procs.remove(&create_key));
2130 assert_eq!(reason, ProcStopReason::Stopped);
2131 }
2132 _ => panic!("expected stopped"),
2133 }
2134 }
2135 let proc_state = alloc.next().await;
2137 assert!(proc_state.is_none());
2138 let proc_state = alloc.next().await;
2140 assert!(proc_state.is_none());
2141
2142 task1_allocator.terminate();
2143 task1_allocator_handle.await.unwrap();
2144 task2_allocator.terminate();
2145 task2_allocator_handle.await.unwrap();
2146 }
2147
2148 #[async_timed_test(timeout_secs = 60)]
2149 async fn test_alloc_host_failure() {
2150 let config = hyperactor::config::global::lock();
2152 let _guard1 = config.override_key(
2153 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2154 Duration::from_secs(1),
2155 );
2156 let _guard2 = config.override_key(
2157 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2158 Duration::from_millis(100),
2159 );
2160 hyperactor_telemetry::initialize_logging(ClockKind::default());
2161
2162 let spec = AllocSpec {
2163 extent: extent!(host = 2, gpu = 2),
2164 constraints: Default::default(),
2165 proc_name: None,
2166 transport: ChannelTransport::Unix,
2167 };
2168 let world_id = WorldId("test_world_id".to_string());
2169
2170 let task1_allocator = RemoteProcessAllocator::new();
2171 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2172 let task1_addr_string = task1_addr.to_string();
2173 let task1_cmd = Command::new(crate::testresource::get(
2174 "monarch/hyperactor_mesh/bootstrap",
2175 ));
2176 let task2_allocator = RemoteProcessAllocator::new();
2177 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2178 let task2_addr_string = task2_addr.to_string();
2179 let task2_cmd = Command::new(crate::testresource::get(
2180 "monarch/hyperactor_mesh/bootstrap",
2181 ));
2182 let task1_allocator_copy = task1_allocator.clone();
2183 let task1_allocator_handle = tokio::spawn(async move {
2184 tracing::info!("spawning task1");
2185 task1_allocator_copy
2186 .start(task1_cmd, task1_addr, None)
2187 .await
2188 .unwrap();
2189 tracing::info!("task1 terminated");
2190 });
2191 let task2_allocator_copy = task2_allocator.clone();
2192 let task2_allocator_handle = tokio::spawn(async move {
2193 task2_allocator_copy
2194 .start(task2_cmd, task2_addr, None)
2195 .await
2196 .unwrap();
2197 tracing::info!("task2 terminated");
2198 });
2199
2200 let mut initializer = MockRemoteProcessAllocInitializer::new();
2201 initializer.expect_initialize_alloc().return_once(move || {
2202 Ok(vec![
2203 RemoteProcessAllocHost {
2204 hostname: task1_addr_string,
2205 id: "task1".to_string(),
2206 },
2207 RemoteProcessAllocHost {
2208 hostname: task2_addr_string,
2209 id: "task2".to_string(),
2210 },
2211 ])
2212 });
2213 let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2214 .await
2215 .unwrap();
2216 for _ in 0..spec.extent.num_ranks() * 2 {
2217 match alloc.next().await {
2218 Some(ProcState::Created { .. }) | Some(ProcState::Running { .. }) => {}
2219 _ => panic!("expected Created or Running"),
2220 }
2221 }
2222
2223 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2225 tokio::select! {
2226 _ = hyperactor::clock::RealClock
2227 .sleep_until(timeout) => {},
2228 _ = alloc.next() => panic!("expected no more items"),
2229 }
2230
2231 tracing::info!("aborting task1 allocator");
2233 task1_allocator_handle.abort();
2234 RealClock
2235 .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2)
2236 .await;
2237 for _ in 0..spec.extent.num_ranks() / 2 {
2238 let proc_state = alloc.next().await.unwrap();
2239 tracing::info!("test received next proc_state: {:?}", proc_state);
2240 match proc_state {
2241 ProcState::Stopped { reason, .. } => {
2242 assert_eq!(reason, ProcStopReason::HostWatchdog);
2243 }
2244 _ => panic!("expected stopped"),
2245 }
2246 }
2247 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2249 tokio::select! {
2250 _ = hyperactor::clock::RealClock
2251 .sleep_until(timeout) => {},
2252 _ = alloc.next() => panic!("expected no more items"),
2253 }
2254
2255 tracing::info!("aborting task2 allocator");
2257 task2_allocator_handle.abort();
2258 RealClock
2259 .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2)
2260 .await;
2261 for _ in 0..spec.extent.num_ranks() / 2 {
2262 let proc_state = alloc.next().await.unwrap();
2263 tracing::info!("test received next proc_state: {:?}", proc_state);
2264 match proc_state {
2265 ProcState::Stopped { reason, .. } => {
2266 assert_eq!(reason, ProcStopReason::HostWatchdog);
2267 }
2268 _ => panic!("expected stopped"),
2269 }
2270 }
2271 let proc_state = alloc.next().await;
2273 assert!(proc_state.is_none());
2274 let proc_state = alloc.next().await;
2276 assert!(proc_state.is_none());
2277 }
2278
2279 #[async_timed_test(timeout_secs = 15)]
2280 async fn test_alloc_inner_alloc_failure() {
2281 unsafe {
2283 std::env::set_var("MONARCH_MESSAGE_DELIVERY_TIMEOUT_SECS", "1");
2284 }
2285 let config = hyperactor::config::global::lock();
2286 let _guard = config.override_key(
2287 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2288 Duration::from_millis(100),
2289 );
2290 hyperactor_telemetry::initialize_logging(ClockKind::default());
2291
2292 let spec = AllocSpec {
2293 extent: extent!(host = 2, gpu = 2),
2294 constraints: Default::default(),
2295 proc_name: None,
2296 transport: ChannelTransport::Unix,
2297 };
2298 let world_id = WorldId("test_world_id".to_string());
2299
2300 let task1_allocator = RemoteProcessAllocator::new();
2301 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2302 let task1_addr_string = task1_addr.to_string();
2303 let task1_cmd = Command::new(crate::testresource::get(
2304 "monarch/hyperactor_mesh/bootstrap",
2305 ));
2306 let task2_allocator = RemoteProcessAllocator::new();
2307 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2308 let task2_addr_string = task2_addr.to_string();
2309 let task2_cmd = Command::new("/caught/somewhere/in/time");
2311 let task1_allocator_copy = task1_allocator.clone();
2312 let task1_allocator_handle = tokio::spawn(async move {
2313 tracing::info!("spawning task1");
2314 task1_allocator_copy
2315 .start(task1_cmd, task1_addr, None)
2316 .await
2317 .unwrap();
2318 });
2319 let task2_allocator_copy = task2_allocator.clone();
2320 let task2_allocator_handle = tokio::spawn(async move {
2321 task2_allocator_copy
2322 .start(task2_cmd, task2_addr, None)
2323 .await
2324 .unwrap();
2325 });
2326
2327 let mut initializer = MockRemoteProcessAllocInitializer::new();
2328 initializer.expect_initialize_alloc().return_once(move || {
2329 Ok(vec![
2330 RemoteProcessAllocHost {
2331 hostname: task1_addr_string,
2332 id: "task1".to_string(),
2333 },
2334 RemoteProcessAllocHost {
2335 hostname: task2_addr_string,
2336 id: "task2".to_string(),
2337 },
2338 ])
2339 });
2340 let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2341 .await
2342 .unwrap();
2343 let mut created = HashSet::new();
2344 let mut started_procs = HashSet::new();
2345 let mut proc_points = HashSet::new();
2346 let mut failed = 0;
2347 for _ in 0..spec.extent.num_ranks() + 1 {
2349 let proc_state = alloc.next().await.unwrap();
2350 tracing::debug!("test got message: {:?}", proc_state);
2351 match proc_state {
2352 ProcState::Created {
2353 create_key, point, ..
2354 } => {
2355 created.insert(create_key);
2356 proc_points.insert(point);
2357 }
2358 ProcState::Running { create_key, .. } => {
2359 assert!(created.remove(&create_key));
2360 started_procs.insert(create_key);
2361 }
2362 ProcState::Failed { .. } => {
2363 failed += 1;
2364 }
2365 _ => panic!("expected Created, Running or Failed"),
2366 }
2367 }
2368 assert!(created.is_empty());
2369 assert_eq!(failed, 1);
2370 for rank in 0..spec.extent.num_ranks() / 2 {
2372 let point = spec.extent.point_of_rank(rank).unwrap();
2373 assert!(proc_points.contains(&point));
2374 }
2375
2376 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2378 tokio::select! {
2379 _ = hyperactor::clock::RealClock
2380 .sleep_until(timeout) => {},
2381 _ = alloc.next() => panic!("expected no more items"),
2382 }
2383
2384 alloc.stop().await.unwrap();
2386 for _ in 0..spec.extent.num_ranks() / 2 {
2387 let proc_state = alloc.next().await.unwrap();
2388 tracing::info!("test received next proc_state: {:?}", proc_state);
2389 match proc_state {
2390 ProcState::Stopped {
2391 create_key, reason, ..
2392 } => {
2393 assert!(started_procs.remove(&create_key));
2394 assert_eq!(reason, ProcStopReason::Stopped);
2395 }
2396 _ => panic!("expected stopped"),
2397 }
2398 }
2399 let proc_state = alloc.next().await;
2401 assert!(proc_state.is_none());
2402 let proc_state = alloc.next().await;
2404 assert!(proc_state.is_none());
2405
2406 task1_allocator.terminate();
2407 task1_allocator_handle.await.unwrap();
2408 task2_allocator.terminate();
2409 task2_allocator_handle.await.unwrap();
2410 }
2411
2412 #[tracing_test::traced_test]
2413 #[async_timed_test(timeout_secs = 60)]
2414 async fn test_remote_process_alloc_signal_handler() {
2415 let num_proc_meshes = 5;
2416 let hosts_per_proc_mesh = 5;
2417
2418 let pid_addr = ChannelAddr::any(ChannelTransport::Unix);
2419 let (pid_addr, mut pid_rx) = channel::serve::<u32>(pid_addr).unwrap();
2420
2421 let addresses = (0..(num_proc_meshes * hosts_per_proc_mesh))
2422 .map(|_| ChannelAddr::any(ChannelTransport::Unix).to_string())
2423 .collect::<Vec<_>>();
2424
2425 let remote_process_allocators = addresses
2426 .iter()
2427 .map(|addr| {
2428 Command::new(crate::testresource::get(
2429 "monarch/hyperactor_mesh/remote_process_allocator",
2430 ))
2431 .env("RUST_LOG", "info")
2432 .arg(format!("--addr={addr}"))
2433 .stdout(std::process::Stdio::piped())
2434 .spawn()
2435 .unwrap()
2436 })
2437 .collect::<Vec<_>>();
2438
2439 let done_allocating_addr = ChannelAddr::any(ChannelTransport::Unix);
2440 let (done_allocating_addr, mut done_allocating_rx) =
2441 channel::serve::<()>(done_allocating_addr).unwrap();
2442 let mut remote_process_alloc = Command::new(crate::testresource::get(
2443 "monarch/hyperactor_mesh/remote_process_alloc",
2444 ))
2445 .arg(format!("--done-allocating-addr={}", done_allocating_addr))
2446 .arg(format!("--addresses={}", addresses.join(",")))
2447 .arg(format!("--num-proc-meshes={}", num_proc_meshes))
2448 .arg(format!("--hosts-per-proc-mesh={}", hosts_per_proc_mesh))
2449 .arg(format!("--pid-addr={}", pid_addr))
2450 .spawn()
2451 .unwrap();
2452
2453 done_allocating_rx.recv().await.unwrap();
2454 let mut received_pids = Vec::new();
2455 while let Ok(pid) = pid_rx.recv().await {
2456 received_pids.push(pid);
2457 if received_pids.len() == remote_process_allocators.len() {
2458 break;
2459 }
2460 }
2461
2462 signal::kill(
2463 Pid::from_raw(remote_process_alloc.id().unwrap() as i32),
2464 signal::SIGINT,
2465 )
2466 .unwrap();
2467
2468 assert_eq!(
2469 remote_process_alloc.wait().await.unwrap().signal(),
2470 Some(signal::SIGINT as i32)
2471 );
2472
2473 RealClock.sleep(tokio::time::Duration::from_secs(5)).await;
2474
2475 for child_pid in received_pids {
2477 let pid_check = Command::new("kill")
2478 .arg("-0")
2479 .arg(child_pid.to_string())
2480 .output()
2481 .await
2482 .expect("Failed to check if PID is alive");
2483
2484 assert!(
2485 !pid_check.status.success(),
2486 "PID {} should no longer be alive",
2487 child_pid
2488 );
2489 }
2490
2491 for mut remote_process_allocator in remote_process_allocators {
2494 remote_process_allocator.kill().await.unwrap();
2495 }
2496 }
2497}