1use std::collections::HashMap;
10use std::collections::HashSet;
11use std::collections::VecDeque;
12use std::sync::Arc;
13use std::time::Duration;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use dashmap::DashMap;
18use futures::FutureExt;
19use futures::future::join_all;
20use futures::future::select_all;
21use hyperactor::channel;
22use hyperactor::channel::ChannelAddr;
23use hyperactor::channel::ChannelRx;
24use hyperactor::channel::ChannelTransport;
25use hyperactor::channel::ChannelTx;
26use hyperactor::channel::Rx;
27use hyperactor::channel::TcpMode;
28use hyperactor::channel::Tx;
29use hyperactor::channel::TxStatus;
30use hyperactor::internal_macro_support::serde_json;
31use hyperactor::mailbox::DialMailboxRouter;
32use hyperactor::mailbox::MailboxServer;
33use hyperactor::observe_async;
34use hyperactor::observe_result;
35use mockall::automock;
36use ndslice::Region;
37use ndslice::Slice;
38use ndslice::View;
39use ndslice::ViewExt;
40use ndslice::view::Extent;
41use ndslice::view::Point;
42use serde::Deserialize;
43use serde::Serialize;
44use strum::AsRefStr;
45use tokio::io::AsyncWriteExt;
46use tokio::process::Command;
47use tokio::sync::mpsc::UnboundedReceiver;
48use tokio::sync::mpsc::UnboundedSender;
49use tokio::sync::mpsc::unbounded_channel;
50use tokio::task::JoinHandle;
51use tokio_stream::StreamExt;
52use tokio_stream::wrappers::WatchStream;
53use tokio_util::sync::CancellationToken;
54use typeuri::Named;
55
56use crate::alloc::Alloc;
57use crate::alloc::AllocConstraints;
58use crate::alloc::AllocName;
59use crate::alloc::AllocSpec;
60use crate::alloc::Allocator;
61use crate::alloc::AllocatorError;
62use crate::alloc::ProcState;
63use crate::alloc::ProcStopReason;
64use crate::alloc::ProcessAllocator;
65use crate::alloc::process::CLIENT_TRACE_ID_LABEL;
66use crate::alloc::process::ClientContext;
67use crate::alloc::with_unspecified_port_or_any;
68use crate::shortuuid::ShortUuid;
69
70#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
72pub enum RemoteProcessAllocatorMessage {
73 Allocate {
75 alloc_key: ShortUuid,
77 extent: Extent,
79 bootstrap_addr: ChannelAddr,
81 hosts: Vec<String>,
84 client_context: Option<ClientContext>,
88 forwarder_addr: ChannelAddr,
90 },
91 Stop,
93 HeartBeat,
96}
97wirevalue::register_type!(RemoteProcessAllocatorMessage);
98
99#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
103pub enum RemoteProcessProcStateMessage {
104 Allocated {
106 alloc_key: ShortUuid,
107 alloc_name: AllocName,
108 },
109 Update(ShortUuid, ProcState),
111 Done(ShortUuid),
113 HeartBeat,
115}
116wirevalue::register_type!(RemoteProcessProcStateMessage);
117
118pub struct RemoteProcessAllocator {
120 cancel_token: CancellationToken,
121}
122
123async fn conditional_sleeper<F: futures::Future<Output = ()>>(t: Option<F>) {
124 match t {
125 Some(timer) => timer.await,
126 None => futures::future::pending().await,
127 }
128}
129
130impl RemoteProcessAllocator {
131 pub fn new() -> Arc<Self> {
133 Arc::new(Self {
134 cancel_token: CancellationToken::new(),
135 })
136 }
137
138 pub fn terminate(&self) {
140 self.cancel_token.cancel();
141 }
142
143 #[hyperactor::instrument]
159 pub async fn start(
160 &self,
161 cmd: Command,
162 serve_addr: ChannelAddr,
163 timeout: Option<Duration>,
164 ) -> Result<(), anyhow::Error> {
165 let process_allocator = ProcessAllocator::new(cmd);
166 self.start_with_allocator(serve_addr, process_allocator, timeout)
167 .await
168 }
169
170 #[hyperactor::instrument(fields(addr=serve_addr.to_string()))]
174 pub async fn start_with_allocator<A: Allocator + Send + Sync + 'static>(
175 &self,
176 serve_addr: ChannelAddr,
177 mut process_allocator: A,
178 timeout: Option<Duration>,
179 ) -> Result<(), anyhow::Error>
180 where
181 <A as Allocator>::Alloc: Send,
182 <A as Allocator>::Alloc: Sync,
183 {
184 tracing::info!("starting remote allocator on: {}", serve_addr);
185 let (_, mut rx) = channel::serve(serve_addr.clone()).map_err(anyhow::Error::from)?;
186
187 struct ActiveAllocation {
188 handle: JoinHandle<()>,
189 cancel_token: CancellationToken,
190 }
191 #[observe_async("RemoteProcessAllocator")]
192 async fn ensure_previous_alloc_stopped(active_allocation: &mut Option<ActiveAllocation>) {
193 if let Some(active_allocation) = active_allocation.take() {
194 tracing::info!("previous alloc found, stopping");
195 active_allocation.cancel_token.cancel();
196 match active_allocation.handle.await {
197 Ok(_) => {
198 tracing::info!("allocation stopped.")
200 }
201 Err(e) => {
202 tracing::error!("allocation handler failed: {}", e);
203 }
204 }
205 }
206 }
207
208 let mut active_allocation: Option<ActiveAllocation> = None;
209 loop {
210 let sleep = conditional_sleeper(timeout.map(|t| tokio::time::sleep(t)));
213 tokio::select! {
214 msg = rx.recv() => {
215 match msg {
216 Ok(RemoteProcessAllocatorMessage::Allocate {
217 alloc_key,
218 extent,
219 bootstrap_addr,
220 hosts,
221 client_context,
222 forwarder_addr,
223 }) => {
224 tracing::info!("received allocation request for {} with extent {}", alloc_key, extent);
225 ensure_previous_alloc_stopped(&mut active_allocation).await;
226
227 let mut constraints: AllocConstraints = Default::default();
229 if let Some(context) = &client_context {
230 constraints = AllocConstraints {
231 match_labels: HashMap::from([(
232 CLIENT_TRACE_ID_LABEL.to_string(),
233 context.trace_id.to_string(),
234 )]
235 )};
236 tracing::info!(
237 monarch_client_trace_id = context.trace_id.to_string(),
238 "allocating...",
239 );
240 }
241
242
243 let spec = AllocSpec {
244 extent,
245 constraints,
246 proc_name: None, transport: ChannelTransport::Unix,
248 proc_allocation_mode: Default::default(),
249 };
250
251 match process_allocator.allocate(spec.clone()).await {
252 Ok(alloc) => {
253 let cancel_token = CancellationToken::new();
254 active_allocation = Some(ActiveAllocation {
255 cancel_token: cancel_token.clone(),
256 handle: tokio::spawn(Self::handle_allocation_request(
257 Box::new(alloc) as Box<dyn Alloc + Send + Sync>,
258 alloc_key,
259 bootstrap_addr,
260 hosts,
261 cancel_token,
262 forwarder_addr,
263 )),
264 })
265 }
266 Err(e) => {
267 tracing::error!("allocation for {:?} failed: {}", spec, e);
268 continue;
269 }
270 }
271 }
272 Ok(RemoteProcessAllocatorMessage::Stop) => {
273 tracing::info!("received stop request");
274
275 ensure_previous_alloc_stopped(&mut active_allocation).await;
276 }
277 Ok(RemoteProcessAllocatorMessage::HeartBeat) => {}
281 Err(e) => {
282 tracing::error!("upstream channel error: {}", e);
283 continue;
284 }
285 }
286 }
287 _ = self.cancel_token.cancelled() => {
288 tracing::info!("main loop cancelled");
289
290 ensure_previous_alloc_stopped(&mut active_allocation).await;
291
292 break;
293 }
294 _ = sleep => {
295 if active_allocation.is_some() {
297 continue;
298 }
299 tracing::warn!("timeout of {} seconds elapsed without any allocations, exiting", timeout.unwrap_or_default().as_secs());
302 break;
303 }
304 }
305 }
306
307 Ok(())
308 }
309
310 #[tracing::instrument(skip(alloc, cancel_token))]
311 #[observe_async("RemoteProcessAllocator")]
312 async fn handle_allocation_request(
313 alloc: Box<dyn Alloc + Send + Sync>,
314 alloc_key: ShortUuid,
315 bootstrap_addr: ChannelAddr,
316 hosts: Vec<String>,
317 cancel_token: CancellationToken,
318 forwarder_addr: ChannelAddr,
319 ) {
320 tracing::info!("handle allocation request, bootstrap_addr: {bootstrap_addr}");
321 let (forwarder_addr, forwarder_rx) = match channel::serve(forwarder_addr) {
323 Ok(v) => v,
324 Err(e) => {
325 tracing::error!("failed to to bootstrap forwarder actor: {}", e);
326 return;
327 }
328 };
329 let router = DialMailboxRouter::new();
330 let mailbox_handle = router.clone().serve(forwarder_rx);
331 tracing::info!("started forwarder on: {}", forwarder_addr);
332
333 if let Ok(hosts_file) = std::env::var("TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE") {
336 tracing::info!("writing hosts to {}", hosts_file);
337 #[derive(Serialize)]
338 struct Hosts {
339 hostnames: Vec<String>,
340 }
341 match serde_json::to_string(&Hosts { hostnames: hosts }) {
342 Ok(json) => match tokio::fs::File::create(&hosts_file).await {
343 Ok(mut file) => {
344 if file.write_all(json.as_bytes()).await.is_err() {
345 tracing::error!("failed to write hosts to {}", hosts_file);
346 return;
347 }
348 }
349 Err(e) => {
350 tracing::error!("failed to open hosts file {}: {}", hosts_file, e);
351 return;
352 }
353 },
354 Err(e) => {
355 tracing::error!("failed to serialize hosts: {}", e);
356 return;
357 }
358 }
359 }
360
361 Self::handle_allocation_loop(
362 alloc,
363 alloc_key,
364 bootstrap_addr,
365 router,
366 forwarder_addr,
367 cancel_token,
368 )
369 .await;
370
371 mailbox_handle.stop("alloc stopped");
372 if let Err(e) = mailbox_handle.await {
373 tracing::error!("failed to join forwarder: {}", e);
374 }
375 }
376
377 async fn handle_allocation_loop(
378 mut alloc: Box<dyn Alloc + Send + Sync>,
379 alloc_key: ShortUuid,
380 bootstrap_addr: ChannelAddr,
381 router: DialMailboxRouter,
382 forward_addr: ChannelAddr,
383 cancel_token: CancellationToken,
384 ) {
385 let alloc_name = alloc.alloc_name().clone();
386 tracing::info!("starting handle allocation loop for {}", alloc_name);
387 let tx = match channel::dial(bootstrap_addr) {
388 Ok(tx) => tx,
389 Err(err) => {
390 tracing::error!("failed to dial bootstrap address: {}", err);
391 return;
392 }
393 };
394 let message = RemoteProcessProcStateMessage::Allocated {
395 alloc_key: alloc_key.clone(),
396 alloc_name,
397 };
398 tracing::info!(name = message.as_ref(), "sending allocated message",);
399 if let Err(e) = tx.send(message).await {
400 tracing::error!("failed to send Allocated message: {}", e);
401 return;
402 }
403
404 let mut mesh_agents_by_create_key = HashMap::new();
405 let mut running = true;
406 let tx_status = tx.status().clone();
407 let mut tx_watcher = WatchStream::new(tx_status);
408 loop {
409 tokio::select! {
410 _ = cancel_token.cancelled(), if running => {
411 tracing::info!("cancelled, stopping allocation");
412 running = false;
413 if let Err(e) = alloc.stop().await {
414 tracing::error!("stop failed: {}", e);
415 break;
416 }
417 }
418 status = tx_watcher.next(), if running => {
419 match status {
420 Some(TxStatus::Closed) => {
421 tracing::error!("upstream channel state closed");
422 break;
423 },
424 _ => {
425 tracing::debug!("got channel event: {:?}", status.unwrap());
426 continue;
427 }
428 }
429 }
430 e = alloc.next() => {
431 match e {
432 Some(event) => {
433 tracing::debug!(name = event.as_ref(), "got event: {:?}", event);
434 let event = match event {
435 ProcState::Created { .. } => event,
436 ProcState::Running { create_key, proc_id, mesh_agent, addr } => {
437 tracing::debug!("remapping mesh_agent {}: addr {} -> {}", mesh_agent, addr, forward_addr);
439 mesh_agents_by_create_key.insert(create_key.clone(), mesh_agent.clone());
440 router.bind(mesh_agent.actor_id().proc_id().clone().into(), addr);
441 ProcState::Running { create_key, proc_id, mesh_agent, addr: forward_addr.clone() }
442 },
443 ProcState::Stopped { create_key, reason } => {
444 match mesh_agents_by_create_key.remove(&create_key) {
445 Some(mesh_agent) => {
446 tracing::debug!("unmapping mesh_agent {}", mesh_agent);
447 let agent_ref: hyperactor::reference::Reference = mesh_agent.actor_id().proc_id().clone().into();
448 router.unbind(&agent_ref);
449 },
450 None => {
451 tracing::warn!("mesh_agent not found for create key {}", create_key);
452 }
453 }
454 ProcState::Stopped { create_key, reason }
455 },
456 ProcState::Failed { ref alloc_name, ref description } => {
457 tracing::error!("allocation failed for {}: {}", alloc_name, description);
458 event
459 }
460 };
461 tracing::debug!(name = event.as_ref(), "sending event: {:?}", event);
462 tx.post(RemoteProcessProcStateMessage::Update(alloc_key.clone(), event));
463 }
464 None => {
465 tracing::debug!("sending done");
466 if let Err(e) = tx.send(RemoteProcessProcStateMessage::Done(alloc_key.clone())).await {
474 tracing::error!("failed to send Done message: {}", e);
475 }
476 running = false;
477 break;
478 }
479 }
480 }
481 _ = tokio::time::sleep(hyperactor_config::global::get(hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)) => {
482 tracing::trace!("sending heartbeat");
483 tx.post(RemoteProcessProcStateMessage::HeartBeat);
484 }
485 }
486 }
487 tracing::info!("allocation handler loop exited");
488 if running {
489 tracing::info!("stopping processes");
490 if let Err(e) = alloc.stop_and_wait().await {
491 tracing::error!("stop failed: {}", e);
492 return;
493 }
494 tracing::info!("stop finished");
495 }
496 }
497}
498
499type HostId = String;
502
503#[derive(Clone)]
505pub struct RemoteProcessAllocHost {
506 pub id: HostId,
509 pub hostname: String,
511}
512
513struct RemoteProcessAllocHostState {
515 alloc_key: ShortUuid,
517 host_id: HostId,
519 tx: ChannelTx<RemoteProcessAllocatorMessage>,
521 active_procs: HashSet<ShortUuid>,
523 region: Region,
525 alloc_name: Option<AllocName>,
527 failed: bool,
529 allocated: bool,
531}
532
533#[automock]
534#[async_trait]
535pub trait RemoteProcessAllocInitializer {
537 async fn initialize_alloc(&mut self) -> Result<Vec<RemoteProcessAllocHost>, anyhow::Error>;
539}
540
541struct HostStates {
544 inner: HashMap<HostId, RemoteProcessAllocHostState>,
545 host_addresses: Arc<DashMap<HostId, ChannelAddr>>,
546}
547
548impl HostStates {
549 fn new(host_addresses: Arc<DashMap<HostId, ChannelAddr>>) -> HostStates {
550 Self {
551 inner: HashMap::new(),
552 host_addresses,
553 }
554 }
555
556 fn insert(
557 &mut self,
558 host_id: HostId,
559 state: RemoteProcessAllocHostState,
560 address: ChannelAddr,
561 ) {
562 self.host_addresses.insert(host_id.clone(), address);
563 self.inner.insert(host_id, state);
564 }
565
566 fn get(&self, host_id: &HostId) -> Option<&RemoteProcessAllocHostState> {
567 self.inner.get(host_id)
568 }
569
570 fn get_mut(&mut self, host_id: &HostId) -> Option<&mut RemoteProcessAllocHostState> {
571 self.inner.get_mut(host_id)
572 }
573
574 fn remove(&mut self, host_id: &HostId) -> Option<RemoteProcessAllocHostState> {
575 self.host_addresses.remove(host_id);
576 self.inner.remove(host_id)
577 }
578
579 fn iter(&self) -> impl Iterator<Item = (&HostId, &RemoteProcessAllocHostState)> {
580 self.inner.iter()
581 }
582
583 fn iter_mut(&mut self) -> impl Iterator<Item = (&HostId, &mut RemoteProcessAllocHostState)> {
584 self.inner.iter_mut()
585 }
586
587 fn is_empty(&self) -> bool {
588 self.inner.is_empty()
589 }
590 }
592
593pub struct RemoteProcessAlloc {
596 initializer: Box<dyn RemoteProcessAllocInitializer + Send + Sync>,
599 spec: AllocSpec,
600 remote_allocator_port: u16,
601 alloc_name: AllocName,
602 ordered_hosts: Vec<RemoteProcessAllocHost>,
603 started: bool,
605 running: bool,
607 failed: bool,
609 alloc_to_host: HashMap<ShortUuid, HostId>,
611 host_states: HostStates,
612 event_queue: VecDeque<ProcState>,
613 comm_watcher_tx: UnboundedSender<HostId>,
614 comm_watcher_rx: UnboundedReceiver<HostId>,
615
616 bootstrap_addr: ChannelAddr,
617 rx: ChannelRx<RemoteProcessProcStateMessage>,
618 _signal_cleanup_guard: hyperactor::SignalCleanupGuard,
619}
620
621impl RemoteProcessAlloc {
622 #[tracing::instrument(skip(initializer))]
627 #[observe_result("RemoteProcessAlloc")]
628 pub async fn new(
629 spec: AllocSpec,
630 alloc_name: AllocName,
631 remote_allocator_port: u16,
632 initializer: impl RemoteProcessAllocInitializer + Send + Sync + 'static,
633 ) -> Result<Self, anyhow::Error> {
634 let alloc_serve_addr = ChannelAddr::any(spec.transport.clone());
635
636 let (bootstrap_addr, rx) = channel::serve(alloc_serve_addr)?;
637
638 tracing::info!(
639 "starting alloc for {} on: {}",
640 alloc_name,
641 bootstrap_addr.clone()
642 );
643
644 let (comm_watcher_tx, comm_watcher_rx) = unbounded_channel();
645
646 let host_addresses = Arc::new(DashMap::<HostId, ChannelAddr>::new());
647 let host_addresses_for_signal = host_addresses.clone();
648
649 let signal_cleanup_guard =
651 hyperactor::register_signal_cleanup_scoped(Box::pin(async move {
652 join_all(host_addresses_for_signal.iter().map(|entry| async move {
653 let addr = entry.value().clone();
654 match channel::dial(addr.clone()) {
655 Ok(tx) => {
656 if let Err(e) = tx.send(RemoteProcessAllocatorMessage::Stop).await {
657 tracing::error!("Failed to send Stop to {}: {}", addr, e);
658 }
659 }
660 Err(e) => {
661 tracing::error!("Failed to dial {} during signal cleanup: {}", addr, e);
662 }
663 }
664 }))
665 .await;
666 }));
667
668 Ok(Self {
669 spec,
670 alloc_name,
671 remote_allocator_port,
672 initializer: Box::new(initializer),
673 ordered_hosts: Vec::new(),
674 alloc_to_host: HashMap::new(),
675 host_states: HostStates::new(host_addresses),
676 bootstrap_addr,
677 event_queue: VecDeque::new(),
678 comm_watcher_tx,
679 comm_watcher_rx,
680 rx,
681 started: false,
682 running: true,
683 failed: false,
684 _signal_cleanup_guard: signal_cleanup_guard,
685 })
686 }
687
688 async fn start_comm_watcher(&self) {
696 let mut tx_watchers = Vec::new();
697 for host in &self.ordered_hosts {
698 let tx_status = self.host_states.get(&host.id).unwrap().tx.status().clone();
699 let watcher = WatchStream::new(tx_status);
700 tx_watchers.push((watcher, host.id.clone()));
701 }
702 assert!(!tx_watchers.is_empty());
703 let tx = self.comm_watcher_tx.clone();
704 tokio::spawn(async move {
705 loop {
706 let mut tx_status_futures = Vec::new();
707 for (watcher, _) in &mut tx_watchers {
708 let fut = watcher.next().boxed();
709 tx_status_futures.push(fut);
710 }
711 let (tx_status, index, _) = select_all(tx_status_futures).await;
712 let host_id = match tx_watchers.get(index) {
713 Some((_, host_id)) => host_id.clone(),
714 None => {
715 tracing::error!(
717 "got selected index {} with no matching host in {}",
718 index,
719 tx_watchers.len()
720 );
721 continue;
722 }
723 };
724 if let Some(tx_status) = tx_status {
725 tracing::debug!("host {} channel event: {:?}", host_id, tx_status);
726 if tx_status == TxStatus::Closed {
727 if tx.send(host_id.clone()).is_err() {
728 break;
730 }
731 tx_watchers.remove(index);
732 if tx_watchers.is_empty() {
733 break;
735 }
736 }
737 }
738 }
739 });
740 }
741
742 async fn ensure_started(&mut self) -> Result<(), anyhow::Error> {
747 if self.started || self.failed {
748 return Ok(());
749 }
750
751 self.started = true;
752 let hosts = self
753 .initializer
754 .initialize_alloc()
755 .await
756 .context("alloc initializer error")?;
757 if hosts.is_empty() {
758 anyhow::bail!("initializer returned empty list of hosts");
759 }
760 let hostnames: Vec<_> = hosts.iter().map(|e| e.hostname.clone()).collect();
763 tracing::info!("obtained {} hosts for this allocation", hostnames.len());
764
765 use crate::alloc::ProcAllocationMode;
767
768 let regions: Option<Vec<_>> = match self.spec.proc_allocation_mode {
770 ProcAllocationMode::ProcLevel => {
771 anyhow::ensure!(
773 self.spec.extent.len() >= 2,
774 "invalid extent: {}, expected at least 2 dimensions",
775 self.spec.extent
776 );
777 None
778 }
779 ProcAllocationMode::HostLevel => Some({
780 let num_points = self.spec.extent.num_ranks();
782 anyhow::ensure!(
783 hosts.len() >= num_points,
784 "HostLevel allocation mode requires {} hosts (one per point in extent {}), but only {} hosts were provided",
785 num_points,
786 self.spec.extent,
787 hosts.len()
788 );
789
790 let labels = self.spec.extent.labels().to_vec();
793
794 let extent_sizes = self.spec.extent.sizes();
796 let mut parent_strides = vec![1; extent_sizes.len()];
797 for i in (0..extent_sizes.len() - 1).rev() {
798 parent_strides[i] = parent_strides[i + 1] * extent_sizes[i + 1];
799 }
800
801 (0..num_points)
802 .map(|rank| {
803 let sizes = vec![1; labels.len()];
806 Region::new(
807 labels.clone(),
808 Slice::new(rank, sizes, parent_strides.clone()).unwrap(),
809 )
810 })
811 .collect()
812 }),
813 };
814
815 match self.spec.proc_allocation_mode {
816 ProcAllocationMode::ProcLevel => {
817 let split_dim = &self.spec.extent.labels()[self.spec.extent.len() - 1];
819 for (i, region) in self.spec.extent.group_by(split_dim)?.enumerate() {
820 let host = &hosts[i];
821 tracing::debug!("allocating: {} for host: {}", region, host.id);
822
823 let remote_addr = match self.spec.transport {
824 ChannelTransport::MetaTls(_) => {
825 format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
826 }
827 ChannelTransport::Tcp(TcpMode::Localhost) => {
828 format!("tcp![::1]:{}", self.remote_allocator_port)
830 }
831 ChannelTransport::Tcp(TcpMode::Hostname) => {
832 format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
833 }
834 ChannelTransport::Unix => host.hostname.clone(),
836 _ => {
837 anyhow::bail!(
838 "unsupported transport for host {}: {:?}",
839 host.id,
840 self.spec.transport,
841 );
842 }
843 };
844
845 tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
846 let remote_addr = remote_addr.parse::<ChannelAddr>()?;
847 let tx = channel::dial(remote_addr.clone())
848 .map_err(anyhow::Error::from)
849 .context(format!(
850 "failed to dial remote {} for host {}",
851 remote_addr, host.id
852 ))?;
853
854 let alloc_key = ShortUuid::generate();
856 assert!(
857 self.alloc_to_host
858 .insert(alloc_key.clone(), host.id.clone())
859 .is_none()
860 );
861
862 let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
863 let client_context = Some(ClientContext { trace_id });
864 let message = RemoteProcessAllocatorMessage::Allocate {
865 alloc_key: alloc_key.clone(),
866 extent: region.extent(),
867 bootstrap_addr: self.bootstrap_addr.clone(),
868 hosts: hostnames.clone(),
869 client_context,
870 forwarder_addr: with_unspecified_port_or_any(&remote_addr),
876 };
877 tracing::info!(
878 name = message.as_ref(),
879 "sending allocate message to workers"
880 );
881 tx.post(message);
882
883 self.host_states.insert(
884 host.id.clone(),
885 RemoteProcessAllocHostState {
886 alloc_key,
887 host_id: host.id.clone(),
888 tx,
889 active_procs: HashSet::new(),
890 region,
891 alloc_name: None,
892 failed: false,
893 allocated: false,
894 },
895 remote_addr,
896 );
897 }
898
899 self.ordered_hosts = hosts;
900 }
901 ProcAllocationMode::HostLevel => {
902 let regions = regions.unwrap();
903 let num_regions = regions.len();
904 for (i, region) in regions.into_iter().enumerate() {
905 let host = &hosts[i];
906 tracing::debug!("allocating: {} for host: {}", region, host.id);
907
908 let remote_addr = match self.spec.transport {
909 ChannelTransport::MetaTls(_) => {
910 format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
911 }
912 ChannelTransport::Tcp(TcpMode::Localhost) => {
913 format!("tcp![::1]:{}", self.remote_allocator_port)
915 }
916 ChannelTransport::Tcp(TcpMode::Hostname) => {
917 format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
918 }
919 ChannelTransport::Unix => host.hostname.clone(),
921 _ => {
922 anyhow::bail!(
923 "unsupported transport for host {}: {:?}",
924 host.id,
925 self.spec.transport,
926 );
927 }
928 };
929
930 tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
931 let remote_addr = remote_addr.parse::<ChannelAddr>()?;
932 let tx = channel::dial(remote_addr.clone())
933 .map_err(anyhow::Error::from)
934 .context(format!(
935 "failed to dial remote {} for host {}",
936 remote_addr, host.id
937 ))?;
938
939 let alloc_key = ShortUuid::generate();
941 assert!(
942 self.alloc_to_host
943 .insert(alloc_key.clone(), host.id.clone())
944 .is_none()
945 );
946
947 let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
948 let client_context = Some(ClientContext { trace_id });
949 let message = RemoteProcessAllocatorMessage::Allocate {
950 alloc_key: alloc_key.clone(),
951 extent: region.extent(),
952 bootstrap_addr: self.bootstrap_addr.clone(),
953 hosts: hostnames.clone(),
954 client_context,
955 forwarder_addr: with_unspecified_port_or_any(&remote_addr),
961 };
962 tracing::info!(
963 name = message.as_ref(),
964 "sending allocate message to workers"
965 );
966 tx.post(message);
967
968 self.host_states.insert(
969 host.id.clone(),
970 RemoteProcessAllocHostState {
971 alloc_key,
972 host_id: host.id.clone(),
973 tx,
974 active_procs: HashSet::new(),
975 region,
976 alloc_name: None,
977 failed: false,
978 allocated: false,
979 },
980 remote_addr,
981 );
982 }
983
984 self.ordered_hosts = hosts.into_iter().take(num_regions).collect();
987 }
988 }
989 self.start_comm_watcher().await;
990 self.started = true;
991
992 Ok(())
993 }
994
995 fn get_host_state_mut(
997 &mut self,
998 alloc_key: &ShortUuid,
999 ) -> Result<&mut RemoteProcessAllocHostState, anyhow::Error> {
1000 let host_id: &HostId = self
1001 .alloc_to_host
1002 .get(alloc_key)
1003 .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1004
1005 self.host_states
1006 .get_mut(host_id)
1007 .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1008 }
1009
1010 fn get_host_state(
1012 &self,
1013 alloc_key: &ShortUuid,
1014 ) -> Result<&RemoteProcessAllocHostState, anyhow::Error> {
1015 let host_id: &HostId = self
1016 .alloc_to_host
1017 .get(alloc_key)
1018 .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1019
1020 self.host_states
1021 .get(host_id)
1022 .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1023 }
1024
1025 fn remove_host_state(
1026 &mut self,
1027 alloc_key: &ShortUuid,
1028 ) -> Result<RemoteProcessAllocHostState, anyhow::Error> {
1029 let host_id: &HostId = self
1030 .alloc_to_host
1031 .get(alloc_key)
1032 .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1033
1034 self.host_states
1035 .remove(host_id)
1036 .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1037 }
1038
1039 fn add_proc_id_to_host_state(
1040 &mut self,
1041 alloc_key: &ShortUuid,
1042 create_key: &ShortUuid,
1043 ) -> Result<(), anyhow::Error> {
1044 let task_state = self.get_host_state_mut(alloc_key)?;
1045 if !task_state.active_procs.insert(create_key.clone()) {
1046 tracing::error!("proc with create key {} already in host state", create_key);
1048 }
1049 task_state.allocated = true;
1050 Ok(())
1051 }
1052
1053 fn remove_proc_from_host_state(
1054 &mut self,
1055 alloc_key: &ShortUuid,
1056 create_key: &ShortUuid,
1057 ) -> Result<(), anyhow::Error> {
1058 let task_state = self.get_host_state_mut(alloc_key)?;
1059 if !task_state.active_procs.remove(create_key) {
1060 tracing::error!("proc with create_key already in host state: {}", create_key);
1062 }
1063 Ok(())
1064 }
1065
1066 fn project_proc_into_global_extent(
1068 &self,
1069 alloc_key: &ShortUuid,
1070 point: &Point,
1071 ) -> Result<Point, anyhow::Error> {
1072 let global_rank = self
1073 .get_host_state(alloc_key)?
1074 .region
1075 .get(point.rank())
1076 .ok_or_else(|| {
1077 anyhow::anyhow!(
1078 "rank {} out of bounds for in alloc {}",
1079 point.rank(),
1080 alloc_key
1081 )
1082 })?;
1083 Ok(self.spec.extent.point_of_rank(global_rank)?)
1084 }
1085
1086 fn cleanup_host_channel_closed(
1088 &mut self,
1089 host_id: HostId,
1090 ) -> Result<Vec<ShortUuid>, anyhow::Error> {
1091 let state = match self.host_states.remove(&host_id) {
1092 Some(state) => state,
1093 None => {
1094 anyhow::bail!(
1096 "got channel closed event for host {} which has no known state",
1097 host_id
1098 );
1099 }
1100 };
1101 self.ordered_hosts.retain(|host| host.id != host_id);
1102 self.alloc_to_host.remove(&state.alloc_key);
1103 let create_keys = state.active_procs.iter().cloned().collect();
1104
1105 Ok(create_keys)
1106 }
1107}
1108
1109#[async_trait]
1110impl Alloc for RemoteProcessAlloc {
1111 async fn next(&mut self) -> Option<ProcState> {
1112 loop {
1113 if let state @ Some(_) = self.event_queue.pop_front() {
1114 break state;
1115 }
1116
1117 if !self.running {
1118 break None;
1119 }
1120
1121 if let Err(e) = self.ensure_started().await {
1122 break Some(ProcState::Failed {
1123 alloc_name: self.alloc_name.clone(),
1124 description: format!("failed to ensure started: {:#}", e),
1125 });
1126 }
1127
1128 let heartbeat_interval = hyperactor_config::global::get(
1129 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1130 );
1131 let mut heartbeat_time = tokio::time::Instant::now() + heartbeat_interval;
1132 let mut reloop = false;
1134 let update = loop {
1135 tokio::select! {
1136 msg = self.rx.recv() => {
1137 tracing::debug!("got ProcState message from allocator: {:?}", msg);
1138 match msg {
1139 Ok(RemoteProcessProcStateMessage::Allocated { alloc_key, alloc_name }) => {
1140 tracing::info!("remote alloc {}: allocated", alloc_key);
1141 match self.get_host_state_mut(&alloc_key) {
1142 Ok(state) => {
1143 state.alloc_name = Some(alloc_name.clone());
1144 }
1145 Err(err) => {
1146 tracing::error!(
1148 "received allocated message alloc: {} with no known state: {}",
1149 alloc_key, err,
1150 );
1151 }
1152 }
1153 }
1154 Ok(RemoteProcessProcStateMessage::Update(alloc_key, proc_state)) => {
1155 let update = match proc_state {
1156 ProcState::Created { ref create_key, .. } => {
1157 if let Err(e) = self.add_proc_id_to_host_state(&alloc_key, create_key) {
1158 tracing::error!("failed to add proc with create key {} host state: {}", create_key, e);
1159 }
1160 proc_state
1161 }
1162 ProcState::Stopped{ ref create_key, ..} => {
1163 if let Err(e) = self.remove_proc_from_host_state(&alloc_key, create_key) {
1164 tracing::error!("failed to remove proc with create key {} host state: {}", create_key, e);
1165 }
1166 proc_state
1167 }
1168 ProcState::Failed { ref alloc_name, ref description } => {
1169 match self.get_host_state_mut(&alloc_key) {
1170 Ok(state) => {
1171 state.failed = true;
1172 ProcState::Failed {
1173 alloc_name: alloc_name.clone(),
1174 description: format!("host {} failed: {}", state.host_id, description),
1175 }
1176 }
1177 Err(e) => {
1178 tracing::error!("failed to find host state for alloc: {}: {}", alloc_name, e);
1179 proc_state
1180 }
1181 }
1182 }
1183 _ => proc_state
1184 };
1185
1186 break Some((Some(alloc_key), update));
1187 }
1188 Ok(RemoteProcessProcStateMessage::Done(alloc_key)) => {
1189 tracing::info!("allocator {} is done", alloc_key);
1190
1191 if let Ok(state) = self.remove_host_state(&alloc_key) {
1192 if !state.active_procs.is_empty() {
1193 tracing::error!("received done for alloc {} with active procs: {:?}", alloc_key, state.active_procs);
1194 }
1195 } else {
1196 tracing::error!("received done for unknown alloc {}", alloc_key);
1197 }
1198
1199 if self.host_states.is_empty() {
1200 self.running = false;
1201 break None;
1202 }
1203 }
1204 Ok(RemoteProcessProcStateMessage::HeartBeat) => {}
1208 Err(e) => {
1209 break Some((None, ProcState::Failed {alloc_name: self.alloc_name.clone(), description: format!("error receiving events: {}", e)}));
1210 }
1211 }
1212 }
1213
1214 _ = tokio::time::sleep_until(heartbeat_time) => {
1215 self.host_states.iter().for_each(|(_, host_state)| host_state.tx.post(RemoteProcessAllocatorMessage::HeartBeat));
1216 heartbeat_time = tokio::time::Instant::now() + heartbeat_interval;
1217 }
1218
1219 closed_host_id = self.comm_watcher_rx.recv() => {
1220 if let Some(closed_host_id) = closed_host_id {
1221 tracing::debug!("host {} channel closed, cleaning up", closed_host_id);
1222 if let Some(state) = self.host_states.get(&closed_host_id)
1223 && !state.allocated {
1224 break Some((None, ProcState::Failed {
1225 alloc_name: self.alloc_name.clone(),
1226 description: format!(
1227 "no process has ever been allocated on {}, connection is not established; \
1228 a common issue could be allocator port mismatch, or transport not enabled on client",
1229 closed_host_id
1230 )}));
1231 }
1232 let create_keys = match self.cleanup_host_channel_closed(closed_host_id) {
1233 Ok(create_keys) => create_keys,
1234 Err(err) => {
1235 tracing::error!("failed to cleanup disconnected host: {}", err);
1236 continue;
1237 }
1238 };
1239 for create_key in create_keys {
1240 tracing::debug!("queuing Stopped state for proc with create key {}", create_key);
1241 self.event_queue.push_back(
1242 ProcState::Stopped {
1243 create_key,
1244 reason: ProcStopReason::HostWatchdog
1245 }
1246 );
1247 }
1248 if self.host_states.is_empty() {
1250 tracing::info!("no more hosts left, stopping the alloc");
1251 self.running = false;
1252 }
1253 reloop = true;
1255 break None;
1256 } else {
1257 tracing::warn!("unexpected comm watcher channel close");
1259 break None;
1260 }
1261 }
1262 }
1263 };
1264
1265 if reloop {
1266 assert!(update.is_none());
1269 continue;
1270 }
1271
1272 break match update {
1273 Some((
1274 Some(alloc_key),
1275 ProcState::Created {
1276 create_key,
1277 point,
1278 pid,
1279 },
1280 )) => match self.project_proc_into_global_extent(&alloc_key, &point) {
1281 Ok(global_point) => {
1282 tracing::debug!("reprojected coords: {} -> {}", point, global_point);
1283 Some(ProcState::Created {
1284 create_key,
1285 point: global_point,
1286 pid,
1287 })
1288 }
1289 Err(e) => {
1290 tracing::error!(
1291 "failed to project coords for proc: {}.{}: {}",
1292 alloc_key,
1293 create_key,
1294 e
1295 );
1296 None
1297 }
1298 },
1299 Some((None, ProcState::Created { .. })) => {
1300 panic!("illegal state: missing alloc_key for ProcState::Created event")
1301 }
1302 Some((_, update)) => {
1303 if let ProcState::Failed { description, .. } = &update {
1304 tracing::error!(description);
1305 self.failed = true;
1306 }
1307 Some(update)
1308 }
1309 None => None,
1310 };
1311 }
1312 }
1313
1314 fn spec(&self) -> &AllocSpec {
1315 &self.spec
1316 }
1317
1318 fn extent(&self) -> &Extent {
1319 &self.spec.extent
1320 }
1321
1322 fn alloc_name(&self) -> &AllocName {
1323 &self.alloc_name
1324 }
1325
1326 async fn stop(&mut self) -> Result<(), AllocatorError> {
1327 tracing::info!("stopping alloc");
1328
1329 for (host_id, task_state) in self.host_states.iter_mut() {
1330 tracing::debug!("stopping alloc at host {}", host_id);
1331 task_state.tx.post(RemoteProcessAllocatorMessage::Stop);
1332 }
1333
1334 Ok(())
1335 }
1336
1337 fn client_router_addr(&self) -> ChannelAddr {
1346 with_unspecified_port_or_any(&self.bootstrap_addr)
1347 }
1348}
1349
1350impl Drop for RemoteProcessAlloc {
1351 fn drop(&mut self) {
1352 tracing::debug!(
1353 "dropping RemoteProcessAlloc of alloc_name {}",
1354 self.alloc_name
1355 );
1356 }
1357}
1358
1359#[cfg(test)]
1360mod test {
1361 use std::assert_matches::assert_matches;
1362
1363 use hyperactor::channel::ChannelRx;
1364 use hyperactor::reference as hyperactor_reference;
1365 use hyperactor::testing::ids::test_proc_id;
1366 use ndslice::extent;
1367 use tokio::sync::oneshot;
1368
1369 use super::*;
1370 use crate::alloc::ChannelTransport;
1371 use crate::alloc::MockAlloc;
1372 use crate::alloc::MockAllocWrapper;
1373 use crate::alloc::MockAllocator;
1374 use crate::alloc::ProcStopReason;
1375 use crate::alloc::with_unspecified_port_or_any;
1376 use crate::proc_agent::ProcAgent;
1377
1378 async fn read_all_created(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1379 let mut i: usize = 0;
1380 while i < alloc_len {
1381 let m = rx.recv().await.unwrap();
1382 match m {
1383 RemoteProcessProcStateMessage::Update(_, ProcState::Created { .. }) => i += 1,
1384 RemoteProcessProcStateMessage::HeartBeat => {}
1385 _ => panic!("unexpected message: {:?}", m),
1386 }
1387 }
1388 }
1389
1390 async fn read_all_running(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1391 let mut i: usize = 0;
1392 while i < alloc_len {
1393 let m = rx.recv().await.unwrap();
1394 match m {
1395 RemoteProcessProcStateMessage::Update(_, ProcState::Running { .. }) => i += 1,
1396 RemoteProcessProcStateMessage::HeartBeat => {}
1397 _ => panic!("unexpected message: {:?}", m),
1398 }
1399 }
1400 }
1401
1402 async fn read_all_stopped(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1403 let mut i: usize = 0;
1404 while i < alloc_len {
1405 let m = rx.recv().await.unwrap();
1406 match m {
1407 RemoteProcessProcStateMessage::Update(_, ProcState::Stopped { .. }) => i += 1,
1408 RemoteProcessProcStateMessage::HeartBeat => {}
1409 _ => panic!("unexpected message: {:?}", m),
1410 }
1411 }
1412 }
1413
1414 fn set_procstate_expectations(alloc: &mut MockAlloc, extent: Extent) {
1415 alloc.expect_extent().return_const(extent.clone());
1416 let mut create_keys = Vec::new();
1417 for point in extent.points() {
1418 let create_key = ShortUuid::generate();
1419 create_keys.push(create_key.clone());
1420 alloc.expect_next().times(1).return_once(move || {
1421 Some(ProcState::Created {
1422 create_key: create_key.clone(),
1423 point,
1424 pid: 0,
1425 })
1426 });
1427 }
1428 for (i, create_key) in create_keys
1429 .iter()
1430 .take(extent.num_ranks())
1431 .cloned()
1432 .enumerate()
1433 {
1434 let proc_id = test_proc_id(&format!("{i}"));
1435 let mesh_agent = hyperactor_reference::ActorRef::<ProcAgent>::attest(
1436 proc_id.actor_id("mesh_agent", i),
1437 );
1438 alloc.expect_next().times(1).return_once(move || {
1439 Some(ProcState::Running {
1440 create_key,
1441 proc_id,
1442 addr: ChannelAddr::Unix("/proc0".parse().unwrap()),
1443 mesh_agent,
1444 })
1445 });
1446 }
1447 for create_key in create_keys.iter().take(extent.num_ranks()).cloned() {
1448 alloc.expect_next().times(1).return_once(|| {
1449 Some(ProcState::Stopped {
1450 create_key,
1451 reason: ProcStopReason::Unknown,
1452 })
1453 });
1454 }
1455 }
1456
1457 #[timed_test::async_timed_test(timeout_secs = 60)]
1458 async fn test_simple() {
1459 let config = hyperactor_config::global::lock();
1460 let _guard = config.override_key(
1461 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1462 Duration::from_millis(100),
1463 );
1464 hyperactor_telemetry::initialize_logging_for_test();
1465 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1466 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1467 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1468
1469 let extent = extent!(host = 1, gpu = 2);
1470 let tx = channel::dial(serve_addr.clone()).unwrap();
1471
1472 let alloc_name: AllocName = AllocName("test_alloc_name".to_string());
1473 let mut alloc = MockAlloc::new();
1474 alloc.expect_alloc_name().return_const(alloc_name.clone());
1475 alloc.expect_extent().return_const(extent.clone());
1476
1477 set_procstate_expectations(&mut alloc, extent.clone());
1478
1479 alloc.expect_next().return_const(None);
1481
1482 let mut allocator = MockAllocator::new();
1483 let total_messages = extent.num_ranks() * 3 + 1;
1484 let mock_wrapper = MockAllocWrapper::new_block_next(
1485 alloc,
1486 total_messages,
1488 );
1489 allocator
1490 .expect_allocate()
1491 .times(1)
1492 .return_once(move |_| Ok(mock_wrapper));
1493
1494 let remote_allocator = RemoteProcessAllocator::new();
1495 let handle = tokio::spawn({
1496 let remote_allocator = remote_allocator.clone();
1497 async move {
1498 remote_allocator
1499 .start_with_allocator(serve_addr, allocator, None)
1500 .await
1501 }
1502 });
1503
1504 let alloc_key = ShortUuid::generate();
1505
1506 tx.send(RemoteProcessAllocatorMessage::Allocate {
1507 alloc_key: alloc_key.clone(),
1508 extent: extent.clone(),
1509 bootstrap_addr,
1510 hosts: vec![],
1511 client_context: None,
1512 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1513 })
1514 .await
1515 .unwrap();
1516
1517 let m = rx.recv().await.unwrap();
1519 assert_matches!(
1520 m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, alloc_name: got_alloc_name }
1521 if got_alloc_name == alloc_name && got_alloc_key == alloc_key
1522 );
1523
1524 let mut rank: usize = 0;
1526 let mut create_keys = Vec::with_capacity(extent.num_ranks());
1527 while rank < extent.num_ranks() {
1528 let m = rx.recv().await.unwrap();
1529 match m {
1530 RemoteProcessProcStateMessage::Update(
1531 got_alloc_key,
1532 ProcState::Created {
1533 create_key, point, ..
1534 },
1535 ) => {
1536 let expected_point = extent.point_of_rank(rank).unwrap();
1537 assert_eq!(got_alloc_key, alloc_key);
1538 assert_eq!(point, expected_point);
1539 create_keys.push(create_key);
1540 rank += 1;
1541 }
1542 RemoteProcessProcStateMessage::HeartBeat => {}
1543 _ => panic!("unexpected message: {:?}", m),
1544 }
1545 }
1546 let mut rank: usize = 0;
1548 while rank < extent.num_ranks() {
1549 let m = rx.recv().await.unwrap();
1550 match m {
1551 RemoteProcessProcStateMessage::Update(
1552 got_alloc_key,
1553 ProcState::Running {
1554 create_key,
1555 proc_id,
1556 mesh_agent,
1557 addr: _,
1558 },
1559 ) => {
1560 assert_eq!(got_alloc_key, alloc_key);
1561 assert_eq!(create_key, create_keys[rank]);
1562 let expected_proc_id = test_proc_id(&format!("{}", rank));
1563 let expected_mesh_agent = hyperactor_reference::ActorRef::<ProcAgent>::attest(
1564 expected_proc_id.actor_id("mesh_agent", rank),
1565 );
1566 assert_eq!(proc_id, expected_proc_id);
1567 assert_eq!(mesh_agent, expected_mesh_agent);
1568 rank += 1;
1569 }
1570 RemoteProcessProcStateMessage::HeartBeat => {}
1571 _ => panic!("unexpected message: {:?}", m),
1572 }
1573 }
1574 let mut rank: usize = 0;
1576 while rank < extent.num_ranks() {
1577 let m = rx.recv().await.unwrap();
1578 match m {
1579 RemoteProcessProcStateMessage::Update(
1580 got_alloc_key,
1581 ProcState::Stopped {
1582 create_key,
1583 reason: ProcStopReason::Unknown,
1584 },
1585 ) => {
1586 assert_eq!(got_alloc_key, alloc_key);
1587 assert_eq!(create_key, create_keys[rank]);
1588 rank += 1;
1589 }
1590 RemoteProcessProcStateMessage::HeartBeat => {}
1591 _ => panic!("unexpected message: {:?}", m),
1592 }
1593 }
1594 loop {
1596 let m = rx.recv().await.unwrap();
1597 match m {
1598 RemoteProcessProcStateMessage::Done(got_alloc_key) => {
1599 assert_eq!(got_alloc_key, alloc_key);
1600 break;
1601 }
1602 RemoteProcessProcStateMessage::HeartBeat => {}
1603 _ => panic!("unexpected message: {:?}", m),
1604 }
1605 }
1606
1607 remote_allocator.terminate();
1608 handle.await.unwrap().unwrap();
1609 }
1610
1611 #[timed_test::async_timed_test(timeout_secs = 60)]
1612 #[cfg_attr(not(fbcode_build), ignore)]
1616 async fn test_normal_stop() {
1617 let config = hyperactor_config::global::lock();
1618 let _guard = config.override_key(
1619 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1620 Duration::from_millis(100),
1621 );
1622 hyperactor_telemetry::initialize_logging_for_test();
1623 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1624 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1625 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1626
1627 let extent = extent!(host = 1, gpu = 2);
1628 let tx = channel::dial(serve_addr.clone()).unwrap();
1629
1630 let alloc_name: AllocName = AllocName("test_alloc_name".to_string());
1631 let mut alloc = MockAllocWrapper::new_block_next(
1632 MockAlloc::new(),
1633 extent.num_ranks() * 2,
1635 );
1636 let next_tx = alloc.notify_tx();
1637 alloc
1638 .alloc
1639 .expect_alloc_name()
1640 .return_const(alloc_name.clone());
1641 alloc.alloc.expect_extent().return_const(extent.clone());
1642
1643 set_procstate_expectations(&mut alloc.alloc, extent.clone());
1644
1645 alloc.alloc.expect_next().return_const(None);
1646 alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1647
1648 let mut allocator = MockAllocator::new();
1649 allocator
1650 .expect_allocate()
1651 .times(1)
1652 .return_once(|_| Ok(alloc));
1653
1654 let remote_allocator = RemoteProcessAllocator::new();
1655 let handle = tokio::spawn({
1656 let remote_allocator = remote_allocator.clone();
1657 async move {
1658 remote_allocator
1659 .start_with_allocator(serve_addr, allocator, None)
1660 .await
1661 }
1662 });
1663
1664 let alloc_key = ShortUuid::generate();
1665 tx.send(RemoteProcessAllocatorMessage::Allocate {
1666 alloc_key: alloc_key.clone(),
1667 extent: extent.clone(),
1668 bootstrap_addr,
1669 hosts: vec![],
1670 client_context: None,
1671 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1672 })
1673 .await
1674 .unwrap();
1675
1676 let m = rx.recv().await.unwrap();
1678 assert_matches!(
1679 m,
1680 RemoteProcessProcStateMessage::Allocated { alloc_name: got_alloc_name, alloc_key: got_alloc_key }
1681 if alloc_name == got_alloc_name && alloc_key == got_alloc_key
1682 );
1683
1684 read_all_created(&mut rx, extent.num_ranks()).await;
1685 read_all_running(&mut rx, extent.num_ranks()).await;
1686
1687 tracing::info!("stopping allocation");
1689 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1690 next_tx.send(()).unwrap();
1692
1693 read_all_stopped(&mut rx, extent.num_ranks()).await;
1694
1695 remote_allocator.terminate();
1696 handle.await.unwrap().unwrap();
1697 }
1698
1699 #[timed_test::async_timed_test(timeout_secs = 60)]
1700 #[cfg_attr(not(fbcode_build), ignore)]
1704 async fn test_realloc() {
1705 let config = hyperactor_config::global::lock();
1706 let _guard = config.override_key(
1707 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1708 Duration::from_millis(100),
1709 );
1710 hyperactor_telemetry::initialize_logging_for_test();
1711 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1712 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1713 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1714
1715 let extent = extent!(host = 1, gpu = 2);
1716
1717 let tx = channel::dial(serve_addr.clone()).unwrap();
1718
1719 let alloc_name: AllocName = AllocName("test_alloc_name".to_string());
1720 let mut alloc1 = MockAllocWrapper::new_block_next(
1721 MockAlloc::new(),
1722 extent.num_ranks() * 2,
1724 );
1725 let next_tx1 = alloc1.notify_tx();
1726 alloc1
1727 .alloc
1728 .expect_alloc_name()
1729 .return_const(alloc_name.clone());
1730 alloc1.alloc.expect_extent().return_const(extent.clone());
1731
1732 set_procstate_expectations(&mut alloc1.alloc, extent.clone());
1733 alloc1.alloc.expect_next().return_const(None);
1734 alloc1.alloc.expect_stop().times(1).return_once(|| Ok(()));
1735 let mut alloc2 = MockAllocWrapper::new_block_next(
1737 MockAlloc::new(),
1738 extent.num_ranks() * 2,
1740 );
1741 let next_tx2 = alloc2.notify_tx();
1742 alloc2
1743 .alloc
1744 .expect_alloc_name()
1745 .return_const(alloc_name.clone());
1746 alloc2.alloc.expect_extent().return_const(extent.clone());
1747 set_procstate_expectations(&mut alloc2.alloc, extent.clone());
1748 alloc2.alloc.expect_next().return_const(None);
1749 alloc2.alloc.expect_stop().times(1).return_once(|| Ok(()));
1750
1751 let mut allocator = MockAllocator::new();
1752 allocator
1753 .expect_allocate()
1754 .times(1)
1755 .return_once(|_| Ok(alloc1));
1756 allocator
1758 .expect_allocate()
1759 .times(1)
1760 .return_once(|_| Ok(alloc2));
1761
1762 let remote_allocator = RemoteProcessAllocator::new();
1763 let handle = tokio::spawn({
1764 let remote_allocator = remote_allocator.clone();
1765 async move {
1766 remote_allocator
1767 .start_with_allocator(serve_addr, allocator, None)
1768 .await
1769 }
1770 });
1771
1772 let alloc_key = ShortUuid::generate();
1773
1774 tx.send(RemoteProcessAllocatorMessage::Allocate {
1775 alloc_key: alloc_key.clone(),
1776 extent: extent.clone(),
1777 bootstrap_addr: bootstrap_addr.clone(),
1778 hosts: vec![],
1779 client_context: None,
1780 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1781 })
1782 .await
1783 .unwrap();
1784
1785 let m = rx.recv().await.unwrap();
1787 assert_matches!(
1788 m,
1789 RemoteProcessProcStateMessage::Allocated { alloc_name: got_alloc_name, alloc_key: got_alloc_key }
1790 if got_alloc_name == alloc_name && got_alloc_key == alloc_key
1791 );
1792
1793 read_all_created(&mut rx, extent.num_ranks()).await;
1794 read_all_running(&mut rx, extent.num_ranks()).await;
1795
1796 let alloc_key = ShortUuid::generate();
1797
1798 tx.send(RemoteProcessAllocatorMessage::Allocate {
1800 alloc_key: alloc_key.clone(),
1801 extent: extent.clone(),
1802 bootstrap_addr,
1803 hosts: vec![],
1804 client_context: None,
1805 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1806 })
1807 .await
1808 .unwrap();
1809 next_tx1.send(()).unwrap();
1811 read_all_stopped(&mut rx, extent.num_ranks()).await;
1813 let m = rx.recv().await.unwrap();
1814 assert_matches!(m, RemoteProcessProcStateMessage::Done(_));
1815 let m = rx.recv().await.unwrap();
1816 assert_matches!(
1817 m,
1818 RemoteProcessProcStateMessage::Allocated { alloc_name: got_alloc_name, alloc_key: got_alloc_key }
1819 if got_alloc_name == alloc_name && got_alloc_key == alloc_key
1820 );
1821 read_all_created(&mut rx, extent.num_ranks()).await;
1823 read_all_running(&mut rx, extent.num_ranks()).await;
1824 tracing::info!("stopping allocation");
1826 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1827 next_tx2.send(()).unwrap();
1829
1830 read_all_stopped(&mut rx, extent.num_ranks()).await;
1831
1832 remote_allocator.terminate();
1833 handle.await.unwrap().unwrap();
1834 }
1835
1836 #[timed_test::async_timed_test(timeout_secs = 60)]
1837 async fn test_upstream_closed() {
1838 let config = hyperactor_config::global::lock();
1840 let _guard1 = config.override_key(
1841 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1842 Duration::from_secs(1),
1843 );
1844 let _guard2 = config.override_key(
1845 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1846 Duration::from_millis(100),
1847 );
1848
1849 hyperactor_telemetry::initialize_logging_for_test();
1850 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1851 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1852 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1853
1854 let extent = extent!(host = 1, gpu = 2);
1855
1856 let tx = channel::dial(serve_addr.clone()).unwrap();
1857
1858 let alloc_name: AllocName = AllocName("test_alloc_name".to_string());
1859 let mut alloc = MockAllocWrapper::new_block_next(
1860 MockAlloc::new(),
1861 extent.num_ranks() * 2,
1863 );
1864 let next_tx = alloc.notify_tx();
1865 alloc
1866 .alloc
1867 .expect_alloc_name()
1868 .return_const(alloc_name.clone());
1869 alloc.alloc.expect_extent().return_const(extent.clone());
1870
1871 set_procstate_expectations(&mut alloc.alloc, extent.clone());
1872
1873 alloc.alloc.expect_next().return_const(None);
1874 let (stop_tx, stop_rx) = oneshot::channel();
1877 alloc.alloc.expect_stop().times(1).return_once(|| {
1878 stop_tx.send(()).unwrap();
1879 Ok(())
1880 });
1881
1882 let mut allocator = MockAllocator::new();
1883 allocator
1884 .expect_allocate()
1885 .times(1)
1886 .return_once(|_| Ok(alloc));
1887
1888 let remote_allocator = RemoteProcessAllocator::new();
1889 let handle = tokio::spawn({
1890 let remote_allocator = remote_allocator.clone();
1891 async move {
1892 remote_allocator
1893 .start_with_allocator(serve_addr, allocator, None)
1894 .await
1895 }
1896 });
1897
1898 let alloc_key = ShortUuid::generate();
1899
1900 tx.send(RemoteProcessAllocatorMessage::Allocate {
1901 alloc_key: alloc_key.clone(),
1902 extent: extent.clone(),
1903 bootstrap_addr,
1904 hosts: vec![],
1905 client_context: None,
1906 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1907 })
1908 .await
1909 .unwrap();
1910
1911 let m = rx.recv().await.unwrap();
1913 assert_matches!(
1914 m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, alloc_name: got_alloc_name }
1915 if got_alloc_name == alloc_name && got_alloc_key == alloc_key
1916 );
1917
1918 read_all_created(&mut rx, extent.num_ranks()).await;
1919 read_all_running(&mut rx, extent.num_ranks()).await;
1920
1921 tracing::info!("closing upstream");
1923 drop(rx);
1924 tokio::time::sleep(Duration::from_secs(2)).await;
1926 stop_rx.await.unwrap();
1928 next_tx.send(()).unwrap();
1930 remote_allocator.terminate();
1931 handle.await.unwrap().unwrap();
1932 }
1933
1934 #[timed_test::async_timed_test(timeout_secs = 60)]
1935 async fn test_inner_alloc_failure() {
1936 let config = hyperactor_config::global::lock();
1937 let _guard = config.override_key(
1938 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1939 Duration::from_mins(1),
1940 );
1941 hyperactor_telemetry::initialize_logging_for_test();
1942 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1943 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1944 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1945
1946 let extent = extent!(host = 1, gpu = 2);
1947
1948 let tx = channel::dial(serve_addr.clone()).unwrap();
1949
1950 let test_alloc_name: AllocName = AllocName("test_alloc_name".to_string());
1951 let mut alloc = MockAllocWrapper::new_block_next(
1952 MockAlloc::new(),
1953 1,
1955 );
1956 let next_tx = alloc.notify_tx();
1957 alloc
1958 .alloc
1959 .expect_alloc_name()
1960 .return_const(test_alloc_name.clone());
1961 alloc.alloc.expect_extent().return_const(extent.clone());
1962 alloc
1963 .alloc
1964 .expect_next()
1965 .times(1)
1966 .return_const(Some(ProcState::Failed {
1967 alloc_name: test_alloc_name.clone(),
1968 description: "test".to_string(),
1969 }));
1970 alloc.alloc.expect_next().times(1).return_const(None);
1971
1972 alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1973
1974 let mut allocator = MockAllocator::new();
1975 allocator
1976 .expect_allocate()
1977 .times(1)
1978 .return_once(|_| Ok(alloc));
1979
1980 let remote_allocator = RemoteProcessAllocator::new();
1981 let handle = tokio::spawn({
1982 let remote_allocator = remote_allocator.clone();
1983 async move {
1984 remote_allocator
1985 .start_with_allocator(serve_addr, allocator, None)
1986 .await
1987 }
1988 });
1989
1990 let alloc_key = ShortUuid::generate();
1991 tx.send(RemoteProcessAllocatorMessage::Allocate {
1992 alloc_key: alloc_key.clone(),
1993 extent: extent.clone(),
1994 bootstrap_addr,
1995 hosts: vec![],
1996 client_context: None,
1997 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1998 })
1999 .await
2000 .unwrap();
2001
2002 let m = rx.recv().await.unwrap();
2004 assert_matches!(
2005 m,
2006 RemoteProcessProcStateMessage::Allocated { alloc_name: got_alloc_name, alloc_key: got_alloc_key }
2007 if test_alloc_name == got_alloc_name && alloc_key == got_alloc_key
2008 );
2009
2010 let m = rx.recv().await.unwrap();
2012 assert_matches!(
2013 m,
2014 RemoteProcessProcStateMessage::Update(
2015 got_alloc_key,
2016 ProcState::Failed { alloc_name, description }
2017 ) if got_alloc_key == alloc_key && alloc_name == test_alloc_name && description == "test"
2018 );
2019
2020 tracing::info!("stopping allocation");
2021 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
2022 next_tx.send(()).unwrap();
2024 let m = rx.recv().await.unwrap();
2026 assert_matches!(
2027 m,
2028 RemoteProcessProcStateMessage::Done(got_alloc_key)
2029 if got_alloc_key == alloc_key
2030 );
2031
2032 remote_allocator.terminate();
2033 handle.await.unwrap().unwrap();
2034 }
2035
2036 #[timed_test::async_timed_test(timeout_secs = 60)]
2037 async fn test_trace_id_propagation() {
2038 let config = hyperactor_config::global::lock();
2039 let _guard = config.override_key(
2040 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2041 Duration::from_mins(1),
2042 );
2043 hyperactor_telemetry::initialize_logging(hyperactor_telemetry::DefaultTelemetryClock {});
2044 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
2045 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
2046 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
2047
2048 let extent = extent!(host = 1, gpu = 1);
2049 let tx = channel::dial(serve_addr.clone()).unwrap();
2050 let test_alloc_name: AllocName = AllocName("test_alloc_name".to_string());
2051 let test_trace_id = "test_trace_id_12345";
2052
2053 let mut alloc = MockAlloc::new();
2055 alloc
2056 .expect_alloc_name()
2057 .return_const(test_alloc_name.clone());
2058 alloc.expect_extent().return_const(extent.clone());
2059 alloc.expect_next().return_const(None);
2060
2061 let mut allocator = MockAllocator::new();
2063 allocator
2064 .expect_allocate()
2065 .times(1)
2066 .withf(move |spec: &AllocSpec| {
2067 spec.constraints
2069 .match_labels
2070 .get(CLIENT_TRACE_ID_LABEL)
2071 .is_some_and(|trace_id| trace_id == test_trace_id)
2072 })
2073 .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
2074
2075 let remote_allocator = RemoteProcessAllocator::new();
2076 let handle = tokio::spawn({
2077 let remote_allocator = remote_allocator.clone();
2078 async move {
2079 remote_allocator
2080 .start_with_allocator(serve_addr, allocator, None)
2081 .await
2082 }
2083 });
2084
2085 let alloc_key = ShortUuid::generate();
2086 tx.send(RemoteProcessAllocatorMessage::Allocate {
2087 alloc_key: alloc_key.clone(),
2088 extent: extent.clone(),
2089 bootstrap_addr,
2090 hosts: vec![],
2091 client_context: Some(ClientContext {
2092 trace_id: test_trace_id.to_string(),
2093 }),
2094 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
2095 })
2096 .await
2097 .unwrap();
2098
2099 let m = rx.recv().await.unwrap();
2101 assert_matches!(
2102 m,
2103 RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, alloc_name: got_alloc_name }
2104 if got_alloc_name == test_alloc_name && got_alloc_key == alloc_key
2105 );
2106
2107 let m = rx.recv().await.unwrap();
2109 assert_matches!(
2110 m,
2111 RemoteProcessProcStateMessage::Done(got_alloc_key)
2112 if alloc_key == got_alloc_key
2113 );
2114
2115 remote_allocator.terminate();
2116 handle.await.unwrap().unwrap();
2117 }
2118
2119 #[timed_test::async_timed_test(timeout_secs = 60)]
2120 async fn test_trace_id_propagation_no_client_context() {
2121 let config = hyperactor_config::global::lock();
2122 let _guard = config.override_key(
2123 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2124 Duration::from_mins(1),
2125 );
2126 hyperactor_telemetry::initialize_logging(hyperactor_telemetry::DefaultTelemetryClock {});
2127 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
2128 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
2129 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
2130
2131 let extent = extent!(host = 1, gpu = 1);
2132 let tx = channel::dial(serve_addr.clone()).unwrap();
2133 let test_alloc_name: AllocName = AllocName("test_alloc_name".to_string());
2134
2135 let mut alloc = MockAlloc::new();
2137 alloc
2138 .expect_alloc_name()
2139 .return_const(test_alloc_name.clone());
2140 alloc.expect_extent().return_const(extent.clone());
2141 alloc.expect_next().return_const(None);
2142
2143 let mut allocator = MockAllocator::new();
2145 allocator
2146 .expect_allocate()
2147 .times(1)
2148 .withf(move |spec: &AllocSpec| {
2149 spec.constraints.match_labels.is_empty()
2151 })
2152 .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
2153
2154 let remote_allocator = RemoteProcessAllocator::new();
2155 let handle = tokio::spawn({
2156 let remote_allocator = remote_allocator.clone();
2157 async move {
2158 remote_allocator
2159 .start_with_allocator(serve_addr, allocator, None)
2160 .await
2161 }
2162 });
2163
2164 let alloc_key = ShortUuid::generate();
2165 tx.send(RemoteProcessAllocatorMessage::Allocate {
2166 alloc_key: alloc_key.clone(),
2167 extent: extent.clone(),
2168 bootstrap_addr,
2169 hosts: vec![],
2170 client_context: None,
2171 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
2172 })
2173 .await
2174 .unwrap();
2175
2176 let m = rx.recv().await.unwrap();
2178 assert_matches!(
2179 m,
2180 RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, alloc_name: got_alloc_name }
2181 if got_alloc_name == test_alloc_name && got_alloc_key == alloc_key
2182 );
2183
2184 let m = rx.recv().await.unwrap();
2186 assert_matches!(
2187 m,
2188 RemoteProcessProcStateMessage::Done(got_alloc_key)
2189 if got_alloc_key == alloc_key
2190 );
2191
2192 remote_allocator.terminate();
2193 handle.await.unwrap().unwrap();
2194 }
2195}
2196
2197#[cfg(test)]
2198mod test_alloc {
2199 use std::os::unix::process::ExitStatusExt;
2200
2201 use hyperactor_config;
2202 use ndslice::extent;
2203 use nix::sys::signal;
2204 use nix::unistd::Pid;
2205 use timed_test::async_timed_test;
2206
2207 use super::*;
2208
2209 #[async_timed_test(timeout_secs = 60)]
2210 #[cfg(fbcode_build)]
2211 async fn test_alloc_simple() {
2212 let config = hyperactor_config::global::lock();
2214 let _guard1 = config.override_key(
2215 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2216 Duration::from_secs(1),
2217 );
2218 let _guard2 = config.override_key(
2219 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2220 Duration::from_millis(100),
2221 );
2222 hyperactor_telemetry::initialize_logging(hyperactor_telemetry::DefaultTelemetryClock {});
2223
2224 let spec = AllocSpec {
2225 extent: extent!(host = 2, gpu = 2),
2226 constraints: Default::default(),
2227 proc_name: None,
2228 transport: ChannelTransport::Unix,
2229 proc_allocation_mode: Default::default(),
2230 };
2231 let alloc_name = AllocName("test_alloc_name".to_string());
2232
2233 let task1_allocator = RemoteProcessAllocator::new();
2234 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2235 let task1_addr_string = task1_addr.to_string();
2236 let task1_cmd = Command::new(crate::testresource::get(
2237 "monarch/hyperactor_mesh/bootstrap",
2238 ));
2239 let task2_allocator = RemoteProcessAllocator::new();
2240 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2241 let task2_addr_string = task2_addr.to_string();
2242 let task2_cmd = Command::new(crate::testresource::get(
2243 "monarch/hyperactor_mesh/bootstrap",
2244 ));
2245 let task1_allocator_copy = task1_allocator.clone();
2246 let task1_allocator_handle = tokio::spawn(async move {
2247 tracing::info!("spawning task1");
2248 task1_allocator_copy
2249 .start(task1_cmd, task1_addr, None)
2250 .await
2251 .unwrap();
2252 });
2253 let task2_allocator_copy = task2_allocator.clone();
2254 let task2_allocator_handle = tokio::spawn(async move {
2255 task2_allocator_copy
2256 .start(task2_cmd, task2_addr, None)
2257 .await
2258 .unwrap();
2259 });
2260
2261 let mut initializer = MockRemoteProcessAllocInitializer::new();
2262 initializer.expect_initialize_alloc().return_once(move || {
2263 Ok(vec![
2264 RemoteProcessAllocHost {
2265 hostname: task1_addr_string,
2266 id: "task1".to_string(),
2267 },
2268 RemoteProcessAllocHost {
2269 hostname: task2_addr_string,
2270 id: "task2".to_string(),
2271 },
2272 ])
2273 });
2274 let mut alloc = RemoteProcessAlloc::new(spec.clone(), alloc_name, 0, initializer)
2275 .await
2276 .unwrap();
2277 let mut created = HashSet::new();
2278 let mut running_procs = HashSet::new();
2279 let mut proc_points = HashSet::new();
2280 for _ in 0..spec.extent.num_ranks() * 2 {
2281 let proc_state = alloc.next().await.unwrap();
2282 tracing::debug!("test got message: {:?}", proc_state);
2283 match proc_state {
2284 ProcState::Created {
2285 create_key, point, ..
2286 } => {
2287 created.insert(create_key);
2288 proc_points.insert(point);
2289 }
2290 ProcState::Running { create_key, .. } => {
2291 assert!(created.remove(&create_key));
2292 running_procs.insert(create_key);
2293 }
2294 _ => panic!("expected Created or Running"),
2295 }
2296 }
2297 assert!(created.is_empty());
2298 assert!(
2300 spec.extent
2301 .points()
2302 .all(|point| proc_points.contains(&point))
2303 );
2304
2305 let timeout = tokio::time::Instant::now() + std::time::Duration::from_millis(1000);
2307 tokio::select! {
2308 _ = tokio::time::sleep_until(timeout) => {},
2309 _ = alloc.next() => panic!("expected no more items"),
2310 }
2311
2312 alloc.stop().await.unwrap();
2314 for _ in 0..spec.extent.num_ranks() {
2315 let proc_state = alloc.next().await.unwrap();
2316 tracing::info!("test received next proc_state: {:?}", proc_state);
2317 match proc_state {
2318 ProcState::Stopped {
2319 create_key, reason, ..
2320 } => {
2321 assert!(running_procs.remove(&create_key));
2322 assert_eq!(reason, ProcStopReason::Stopped);
2323 }
2324 _ => panic!("expected stopped"),
2325 }
2326 }
2327 let proc_state = alloc.next().await;
2329 assert!(proc_state.is_none());
2330 let proc_state = alloc.next().await;
2332 assert!(proc_state.is_none());
2333
2334 task1_allocator.terminate();
2335 task1_allocator_handle.await.unwrap();
2336 task2_allocator.terminate();
2337 task2_allocator_handle.await.unwrap();
2338 }
2339
2340 #[async_timed_test(timeout_secs = 60)]
2341 #[cfg(fbcode_build)]
2342 async fn test_alloc_host_failure() {
2343 let config = hyperactor_config::global::lock();
2345 let _guard1 = config.override_key(
2346 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2347 Duration::from_secs(1),
2348 );
2349 let _guard2 = config.override_key(
2350 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2351 Duration::from_millis(100),
2352 );
2353 hyperactor_telemetry::initialize_logging(hyperactor_telemetry::DefaultTelemetryClock {});
2354
2355 let spec = AllocSpec {
2356 extent: extent!(host = 2, gpu = 2),
2357 constraints: Default::default(),
2358 proc_name: None,
2359 transport: ChannelTransport::Unix,
2360 proc_allocation_mode: Default::default(),
2361 };
2362 let alloc_name = AllocName("test_alloc_name".to_string());
2363
2364 let task1_allocator = RemoteProcessAllocator::new();
2365 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2366 let task1_addr_string = task1_addr.to_string();
2367 let task1_cmd = Command::new(crate::testresource::get(
2368 "monarch/hyperactor_mesh/bootstrap",
2369 ));
2370 let task2_allocator = RemoteProcessAllocator::new();
2371 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2372 let task2_addr_string = task2_addr.to_string();
2373 let task2_cmd = Command::new(crate::testresource::get(
2374 "monarch/hyperactor_mesh/bootstrap",
2375 ));
2376 let task1_allocator_copy = task1_allocator.clone();
2377 let task1_allocator_handle = tokio::spawn(async move {
2378 tracing::info!("spawning task1");
2379 task1_allocator_copy
2380 .start(task1_cmd, task1_addr, None)
2381 .await
2382 .unwrap();
2383 tracing::info!("task1 terminated");
2384 });
2385 let task2_allocator_copy = task2_allocator.clone();
2386 let task2_allocator_handle = tokio::spawn(async move {
2387 task2_allocator_copy
2388 .start(task2_cmd, task2_addr, None)
2389 .await
2390 .unwrap();
2391 tracing::info!("task2 terminated");
2392 });
2393
2394 let mut initializer = MockRemoteProcessAllocInitializer::new();
2395 initializer.expect_initialize_alloc().return_once(move || {
2396 Ok(vec![
2397 RemoteProcessAllocHost {
2398 hostname: task1_addr_string,
2399 id: "task1".to_string(),
2400 },
2401 RemoteProcessAllocHost {
2402 hostname: task2_addr_string,
2403 id: "task2".to_string(),
2404 },
2405 ])
2406 });
2407 let mut alloc = RemoteProcessAlloc::new(spec.clone(), alloc_name, 0, initializer)
2408 .await
2409 .unwrap();
2410 for _ in 0..spec.extent.num_ranks() * 2 {
2411 match alloc.next().await {
2412 Some(ProcState::Created { .. }) | Some(ProcState::Running { .. }) => {}
2413 _ => panic!("expected Created or Running"),
2414 }
2415 }
2416
2417 let timeout = tokio::time::Instant::now() + std::time::Duration::from_millis(1000);
2419 tokio::select! {
2420 _ = tokio::time::sleep_until(timeout) => {},
2421 _ = alloc.next() => panic!("expected no more items"),
2422 }
2423
2424 tracing::info!("aborting task1 allocator");
2426 task1_allocator_handle.abort();
2427 tokio::time::sleep(
2428 hyperactor_config::global::get(hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)
2429 * 2,
2430 )
2431 .await;
2432 for _ in 0..spec.extent.num_ranks() / 2 {
2433 let proc_state = alloc.next().await.unwrap();
2434 tracing::info!("test received next proc_state: {:?}", proc_state);
2435 match proc_state {
2436 ProcState::Stopped { reason, .. } => {
2437 assert_eq!(reason, ProcStopReason::HostWatchdog);
2438 }
2439 _ => panic!("expected stopped"),
2440 }
2441 }
2442 let timeout = tokio::time::Instant::now() + std::time::Duration::from_millis(1000);
2444 tokio::select! {
2445 _ = tokio::time::sleep_until(timeout) => {},
2446 _ = alloc.next() => panic!("expected no more items"),
2447 }
2448
2449 tracing::info!("aborting task2 allocator");
2451 task2_allocator_handle.abort();
2452 tokio::time::sleep(
2453 hyperactor_config::global::get(hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)
2454 * 2,
2455 )
2456 .await;
2457 for _ in 0..spec.extent.num_ranks() / 2 {
2458 let proc_state = alloc.next().await.unwrap();
2459 tracing::info!("test received next proc_state: {:?}", proc_state);
2460 match proc_state {
2461 ProcState::Stopped { reason, .. } => {
2462 assert_eq!(reason, ProcStopReason::HostWatchdog);
2463 }
2464 _ => panic!("expected stopped"),
2465 }
2466 }
2467 let proc_state = alloc.next().await;
2469 assert!(proc_state.is_none());
2470 let proc_state = alloc.next().await;
2472 assert!(proc_state.is_none());
2473 }
2474
2475 #[async_timed_test(timeout_secs = 60)]
2476 #[cfg(fbcode_build)]
2477 async fn test_alloc_inner_alloc_failure() {
2478 unsafe {
2480 std::env::set_var("MONARCH_MESSAGE_DELIVERY_TIMEOUT_SECS", "1");
2481 }
2482
2483 let config = hyperactor_config::global::lock();
2484 let _guard = config.override_key(
2485 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2486 Duration::from_millis(100),
2487 );
2488 hyperactor_telemetry::initialize_logging_for_test();
2489
2490 let spec = AllocSpec {
2491 extent: extent!(host = 2, gpu = 2),
2492 constraints: Default::default(),
2493 proc_name: None,
2494 transport: ChannelTransport::Unix,
2495 proc_allocation_mode: Default::default(),
2496 };
2497 let alloc_name = AllocName("test_alloc_name".to_string());
2498
2499 let task1_allocator = RemoteProcessAllocator::new();
2500 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2501 let task1_addr_string = task1_addr.to_string();
2502 let task1_cmd = Command::new(crate::testresource::get(
2503 "monarch/hyperactor_mesh/bootstrap",
2504 ));
2505 let task2_allocator = RemoteProcessAllocator::new();
2506 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2507 let task2_addr_string = task2_addr.to_string();
2508 let task2_cmd = Command::new("/caught/somewhere/in/time");
2510 let task1_allocator_copy = task1_allocator.clone();
2511 let task1_allocator_handle = tokio::spawn(async move {
2512 tracing::info!("spawning task1");
2513 task1_allocator_copy
2514 .start(task1_cmd, task1_addr, None)
2515 .await
2516 .unwrap();
2517 });
2518 let task2_allocator_copy = task2_allocator.clone();
2519 let task2_allocator_handle = tokio::spawn(async move {
2520 task2_allocator_copy
2521 .start(task2_cmd, task2_addr, None)
2522 .await
2523 .unwrap();
2524 });
2525
2526 let mut initializer = MockRemoteProcessAllocInitializer::new();
2527 initializer.expect_initialize_alloc().return_once(move || {
2528 Ok(vec![
2529 RemoteProcessAllocHost {
2530 hostname: task1_addr_string,
2531 id: "task1".to_string(),
2532 },
2533 RemoteProcessAllocHost {
2534 hostname: task2_addr_string,
2535 id: "task2".to_string(),
2536 },
2537 ])
2538 });
2539 let mut alloc = RemoteProcessAlloc::new(spec.clone(), alloc_name, 0, initializer)
2540 .await
2541 .unwrap();
2542 let mut created = HashSet::new();
2543 let mut started_procs = HashSet::new();
2544 let mut proc_points = HashSet::new();
2545 let mut failed = 0;
2546 for _ in 0..spec.extent.num_ranks() + 1 {
2548 let proc_state = alloc.next().await.unwrap();
2549 tracing::debug!("test got message: {:?}", proc_state);
2550 match proc_state {
2551 ProcState::Created {
2552 create_key, point, ..
2553 } => {
2554 created.insert(create_key);
2555 proc_points.insert(point);
2556 }
2557 ProcState::Running { create_key, .. } => {
2558 assert!(created.remove(&create_key));
2559 started_procs.insert(create_key);
2560 }
2561 ProcState::Failed { .. } => {
2562 failed += 1;
2563 }
2564 _ => panic!("expected Created, Running or Failed"),
2565 }
2566 }
2567 assert!(created.is_empty());
2568 assert_eq!(failed, 1);
2569 for rank in 0..spec.extent.num_ranks() / 2 {
2571 let point = spec.extent.point_of_rank(rank).unwrap();
2572 assert!(proc_points.contains(&point));
2573 }
2574
2575 let timeout = tokio::time::Instant::now() + std::time::Duration::from_millis(1000);
2577 tokio::select! {
2578 _ = tokio::time::sleep_until(timeout) => {},
2579 _ = alloc.next() => panic!("expected no more items"),
2580 }
2581
2582 alloc.stop().await.unwrap();
2584 for _ in 0..spec.extent.num_ranks() / 2 {
2585 let proc_state = alloc.next().await.unwrap();
2586 tracing::info!("test received next proc_state: {:?}", proc_state);
2587 match proc_state {
2588 ProcState::Stopped {
2589 create_key, reason, ..
2590 } => {
2591 assert!(started_procs.remove(&create_key));
2592 assert_eq!(reason, ProcStopReason::Stopped);
2593 }
2594 _ => panic!("expected stopped"),
2595 }
2596 }
2597 let proc_state = alloc.next().await;
2599 assert!(proc_state.is_none());
2600 let proc_state = alloc.next().await;
2602 assert!(proc_state.is_none());
2603
2604 task1_allocator.terminate();
2605 task1_allocator_handle.await.unwrap();
2606 task2_allocator.terminate();
2607 task2_allocator_handle.await.unwrap();
2608 }
2609
2610 #[async_timed_test(timeout_secs = 180)]
2611 #[cfg(fbcode_build)]
2612 async fn test_remote_process_alloc_signal_handler() {
2613 hyperactor_telemetry::initialize_logging_for_test();
2614 let num_proc_meshes = 5;
2615 let hosts_per_proc_mesh = 5;
2616
2617 let pid_addr = ChannelAddr::any(ChannelTransport::Unix);
2618 let (pid_addr, mut pid_rx) = channel::serve::<u32>(pid_addr).unwrap();
2619
2620 let addresses = (0..(num_proc_meshes * hosts_per_proc_mesh))
2621 .map(|_| ChannelAddr::any(ChannelTransport::Unix).to_string())
2622 .collect::<Vec<_>>();
2623
2624 let remote_process_allocators = addresses
2625 .iter()
2626 .map(|addr| {
2627 Command::new(crate::testresource::get(
2628 "monarch/hyperactor_mesh/remote_process_allocator",
2629 ))
2630 .env("RUST_LOG", "info")
2631 .arg(format!("--addr={addr}"))
2632 .stdout(std::process::Stdio::piped())
2633 .spawn()
2634 .unwrap()
2635 })
2636 .collect::<Vec<_>>();
2637
2638 let done_allocating_addr = ChannelAddr::any(ChannelTransport::Unix);
2639 let (done_allocating_addr, mut done_allocating_rx) =
2640 channel::serve::<()>(done_allocating_addr).unwrap();
2641 let mut remote_process_alloc = Command::new(crate::testresource::get(
2642 "monarch/hyperactor_mesh/remote_process_alloc",
2643 ))
2644 .arg(format!("--done-allocating-addr={}", done_allocating_addr))
2645 .arg(format!("--addresses={}", addresses.join(",")))
2646 .arg(format!("--num-proc-meshes={}", num_proc_meshes))
2647 .arg(format!("--hosts-per-proc-mesh={}", hosts_per_proc_mesh))
2648 .arg(format!("--pid-addr={}", pid_addr))
2649 .spawn()
2650 .unwrap();
2651
2652 done_allocating_rx.recv().await.unwrap();
2653 let mut received_pids = Vec::new();
2654 while let Ok(pid) = pid_rx.recv().await {
2655 received_pids.push(pid);
2656 if received_pids.len() == remote_process_allocators.len() {
2657 break;
2658 }
2659 }
2660
2661 signal::kill(
2662 Pid::from_raw(remote_process_alloc.id().unwrap() as i32),
2663 signal::SIGINT,
2664 )
2665 .unwrap();
2666
2667 assert_eq!(
2668 remote_process_alloc.wait().await.unwrap().signal(),
2669 Some(signal::SIGINT as i32)
2670 );
2671
2672 tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
2673
2674 for child_pid in received_pids {
2676 let pid_check = Command::new("kill")
2677 .arg("-0")
2678 .arg(child_pid.to_string())
2679 .output()
2680 .await
2681 .expect("Failed to check if PID is alive");
2682
2683 assert!(
2684 !pid_check.status.success(),
2685 "PID {} should no longer be alive",
2686 child_pid
2687 );
2688 }
2689
2690 for mut remote_process_allocator in remote_process_allocators {
2693 remote_process_allocator.kill().await.unwrap();
2694 }
2695 }
2696}