1use std::collections::HashMap;
10use std::collections::HashSet;
11use std::collections::VecDeque;
12use std::sync::Arc;
13use std::time::Duration;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use dashmap::DashMap;
18use futures::FutureExt;
19use futures::future::join_all;
20use futures::future::select_all;
21use hyperactor::WorldId;
22use hyperactor::channel;
23use hyperactor::channel::ChannelAddr;
24use hyperactor::channel::ChannelRx;
25use hyperactor::channel::ChannelTransport;
26use hyperactor::channel::ChannelTx;
27use hyperactor::channel::Rx;
28use hyperactor::channel::TcpMode;
29use hyperactor::channel::Tx;
30use hyperactor::channel::TxStatus;
31use hyperactor::clock;
32use hyperactor::clock::Clock;
33use hyperactor::clock::RealClock;
34use hyperactor::internal_macro_support::serde_json;
35use hyperactor::mailbox::DialMailboxRouter;
36use hyperactor::mailbox::MailboxServer;
37use hyperactor::observe_async;
38use hyperactor::observe_result;
39use hyperactor::reference::Reference;
40use mockall::automock;
41use ndslice::Region;
42use ndslice::Slice;
43use ndslice::View;
44use ndslice::ViewExt;
45use ndslice::view::Extent;
46use ndslice::view::Point;
47use serde::Deserialize;
48use serde::Serialize;
49use strum::AsRefStr;
50use tokio::io::AsyncWriteExt;
51use tokio::process::Command;
52use tokio::sync::mpsc::UnboundedReceiver;
53use tokio::sync::mpsc::UnboundedSender;
54use tokio::sync::mpsc::unbounded_channel;
55use tokio::task::JoinHandle;
56use tokio_stream::StreamExt;
57use tokio_stream::wrappers::WatchStream;
58use tokio_util::sync::CancellationToken;
59use typeuri::Named;
60
61use crate::alloc::Alloc;
62use crate::alloc::AllocConstraints;
63use crate::alloc::AllocSpec;
64use crate::alloc::Allocator;
65use crate::alloc::AllocatorError;
66use crate::alloc::ProcState;
67use crate::alloc::ProcStopReason;
68use crate::alloc::ProcessAllocator;
69use crate::alloc::process::CLIENT_TRACE_ID_LABEL;
70use crate::alloc::process::ClientContext;
71use crate::alloc::with_unspecified_port_or_any;
72use crate::shortuuid::ShortUuid;
73
74#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
76pub enum RemoteProcessAllocatorMessage {
77 Allocate {
79 alloc_key: ShortUuid,
81 extent: Extent,
83 bootstrap_addr: ChannelAddr,
85 hosts: Vec<String>,
88 client_context: Option<ClientContext>,
92 forwarder_addr: ChannelAddr,
94 },
95 Stop,
97 HeartBeat,
100}
101wirevalue::register_type!(RemoteProcessAllocatorMessage);
102
103#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
107pub enum RemoteProcessProcStateMessage {
108 Allocated {
110 alloc_key: ShortUuid,
111 world_id: WorldId,
112 },
113 Update(ShortUuid, ProcState),
115 Done(ShortUuid),
117 HeartBeat,
119}
120wirevalue::register_type!(RemoteProcessProcStateMessage);
121
122pub struct RemoteProcessAllocator {
124 cancel_token: CancellationToken,
125}
126
127async fn conditional_sleeper<F: futures::Future<Output = ()>>(t: Option<F>) {
128 match t {
129 Some(timer) => timer.await,
130 None => futures::future::pending().await,
131 }
132}
133
134impl RemoteProcessAllocator {
135 pub fn new() -> Arc<Self> {
137 Arc::new(Self {
138 cancel_token: CancellationToken::new(),
139 })
140 }
141
142 pub fn terminate(&self) {
144 self.cancel_token.cancel();
145 }
146
147 #[hyperactor::instrument]
163 pub async fn start(
164 &self,
165 cmd: Command,
166 serve_addr: ChannelAddr,
167 timeout: Option<Duration>,
168 ) -> Result<(), anyhow::Error> {
169 let process_allocator = ProcessAllocator::new(cmd);
170 self.start_with_allocator(serve_addr, process_allocator, timeout)
171 .await
172 }
173
174 #[hyperactor::instrument(fields(addr=serve_addr.to_string()))]
178 pub async fn start_with_allocator<A: Allocator + Send + Sync + 'static>(
179 &self,
180 serve_addr: ChannelAddr,
181 mut process_allocator: A,
182 timeout: Option<Duration>,
183 ) -> Result<(), anyhow::Error>
184 where
185 <A as Allocator>::Alloc: Send,
186 <A as Allocator>::Alloc: Sync,
187 {
188 tracing::info!("starting remote allocator on: {}", serve_addr);
189 let (_, mut rx) = channel::serve(serve_addr.clone()).map_err(anyhow::Error::from)?;
190
191 struct ActiveAllocation {
192 handle: JoinHandle<()>,
193 cancel_token: CancellationToken,
194 }
195 #[observe_async("RemoteProcessAllocator")]
196 async fn ensure_previous_alloc_stopped(active_allocation: &mut Option<ActiveAllocation>) {
197 if let Some(active_allocation) = active_allocation.take() {
198 tracing::info!("previous alloc found, stopping");
199 active_allocation.cancel_token.cancel();
200 match active_allocation.handle.await {
201 Ok(_) => {
202 tracing::info!("allocation stopped.")
204 }
205 Err(e) => {
206 tracing::error!("allocation handler failed: {}", e);
207 }
208 }
209 }
210 }
211
212 let mut active_allocation: Option<ActiveAllocation> = None;
213 loop {
214 let sleep = conditional_sleeper(timeout.map(|t| RealClock.sleep(t)));
217 tokio::select! {
218 msg = rx.recv() => {
219 match msg {
220 Ok(RemoteProcessAllocatorMessage::Allocate {
221 alloc_key,
222 extent,
223 bootstrap_addr,
224 hosts,
225 client_context,
226 forwarder_addr,
227 }) => {
228 tracing::info!("received allocation request for {} with extent {}", alloc_key, extent);
229 ensure_previous_alloc_stopped(&mut active_allocation).await;
230
231 let mut constraints: AllocConstraints = Default::default();
233 if let Some(context) = &client_context {
234 constraints = AllocConstraints {
235 match_labels: HashMap::from([(
236 CLIENT_TRACE_ID_LABEL.to_string(),
237 context.trace_id.to_string(),
238 )]
239 )};
240 tracing::info!(
241 monarch_client_trace_id = context.trace_id.to_string(),
242 "allocating...",
243 );
244 }
245
246
247 let spec = AllocSpec {
248 extent,
249 constraints,
250 proc_name: None, transport: ChannelTransport::Unix,
252 proc_allocation_mode: Default::default(),
253 };
254
255 match process_allocator.allocate(spec.clone()).await {
256 Ok(alloc) => {
257 let cancel_token = CancellationToken::new();
258 active_allocation = Some(ActiveAllocation {
259 cancel_token: cancel_token.clone(),
260 handle: tokio::spawn(Self::handle_allocation_request(
261 Box::new(alloc) as Box<dyn Alloc + Send + Sync>,
262 alloc_key,
263 bootstrap_addr,
264 hosts,
265 cancel_token,
266 forwarder_addr,
267 )),
268 })
269 }
270 Err(e) => {
271 tracing::error!("allocation for {:?} failed: {}", spec, e);
272 continue;
273 }
274 }
275 }
276 Ok(RemoteProcessAllocatorMessage::Stop) => {
277 tracing::info!("received stop request");
278
279 ensure_previous_alloc_stopped(&mut active_allocation).await;
280 }
281 Ok(RemoteProcessAllocatorMessage::HeartBeat) => {}
285 Err(e) => {
286 tracing::error!("upstream channel error: {}", e);
287 continue;
288 }
289 }
290 }
291 _ = self.cancel_token.cancelled() => {
292 tracing::info!("main loop cancelled");
293
294 ensure_previous_alloc_stopped(&mut active_allocation).await;
295
296 break;
297 }
298 _ = sleep => {
299 if active_allocation.is_some() {
301 continue;
302 }
303 tracing::warn!("timeout of {} seconds elapsed without any allocations, exiting", timeout.unwrap_or_default().as_secs());
306 break;
307 }
308 }
309 }
310
311 Ok(())
312 }
313
314 #[tracing::instrument(skip(alloc, cancel_token))]
315 #[observe_async("RemoteProcessAllocator")]
316 async fn handle_allocation_request(
317 alloc: Box<dyn Alloc + Send + Sync>,
318 alloc_key: ShortUuid,
319 bootstrap_addr: ChannelAddr,
320 hosts: Vec<String>,
321 cancel_token: CancellationToken,
322 forwarder_addr: ChannelAddr,
323 ) {
324 tracing::info!("handle allocation request, bootstrap_addr: {bootstrap_addr}");
325 let (forwarder_addr, forwarder_rx) = match channel::serve(forwarder_addr) {
327 Ok(v) => v,
328 Err(e) => {
329 tracing::error!("failed to to bootstrap forwarder actor: {}", e);
330 return;
331 }
332 };
333 let router = DialMailboxRouter::new();
334 let mailbox_handle = router.clone().serve(forwarder_rx);
335 tracing::info!("started forwarder on: {}", forwarder_addr);
336
337 if let Ok(hosts_file) = std::env::var("TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE") {
340 tracing::info!("writing hosts to {}", hosts_file);
341 #[derive(Serialize)]
342 struct Hosts {
343 hostnames: Vec<String>,
344 }
345 match serde_json::to_string(&Hosts { hostnames: hosts }) {
346 Ok(json) => match tokio::fs::File::create(&hosts_file).await {
347 Ok(mut file) => {
348 if file.write_all(json.as_bytes()).await.is_err() {
349 tracing::error!("failed to write hosts to {}", hosts_file);
350 return;
351 }
352 }
353 Err(e) => {
354 tracing::error!("failed to open hosts file {}: {}", hosts_file, e);
355 return;
356 }
357 },
358 Err(e) => {
359 tracing::error!("failed to serialize hosts: {}", e);
360 return;
361 }
362 }
363 }
364
365 Self::handle_allocation_loop(
366 alloc,
367 alloc_key,
368 bootstrap_addr,
369 router,
370 forwarder_addr,
371 cancel_token,
372 )
373 .await;
374
375 mailbox_handle.stop("alloc stopped");
376 if let Err(e) = mailbox_handle.await {
377 tracing::error!("failed to join forwarder: {}", e);
378 }
379 }
380
381 async fn handle_allocation_loop(
382 mut alloc: Box<dyn Alloc + Send + Sync>,
383 alloc_key: ShortUuid,
384 bootstrap_addr: ChannelAddr,
385 router: DialMailboxRouter,
386 forward_addr: ChannelAddr,
387 cancel_token: CancellationToken,
388 ) {
389 let world_id = alloc.world_id().clone();
390 tracing::info!("starting handle allocation loop for {}", world_id);
391 let tx = match channel::dial(bootstrap_addr) {
392 Ok(tx) => tx,
393 Err(err) => {
394 tracing::error!("failed to dial bootstrap address: {}", err);
395 return;
396 }
397 };
398 let message = RemoteProcessProcStateMessage::Allocated {
399 alloc_key: alloc_key.clone(),
400 world_id,
401 };
402 tracing::info!(name = message.as_ref(), "sending allocated message",);
403 if let Err(e) = tx.send(message).await {
404 tracing::error!("failed to send Allocated message: {}", e);
405 return;
406 }
407
408 let mut mesh_agents_by_create_key = HashMap::new();
409 let mut running = true;
410 let tx_status = tx.status().clone();
411 let mut tx_watcher = WatchStream::new(tx_status);
412 loop {
413 tokio::select! {
414 _ = cancel_token.cancelled(), if running => {
415 tracing::info!("cancelled, stopping allocation");
416 running = false;
417 if let Err(e) = alloc.stop().await {
418 tracing::error!("stop failed: {}", e);
419 break;
420 }
421 }
422 status = tx_watcher.next(), if running => {
423 match status {
424 Some(TxStatus::Closed) => {
425 tracing::error!("upstream channel state closed");
426 break;
427 },
428 _ => {
429 tracing::debug!("got channel event: {:?}", status.unwrap());
430 continue;
431 }
432 }
433 }
434 e = alloc.next() => {
435 match e {
436 Some(event) => {
437 tracing::debug!(name = event.as_ref(), "got event: {:?}", event);
438 let event = match event {
439 ProcState::Created { .. } => event,
440 ProcState::Running { create_key, proc_id, mesh_agent, addr } => {
441 tracing::debug!("remapping mesh_agent {}: addr {} -> {}", mesh_agent, addr, forward_addr);
443 mesh_agents_by_create_key.insert(create_key.clone(), mesh_agent.clone());
444 router.bind(mesh_agent.actor_id().proc_id().clone().into(), addr);
445 ProcState::Running { create_key, proc_id, mesh_agent, addr: forward_addr.clone() }
446 },
447 ProcState::Stopped { create_key, reason } => {
448 match mesh_agents_by_create_key.remove(&create_key) {
449 Some(mesh_agent) => {
450 tracing::debug!("unmapping mesh_agent {}", mesh_agent);
451 let agent_ref: Reference = mesh_agent.actor_id().proc_id().clone().into();
452 router.unbind(&agent_ref);
453 },
454 None => {
455 tracing::warn!("mesh_agent not found for create key {}", create_key);
456 }
457 }
458 ProcState::Stopped { create_key, reason }
459 },
460 ProcState::Failed { ref world_id, ref description } => {
461 tracing::error!("allocation failed for {}: {}", world_id, description);
462 event
463 }
464 };
465 tracing::debug!(name = event.as_ref(), "sending event: {:?}", event);
466 tx.post(RemoteProcessProcStateMessage::Update(alloc_key.clone(), event));
467 }
468 None => {
469 tracing::debug!("sending done");
470 tx.post(RemoteProcessProcStateMessage::Done(alloc_key.clone()));
471 running = false;
472 break;
473 }
474 }
475 }
476 _ = RealClock.sleep(hyperactor_config::global::get(hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)) => {
477 tracing::trace!("sending heartbeat");
478 tx.post(RemoteProcessProcStateMessage::HeartBeat);
479 }
480 }
481 }
482 tracing::info!("allocation handler loop exited");
483 if running {
484 tracing::info!("stopping processes");
485 if let Err(e) = alloc.stop_and_wait().await {
486 tracing::error!("stop failed: {}", e);
487 return;
488 }
489 tracing::info!("stop finished");
490 }
491 }
492}
493
494type HostId = String;
497
498#[derive(Clone)]
500pub struct RemoteProcessAllocHost {
501 pub id: HostId,
504 pub hostname: String,
506}
507
508struct RemoteProcessAllocHostState {
510 alloc_key: ShortUuid,
512 host_id: HostId,
514 tx: ChannelTx<RemoteProcessAllocatorMessage>,
516 active_procs: HashSet<ShortUuid>,
518 region: Region,
520 world_id: Option<WorldId>,
522 failed: bool,
524 allocated: bool,
526}
527
528#[automock]
529#[async_trait]
530pub trait RemoteProcessAllocInitializer {
532 async fn initialize_alloc(&mut self) -> Result<Vec<RemoteProcessAllocHost>, anyhow::Error>;
534}
535
536struct HostStates {
539 inner: HashMap<HostId, RemoteProcessAllocHostState>,
540 host_addresses: Arc<DashMap<HostId, ChannelAddr>>,
541}
542
543impl HostStates {
544 fn new(host_addresses: Arc<DashMap<HostId, ChannelAddr>>) -> HostStates {
545 Self {
546 inner: HashMap::new(),
547 host_addresses,
548 }
549 }
550
551 fn insert(
552 &mut self,
553 host_id: HostId,
554 state: RemoteProcessAllocHostState,
555 address: ChannelAddr,
556 ) {
557 self.host_addresses.insert(host_id.clone(), address);
558 self.inner.insert(host_id, state);
559 }
560
561 fn get(&self, host_id: &HostId) -> Option<&RemoteProcessAllocHostState> {
562 self.inner.get(host_id)
563 }
564
565 fn get_mut(&mut self, host_id: &HostId) -> Option<&mut RemoteProcessAllocHostState> {
566 self.inner.get_mut(host_id)
567 }
568
569 fn remove(&mut self, host_id: &HostId) -> Option<RemoteProcessAllocHostState> {
570 self.host_addresses.remove(host_id);
571 self.inner.remove(host_id)
572 }
573
574 fn iter(&self) -> impl Iterator<Item = (&HostId, &RemoteProcessAllocHostState)> {
575 self.inner.iter()
576 }
577
578 fn iter_mut(&mut self) -> impl Iterator<Item = (&HostId, &mut RemoteProcessAllocHostState)> {
579 self.inner.iter_mut()
580 }
581
582 fn is_empty(&self) -> bool {
583 self.inner.is_empty()
584 }
585 }
587
588pub struct RemoteProcessAlloc {
591 initializer: Box<dyn RemoteProcessAllocInitializer + Send + Sync>,
594 spec: AllocSpec,
595 remote_allocator_port: u16,
596 world_id: WorldId,
597 ordered_hosts: Vec<RemoteProcessAllocHost>,
598 started: bool,
600 running: bool,
602 failed: bool,
604 alloc_to_host: HashMap<ShortUuid, HostId>,
606 host_states: HostStates,
607 world_offsets: HashMap<WorldId, usize>,
608 event_queue: VecDeque<ProcState>,
609 comm_watcher_tx: UnboundedSender<HostId>,
610 comm_watcher_rx: UnboundedReceiver<HostId>,
611
612 bootstrap_addr: ChannelAddr,
613 rx: ChannelRx<RemoteProcessProcStateMessage>,
614 _signal_cleanup_guard: hyperactor::SignalCleanupGuard,
615}
616
617impl RemoteProcessAlloc {
618 #[tracing::instrument(skip(initializer))]
623 #[observe_result("RemoteProcessAlloc")]
624 pub async fn new(
625 spec: AllocSpec,
626 world_id: WorldId,
627 remote_allocator_port: u16,
628 initializer: impl RemoteProcessAllocInitializer + Send + Sync + 'static,
629 ) -> Result<Self, anyhow::Error> {
630 let alloc_serve_addr = ChannelAddr::any(spec.transport.clone());
631
632 let (bootstrap_addr, rx) = channel::serve(alloc_serve_addr)?;
633
634 tracing::info!(
635 "starting alloc for {} on: {}",
636 world_id,
637 bootstrap_addr.clone()
638 );
639
640 let (comm_watcher_tx, comm_watcher_rx) = unbounded_channel();
641
642 let host_addresses = Arc::new(DashMap::<HostId, ChannelAddr>::new());
643 let host_addresses_for_signal = host_addresses.clone();
644
645 let signal_cleanup_guard =
647 hyperactor::register_signal_cleanup_scoped(Box::pin(async move {
648 join_all(host_addresses_for_signal.iter().map(|entry| async move {
649 let addr = entry.value().clone();
650 match channel::dial(addr.clone()) {
651 Ok(tx) => {
652 if let Err(e) = tx.send(RemoteProcessAllocatorMessage::Stop).await {
653 tracing::error!("Failed to send Stop to {}: {}", addr, e);
654 }
655 }
656 Err(e) => {
657 tracing::error!("Failed to dial {} during signal cleanup: {}", addr, e);
658 }
659 }
660 }))
661 .await;
662 }));
663
664 Ok(Self {
665 spec,
666 world_id,
667 remote_allocator_port,
668 initializer: Box::new(initializer),
669 world_offsets: HashMap::new(),
670 ordered_hosts: Vec::new(),
671 alloc_to_host: HashMap::new(),
672 host_states: HostStates::new(host_addresses),
673 bootstrap_addr,
674 event_queue: VecDeque::new(),
675 comm_watcher_tx,
676 comm_watcher_rx,
677 rx,
678 started: false,
679 running: true,
680 failed: false,
681 _signal_cleanup_guard: signal_cleanup_guard,
682 })
683 }
684
685 async fn start_comm_watcher(&self) {
693 let mut tx_watchers = Vec::new();
694 for host in &self.ordered_hosts {
695 let tx_status = self.host_states.get(&host.id).unwrap().tx.status().clone();
696 let watcher = WatchStream::new(tx_status);
697 tx_watchers.push((watcher, host.id.clone()));
698 }
699 assert!(!tx_watchers.is_empty());
700 let tx = self.comm_watcher_tx.clone();
701 tokio::spawn(async move {
702 loop {
703 let mut tx_status_futures = Vec::new();
704 for (watcher, _) in &mut tx_watchers {
705 let fut = watcher.next().boxed();
706 tx_status_futures.push(fut);
707 }
708 let (tx_status, index, _) = select_all(tx_status_futures).await;
709 let host_id = match tx_watchers.get(index) {
710 Some((_, host_id)) => host_id.clone(),
711 None => {
712 tracing::error!(
714 "got selected index {} with no matching host in {}",
715 index,
716 tx_watchers.len()
717 );
718 continue;
719 }
720 };
721 if let Some(tx_status) = tx_status {
722 tracing::debug!("host {} channel event: {:?}", host_id, tx_status);
723 if tx_status == TxStatus::Closed {
724 if tx.send(host_id.clone()).is_err() {
725 break;
727 }
728 tx_watchers.remove(index);
729 if tx_watchers.is_empty() {
730 break;
732 }
733 }
734 }
735 }
736 });
737 }
738
739 async fn ensure_started(&mut self) -> Result<(), anyhow::Error> {
744 if self.started || self.failed {
745 return Ok(());
746 }
747
748 self.started = true;
749 let hosts = self
750 .initializer
751 .initialize_alloc()
752 .await
753 .context("alloc initializer error")?;
754 if hosts.is_empty() {
755 anyhow::bail!("initializer returned empty list of hosts");
756 }
757 let hostnames: Vec<_> = hosts.iter().map(|e| e.hostname.clone()).collect();
760 tracing::info!("obtained {} hosts for this allocation", hostnames.len());
761
762 use crate::alloc::ProcAllocationMode;
764
765 let regions: Option<Vec<_>> = match self.spec.proc_allocation_mode {
767 ProcAllocationMode::ProcLevel => {
768 anyhow::ensure!(
770 self.spec.extent.len() >= 2,
771 "invalid extent: {}, expected at least 2 dimensions",
772 self.spec.extent
773 );
774 None
775 }
776 ProcAllocationMode::HostLevel => Some({
777 let num_points = self.spec.extent.num_ranks();
779 anyhow::ensure!(
780 hosts.len() >= num_points,
781 "HostLevel allocation mode requires {} hosts (one per point in extent {}), but only {} hosts were provided",
782 num_points,
783 self.spec.extent,
784 hosts.len()
785 );
786
787 let labels = self.spec.extent.labels().to_vec();
790
791 let extent_sizes = self.spec.extent.sizes();
793 let mut parent_strides = vec![1; extent_sizes.len()];
794 for i in (0..extent_sizes.len() - 1).rev() {
795 parent_strides[i] = parent_strides[i + 1] * extent_sizes[i + 1];
796 }
797
798 (0..num_points)
799 .map(|rank| {
800 let sizes = vec![1; labels.len()];
803 Region::new(
804 labels.clone(),
805 Slice::new(rank, sizes, parent_strides.clone()).unwrap(),
806 )
807 })
808 .collect()
809 }),
810 };
811
812 match self.spec.proc_allocation_mode {
813 ProcAllocationMode::ProcLevel => {
814 let split_dim = &self.spec.extent.labels()[self.spec.extent.len() - 1];
816 for (i, region) in self.spec.extent.group_by(split_dim)?.enumerate() {
817 let host = &hosts[i];
818 tracing::debug!("allocating: {} for host: {}", region, host.id);
819
820 let remote_addr = match self.spec.transport {
821 ChannelTransport::MetaTls(_) => {
822 format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
823 }
824 ChannelTransport::Tcp(TcpMode::Localhost) => {
825 format!("tcp![::1]:{}", self.remote_allocator_port)
827 }
828 ChannelTransport::Tcp(TcpMode::Hostname) => {
829 format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
830 }
831 ChannelTransport::Unix => host.hostname.clone(),
833 _ => {
834 anyhow::bail!(
835 "unsupported transport for host {}: {:?}",
836 host.id,
837 self.spec.transport,
838 );
839 }
840 };
841
842 tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
843 let remote_addr = remote_addr.parse::<ChannelAddr>()?;
844 let tx = channel::dial(remote_addr.clone())
845 .map_err(anyhow::Error::from)
846 .context(format!(
847 "failed to dial remote {} for host {}",
848 remote_addr, host.id
849 ))?;
850
851 let alloc_key = ShortUuid::generate();
853 assert!(
854 self.alloc_to_host
855 .insert(alloc_key.clone(), host.id.clone())
856 .is_none()
857 );
858
859 let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
860 let client_context = Some(ClientContext { trace_id });
861 let message = RemoteProcessAllocatorMessage::Allocate {
862 alloc_key: alloc_key.clone(),
863 extent: region.extent(),
864 bootstrap_addr: self.bootstrap_addr.clone(),
865 hosts: hostnames.clone(),
866 client_context,
867 forwarder_addr: with_unspecified_port_or_any(&remote_addr),
873 };
874 tracing::info!(
875 name = message.as_ref(),
876 "sending allocate message to workers"
877 );
878 tx.post(message);
879
880 self.host_states.insert(
881 host.id.clone(),
882 RemoteProcessAllocHostState {
883 alloc_key,
884 host_id: host.id.clone(),
885 tx,
886 active_procs: HashSet::new(),
887 region,
888 world_id: None,
889 failed: false,
890 allocated: false,
891 },
892 remote_addr,
893 );
894 }
895
896 self.ordered_hosts = hosts;
897 }
898 ProcAllocationMode::HostLevel => {
899 let regions = regions.unwrap();
900 let num_regions = regions.len();
901 for (i, region) in regions.into_iter().enumerate() {
902 let host = &hosts[i];
903 tracing::debug!("allocating: {} for host: {}", region, host.id);
904
905 let remote_addr = match self.spec.transport {
906 ChannelTransport::MetaTls(_) => {
907 format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
908 }
909 ChannelTransport::Tcp(TcpMode::Localhost) => {
910 format!("tcp![::1]:{}", self.remote_allocator_port)
912 }
913 ChannelTransport::Tcp(TcpMode::Hostname) => {
914 format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
915 }
916 ChannelTransport::Unix => host.hostname.clone(),
918 _ => {
919 anyhow::bail!(
920 "unsupported transport for host {}: {:?}",
921 host.id,
922 self.spec.transport,
923 );
924 }
925 };
926
927 tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
928 let remote_addr = remote_addr.parse::<ChannelAddr>()?;
929 let tx = channel::dial(remote_addr.clone())
930 .map_err(anyhow::Error::from)
931 .context(format!(
932 "failed to dial remote {} for host {}",
933 remote_addr, host.id
934 ))?;
935
936 let alloc_key = ShortUuid::generate();
938 assert!(
939 self.alloc_to_host
940 .insert(alloc_key.clone(), host.id.clone())
941 .is_none()
942 );
943
944 let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
945 let client_context = Some(ClientContext { trace_id });
946 let message = RemoteProcessAllocatorMessage::Allocate {
947 alloc_key: alloc_key.clone(),
948 extent: region.extent(),
949 bootstrap_addr: self.bootstrap_addr.clone(),
950 hosts: hostnames.clone(),
951 client_context,
952 forwarder_addr: with_unspecified_port_or_any(&remote_addr),
958 };
959 tracing::info!(
960 name = message.as_ref(),
961 "sending allocate message to workers"
962 );
963 tx.post(message);
964
965 self.host_states.insert(
966 host.id.clone(),
967 RemoteProcessAllocHostState {
968 alloc_key,
969 host_id: host.id.clone(),
970 tx,
971 active_procs: HashSet::new(),
972 region,
973 world_id: None,
974 failed: false,
975 allocated: false,
976 },
977 remote_addr,
978 );
979 }
980
981 self.ordered_hosts = hosts.into_iter().take(num_regions).collect();
984 }
985 }
986 self.start_comm_watcher().await;
987 self.started = true;
988
989 Ok(())
990 }
991
992 fn get_host_state_mut(
994 &mut self,
995 alloc_key: &ShortUuid,
996 ) -> Result<&mut RemoteProcessAllocHostState, anyhow::Error> {
997 let host_id: &HostId = self
998 .alloc_to_host
999 .get(alloc_key)
1000 .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1001
1002 self.host_states
1003 .get_mut(host_id)
1004 .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1005 }
1006
1007 fn get_host_state(
1009 &self,
1010 alloc_key: &ShortUuid,
1011 ) -> Result<&RemoteProcessAllocHostState, anyhow::Error> {
1012 let host_id: &HostId = self
1013 .alloc_to_host
1014 .get(alloc_key)
1015 .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1016
1017 self.host_states
1018 .get(host_id)
1019 .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1020 }
1021
1022 fn remove_host_state(
1023 &mut self,
1024 alloc_key: &ShortUuid,
1025 ) -> Result<RemoteProcessAllocHostState, anyhow::Error> {
1026 let host_id: &HostId = self
1027 .alloc_to_host
1028 .get(alloc_key)
1029 .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1030
1031 self.host_states
1032 .remove(host_id)
1033 .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1034 }
1035
1036 fn add_proc_id_to_host_state(
1037 &mut self,
1038 alloc_key: &ShortUuid,
1039 create_key: &ShortUuid,
1040 ) -> Result<(), anyhow::Error> {
1041 let task_state = self.get_host_state_mut(alloc_key)?;
1042 if !task_state.active_procs.insert(create_key.clone()) {
1043 tracing::error!("proc with create key {} already in host state", create_key);
1045 }
1046 task_state.allocated = true;
1047 Ok(())
1048 }
1049
1050 fn remove_proc_from_host_state(
1051 &mut self,
1052 alloc_key: &ShortUuid,
1053 create_key: &ShortUuid,
1054 ) -> Result<(), anyhow::Error> {
1055 let task_state = self.get_host_state_mut(alloc_key)?;
1056 if !task_state.active_procs.remove(create_key) {
1057 tracing::error!("proc with create_key already in host state: {}", create_key);
1059 }
1060 Ok(())
1061 }
1062
1063 fn project_proc_into_global_extent(
1065 &self,
1066 alloc_key: &ShortUuid,
1067 point: &Point,
1068 ) -> Result<Point, anyhow::Error> {
1069 let global_rank = self
1070 .get_host_state(alloc_key)?
1071 .region
1072 .get(point.rank())
1073 .ok_or_else(|| {
1074 anyhow::anyhow!(
1075 "rank {} out of bounds for in alloc {}",
1076 point.rank(),
1077 alloc_key
1078 )
1079 })?;
1080 Ok(self.spec.extent.point_of_rank(global_rank)?)
1081 }
1082
1083 fn cleanup_host_channel_closed(
1085 &mut self,
1086 host_id: HostId,
1087 ) -> Result<Vec<ShortUuid>, anyhow::Error> {
1088 let state = match self.host_states.remove(&host_id) {
1089 Some(state) => state,
1090 None => {
1091 anyhow::bail!(
1093 "got channel closed event for host {} which has no known state",
1094 host_id
1095 );
1096 }
1097 };
1098 self.ordered_hosts.retain(|host| host.id != host_id);
1099 self.alloc_to_host.remove(&state.alloc_key);
1100 if let Some(world_id) = state.world_id {
1101 self.world_offsets.remove(&world_id);
1102 }
1103 let create_keys = state.active_procs.iter().cloned().collect();
1104
1105 Ok(create_keys)
1106 }
1107}
1108
1109#[async_trait]
1110impl Alloc for RemoteProcessAlloc {
1111 async fn next(&mut self) -> Option<ProcState> {
1112 loop {
1113 if let state @ Some(_) = self.event_queue.pop_front() {
1114 break state;
1115 }
1116
1117 if !self.running {
1118 break None;
1119 }
1120
1121 if let Err(e) = self.ensure_started().await {
1122 break Some(ProcState::Failed {
1123 world_id: self.world_id.clone(),
1124 description: format!("failed to ensure started: {:#}", e),
1125 });
1126 }
1127
1128 let heartbeat_interval = hyperactor_config::global::get(
1129 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1130 );
1131 let mut heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
1132 let mut reloop = false;
1134 let update = loop {
1135 tokio::select! {
1136 msg = self.rx.recv() => {
1137 tracing::debug!("got ProcState message from allocator: {:?}", msg);
1138 match msg {
1139 Ok(RemoteProcessProcStateMessage::Allocated { alloc_key, world_id }) => {
1140 tracing::info!("remote alloc {}: allocated", alloc_key);
1141 match self.get_host_state_mut(&alloc_key) {
1142 Ok(state) => {
1143 state.world_id = Some(world_id.clone());
1144 }
1145 Err(err) => {
1146 tracing::error!(
1148 "received allocated message alloc: {} with no known state: {}",
1149 alloc_key, err,
1150 );
1151 }
1152 }
1153 }
1154 Ok(RemoteProcessProcStateMessage::Update(alloc_key, proc_state)) => {
1155 let update = match proc_state {
1156 ProcState::Created { ref create_key, .. } => {
1157 if let Err(e) = self.add_proc_id_to_host_state(&alloc_key, create_key) {
1158 tracing::error!("failed to add proc with create key {} host state: {}", create_key, e);
1159 }
1160 proc_state
1161 }
1162 ProcState::Stopped{ ref create_key, ..} => {
1163 if let Err(e) = self.remove_proc_from_host_state(&alloc_key, create_key) {
1164 tracing::error!("failed to remove proc with create key {} host state: {}", create_key, e);
1165 }
1166 proc_state
1167 }
1168 ProcState::Failed { ref world_id, ref description } => {
1169 match self.get_host_state_mut(&alloc_key) {
1170 Ok(state) => {
1171 state.failed = true;
1172 ProcState::Failed {
1173 world_id: world_id.clone(),
1174 description: format!("host {} failed: {}", state.host_id, description),
1175 }
1176 }
1177 Err(e) => {
1178 tracing::error!("failed to find host state for world id: {}: {}", world_id, e);
1179 proc_state
1180 }
1181 }
1182 }
1183 _ => proc_state
1184 };
1185
1186 break Some((Some(alloc_key), update));
1187 }
1188 Ok(RemoteProcessProcStateMessage::Done(alloc_key)) => {
1189 tracing::info!("allocator {} is done", alloc_key);
1190
1191 if let Ok(state) = self.remove_host_state(&alloc_key) {
1192 if !state.active_procs.is_empty() {
1193 tracing::error!("received done for alloc {} with active procs: {:?}", alloc_key, state.active_procs);
1194 }
1195 } else {
1196 tracing::error!("received done for unknown alloc {}", alloc_key);
1197 }
1198
1199 if self.host_states.is_empty() {
1200 self.running = false;
1201 break None;
1202 }
1203 }
1204 Ok(RemoteProcessProcStateMessage::HeartBeat) => {}
1208 Err(e) => {
1209 break Some((None, ProcState::Failed {world_id: self.world_id.clone(), description: format!("error receiving events: {}", e)}));
1210 }
1211 }
1212 }
1213
1214 _ = clock::RealClock.sleep_until(heartbeat_time) => {
1215 self.host_states.iter().for_each(|(_, host_state)| host_state.tx.post(RemoteProcessAllocatorMessage::HeartBeat));
1216 heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
1217 }
1218
1219 closed_host_id = self.comm_watcher_rx.recv() => {
1220 if let Some(closed_host_id) = closed_host_id {
1221 tracing::debug!("host {} channel closed, cleaning up", closed_host_id);
1222 if let Some(state) = self.host_states.get(&closed_host_id)
1223 && !state.allocated {
1224 break Some((None, ProcState::Failed {
1225 world_id: self.world_id.clone(),
1226 description: format!(
1227 "no process has ever been allocated on {}, connection is not established; \
1228 a common issue could be allocator port mismatch, or transport not enabled on client",
1229 closed_host_id
1230 )}));
1231 }
1232 let create_keys = match self.cleanup_host_channel_closed(closed_host_id) {
1233 Ok(create_keys) => create_keys,
1234 Err(err) => {
1235 tracing::error!("failed to cleanup disconnected host: {}", err);
1236 continue;
1237 }
1238 };
1239 for create_key in create_keys {
1240 tracing::debug!("queuing Stopped state for proc with create key {}", create_key);
1241 self.event_queue.push_back(
1242 ProcState::Stopped {
1243 create_key,
1244 reason: ProcStopReason::HostWatchdog
1245 }
1246 );
1247 }
1248 if self.host_states.is_empty() {
1250 tracing::info!("no more hosts left, stopping the alloc");
1251 self.running = false;
1252 }
1253 reloop = true;
1255 break None;
1256 } else {
1257 tracing::warn!("unexpected comm watcher channel close");
1259 break None;
1260 }
1261 }
1262 }
1263 };
1264
1265 if reloop {
1266 assert!(update.is_none());
1269 continue;
1270 }
1271
1272 break match update {
1273 Some((
1274 Some(alloc_key),
1275 ProcState::Created {
1276 create_key,
1277 point,
1278 pid,
1279 },
1280 )) => match self.project_proc_into_global_extent(&alloc_key, &point) {
1281 Ok(global_point) => {
1282 tracing::debug!("reprojected coords: {} -> {}", point, global_point);
1283 Some(ProcState::Created {
1284 create_key,
1285 point: global_point,
1286 pid,
1287 })
1288 }
1289 Err(e) => {
1290 tracing::error!(
1291 "failed to project coords for proc: {}.{}: {}",
1292 alloc_key,
1293 create_key,
1294 e
1295 );
1296 None
1297 }
1298 },
1299 Some((None, ProcState::Created { .. })) => {
1300 panic!("illegal state: missing alloc_key for ProcState::Created event")
1301 }
1302 Some((_, update)) => {
1303 if let ProcState::Failed { description, .. } = &update {
1304 tracing::error!(description);
1305 self.failed = true;
1306 }
1307 Some(update)
1308 }
1309 None => None,
1310 };
1311 }
1312 }
1313
1314 fn spec(&self) -> &AllocSpec {
1315 &self.spec
1316 }
1317
1318 fn extent(&self) -> &Extent {
1319 &self.spec.extent
1320 }
1321
1322 fn world_id(&self) -> &WorldId {
1323 &self.world_id
1324 }
1325
1326 async fn stop(&mut self) -> Result<(), AllocatorError> {
1327 tracing::info!("stopping alloc");
1328
1329 for (host_id, task_state) in self.host_states.iter_mut() {
1330 tracing::debug!("stopping alloc at host {}", host_id);
1331 task_state.tx.post(RemoteProcessAllocatorMessage::Stop);
1332 }
1333
1334 Ok(())
1335 }
1336
1337 fn client_router_addr(&self) -> ChannelAddr {
1346 with_unspecified_port_or_any(&self.bootstrap_addr)
1347 }
1348}
1349
1350impl Drop for RemoteProcessAlloc {
1351 fn drop(&mut self) {
1352 tracing::debug!("dropping RemoteProcessAlloc of world_id {}", self.world_id);
1353 }
1354}
1355
1356#[cfg(test)]
1357mod test {
1358 use std::assert_matches::assert_matches;
1359
1360 use hyperactor::ActorRef;
1361 use hyperactor::channel::ChannelRx;
1362 use hyperactor::clock::ClockKind;
1363 use hyperactor::id;
1364 use ndslice::extent;
1365 use tokio::sync::oneshot;
1366
1367 use super::*;
1368 use crate::alloc::ChannelTransport;
1369 use crate::alloc::MockAlloc;
1370 use crate::alloc::MockAllocWrapper;
1371 use crate::alloc::MockAllocator;
1372 use crate::alloc::ProcStopReason;
1373 use crate::alloc::with_unspecified_port_or_any;
1374 use crate::proc_mesh::mesh_agent::ProcMeshAgent;
1375
1376 async fn read_all_created(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1377 let mut i: usize = 0;
1378 while i < alloc_len {
1379 let m = rx.recv().await.unwrap();
1380 match m {
1381 RemoteProcessProcStateMessage::Update(_, ProcState::Created { .. }) => i += 1,
1382 RemoteProcessProcStateMessage::HeartBeat => {}
1383 _ => panic!("unexpected message: {:?}", m),
1384 }
1385 }
1386 }
1387
1388 async fn read_all_running(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1389 let mut i: usize = 0;
1390 while i < alloc_len {
1391 let m = rx.recv().await.unwrap();
1392 match m {
1393 RemoteProcessProcStateMessage::Update(_, ProcState::Running { .. }) => i += 1,
1394 RemoteProcessProcStateMessage::HeartBeat => {}
1395 _ => panic!("unexpected message: {:?}", m),
1396 }
1397 }
1398 }
1399
1400 async fn read_all_stopped(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1401 let mut i: usize = 0;
1402 while i < alloc_len {
1403 let m = rx.recv().await.unwrap();
1404 match m {
1405 RemoteProcessProcStateMessage::Update(_, ProcState::Stopped { .. }) => i += 1,
1406 RemoteProcessProcStateMessage::HeartBeat => {}
1407 _ => panic!("unexpected message: {:?}", m),
1408 }
1409 }
1410 }
1411
1412 fn set_procstate_expectations(alloc: &mut MockAlloc, extent: Extent) {
1413 alloc.expect_extent().return_const(extent.clone());
1414 let mut create_keys = Vec::new();
1415 for point in extent.points() {
1416 let create_key = ShortUuid::generate();
1417 create_keys.push(create_key.clone());
1418 alloc.expect_next().times(1).return_once(move || {
1419 Some(ProcState::Created {
1420 create_key: create_key.clone(),
1421 point,
1422 pid: 0,
1423 })
1424 });
1425 }
1426 for (i, create_key) in create_keys
1427 .iter()
1428 .take(extent.num_ranks())
1429 .cloned()
1430 .enumerate()
1431 {
1432 let proc_id = format!("test[{i}]").parse().unwrap();
1433 let mesh_agent = ActorRef::<ProcMeshAgent>::attest(
1434 format!("test[{i}].mesh_agent[{i}]").parse().unwrap(),
1435 );
1436 alloc.expect_next().times(1).return_once(move || {
1437 Some(ProcState::Running {
1438 create_key,
1439 proc_id,
1440 addr: ChannelAddr::Unix("/proc0".parse().unwrap()),
1441 mesh_agent,
1442 })
1443 });
1444 }
1445 for create_key in create_keys.iter().take(extent.num_ranks()).cloned() {
1446 alloc.expect_next().times(1).return_once(|| {
1447 Some(ProcState::Stopped {
1448 create_key,
1449 reason: ProcStopReason::Unknown,
1450 })
1451 });
1452 }
1453 }
1454
1455 #[timed_test::async_timed_test(timeout_secs = 5)]
1456 async fn test_simple() {
1457 let config = hyperactor_config::global::lock();
1458 let _guard = config.override_key(
1459 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1460 Duration::from_millis(100),
1461 );
1462 hyperactor_telemetry::initialize_logging_for_test();
1463 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1464 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1465 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1466
1467 let extent = extent!(host = 1, gpu = 2);
1468 let tx = channel::dial(serve_addr.clone()).unwrap();
1469
1470 let world_id: WorldId = id!(test_world_id);
1471 let mut alloc = MockAlloc::new();
1472 alloc.expect_world_id().return_const(world_id.clone());
1473 alloc.expect_extent().return_const(extent.clone());
1474
1475 set_procstate_expectations(&mut alloc, extent.clone());
1476
1477 alloc.expect_next().return_const(None);
1479
1480 let mut allocator = MockAllocator::new();
1481 let total_messages = extent.num_ranks() * 3 + 1;
1482 let mock_wrapper = MockAllocWrapper::new_block_next(
1483 alloc,
1484 total_messages,
1486 );
1487 allocator
1488 .expect_allocate()
1489 .times(1)
1490 .return_once(move |_| Ok(mock_wrapper));
1491
1492 let remote_allocator = RemoteProcessAllocator::new();
1493 let handle = tokio::spawn({
1494 let remote_allocator = remote_allocator.clone();
1495 async move {
1496 remote_allocator
1497 .start_with_allocator(serve_addr, allocator, None)
1498 .await
1499 }
1500 });
1501
1502 let alloc_key = ShortUuid::generate();
1503
1504 tx.send(RemoteProcessAllocatorMessage::Allocate {
1505 alloc_key: alloc_key.clone(),
1506 extent: extent.clone(),
1507 bootstrap_addr,
1508 hosts: vec![],
1509 client_context: None,
1510 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1511 })
1512 .await
1513 .unwrap();
1514
1515 let m = rx.recv().await.unwrap();
1517 assert_matches!(
1518 m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1519 if got_world_id == world_id && got_alloc_key == alloc_key
1520 );
1521
1522 let mut rank: usize = 0;
1524 let mut create_keys = Vec::with_capacity(extent.num_ranks());
1525 while rank < extent.num_ranks() {
1526 let m = rx.recv().await.unwrap();
1527 match m {
1528 RemoteProcessProcStateMessage::Update(
1529 got_alloc_key,
1530 ProcState::Created {
1531 create_key, point, ..
1532 },
1533 ) => {
1534 let expected_point = extent.point_of_rank(rank).unwrap();
1535 assert_eq!(got_alloc_key, alloc_key);
1536 assert_eq!(point, expected_point);
1537 create_keys.push(create_key);
1538 rank += 1;
1539 }
1540 RemoteProcessProcStateMessage::HeartBeat => {}
1541 _ => panic!("unexpected message: {:?}", m),
1542 }
1543 }
1544 let mut rank: usize = 0;
1546 while rank < extent.num_ranks() {
1547 let m = rx.recv().await.unwrap();
1548 match m {
1549 RemoteProcessProcStateMessage::Update(
1550 got_alloc_key,
1551 ProcState::Running {
1552 create_key,
1553 proc_id,
1554 mesh_agent,
1555 addr: _,
1556 },
1557 ) => {
1558 assert_eq!(got_alloc_key, alloc_key);
1559 assert_eq!(create_key, create_keys[rank]);
1560 let expected_proc_id = format!("test[{}]", rank).parse().unwrap();
1561 let expected_mesh_agent = ActorRef::<ProcMeshAgent>::attest(
1562 format!("test[{}].mesh_agent[{}]", rank, rank)
1563 .parse()
1564 .unwrap(),
1565 );
1566 assert_eq!(proc_id, expected_proc_id);
1567 assert_eq!(mesh_agent, expected_mesh_agent);
1568 rank += 1;
1569 }
1570 RemoteProcessProcStateMessage::HeartBeat => {}
1571 _ => panic!("unexpected message: {:?}", m),
1572 }
1573 }
1574 let mut rank: usize = 0;
1576 while rank < extent.num_ranks() {
1577 let m = rx.recv().await.unwrap();
1578 match m {
1579 RemoteProcessProcStateMessage::Update(
1580 got_alloc_key,
1581 ProcState::Stopped {
1582 create_key,
1583 reason: ProcStopReason::Unknown,
1584 },
1585 ) => {
1586 assert_eq!(got_alloc_key, alloc_key);
1587 assert_eq!(create_key, create_keys[rank]);
1588 rank += 1;
1589 }
1590 RemoteProcessProcStateMessage::HeartBeat => {}
1591 _ => panic!("unexpected message: {:?}", m),
1592 }
1593 }
1594 loop {
1596 let m = rx.recv().await.unwrap();
1597 match m {
1598 RemoteProcessProcStateMessage::Done(got_alloc_key) => {
1599 assert_eq!(got_alloc_key, alloc_key);
1600 break;
1601 }
1602 RemoteProcessProcStateMessage::HeartBeat => {}
1603 _ => panic!("unexpected message: {:?}", m),
1604 }
1605 }
1606
1607 remote_allocator.terminate();
1608 handle.await.unwrap().unwrap();
1609 }
1610
1611 #[timed_test::async_timed_test(timeout_secs = 15)]
1612 async fn test_normal_stop() {
1613 let config = hyperactor_config::global::lock();
1614 let _guard = config.override_key(
1615 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1616 Duration::from_millis(100),
1617 );
1618 hyperactor_telemetry::initialize_logging_for_test();
1619 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1620 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1621 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1622
1623 let extent = extent!(host = 1, gpu = 2);
1624 let tx = channel::dial(serve_addr.clone()).unwrap();
1625
1626 let world_id: WorldId = id!(test_world_id);
1627 let mut alloc = MockAllocWrapper::new_block_next(
1628 MockAlloc::new(),
1629 extent.num_ranks() * 2,
1631 );
1632 let next_tx = alloc.notify_tx();
1633 alloc.alloc.expect_world_id().return_const(world_id.clone());
1634 alloc.alloc.expect_extent().return_const(extent.clone());
1635
1636 set_procstate_expectations(&mut alloc.alloc, extent.clone());
1637
1638 alloc.alloc.expect_next().return_const(None);
1639 alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1640
1641 let mut allocator = MockAllocator::new();
1642 allocator
1643 .expect_allocate()
1644 .times(1)
1645 .return_once(|_| Ok(alloc));
1646
1647 let remote_allocator = RemoteProcessAllocator::new();
1648 let handle = tokio::spawn({
1649 let remote_allocator = remote_allocator.clone();
1650 async move {
1651 remote_allocator
1652 .start_with_allocator(serve_addr, allocator, None)
1653 .await
1654 }
1655 });
1656
1657 let alloc_key = ShortUuid::generate();
1658 tx.send(RemoteProcessAllocatorMessage::Allocate {
1659 alloc_key: alloc_key.clone(),
1660 extent: extent.clone(),
1661 bootstrap_addr,
1662 hosts: vec![],
1663 client_context: None,
1664 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1665 })
1666 .await
1667 .unwrap();
1668
1669 let m = rx.recv().await.unwrap();
1671 assert_matches!(
1672 m,
1673 RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1674 if world_id == got_world_id && alloc_key == got_alloc_key
1675 );
1676
1677 read_all_created(&mut rx, extent.num_ranks()).await;
1678 read_all_running(&mut rx, extent.num_ranks()).await;
1679
1680 tracing::info!("stopping allocation");
1682 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1683 next_tx.send(()).unwrap();
1685
1686 read_all_stopped(&mut rx, extent.num_ranks()).await;
1687
1688 remote_allocator.terminate();
1689 handle.await.unwrap().unwrap();
1690 }
1691
1692 #[timed_test::async_timed_test(timeout_secs = 15)]
1693 async fn test_realloc() {
1694 let config = hyperactor_config::global::lock();
1695 let _guard = config.override_key(
1696 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1697 Duration::from_millis(100),
1698 );
1699 hyperactor_telemetry::initialize_logging_for_test();
1700 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1701 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1702 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1703
1704 let extent = extent!(host = 1, gpu = 2);
1705
1706 let tx = channel::dial(serve_addr.clone()).unwrap();
1707
1708 let world_id: WorldId = id!(test_world_id);
1709 let mut alloc1 = MockAllocWrapper::new_block_next(
1710 MockAlloc::new(),
1711 extent.num_ranks() * 2,
1713 );
1714 let next_tx1 = alloc1.notify_tx();
1715 alloc1
1716 .alloc
1717 .expect_world_id()
1718 .return_const(world_id.clone());
1719 alloc1.alloc.expect_extent().return_const(extent.clone());
1720
1721 set_procstate_expectations(&mut alloc1.alloc, extent.clone());
1722 alloc1.alloc.expect_next().return_const(None);
1723 alloc1.alloc.expect_stop().times(1).return_once(|| Ok(()));
1724 let mut alloc2 = MockAllocWrapper::new_block_next(
1726 MockAlloc::new(),
1727 extent.num_ranks() * 2,
1729 );
1730 let next_tx2 = alloc2.notify_tx();
1731 alloc2
1732 .alloc
1733 .expect_world_id()
1734 .return_const(world_id.clone());
1735 alloc2.alloc.expect_extent().return_const(extent.clone());
1736 set_procstate_expectations(&mut alloc2.alloc, extent.clone());
1737 alloc2.alloc.expect_next().return_const(None);
1738 alloc2.alloc.expect_stop().times(1).return_once(|| Ok(()));
1739
1740 let mut allocator = MockAllocator::new();
1741 allocator
1742 .expect_allocate()
1743 .times(1)
1744 .return_once(|_| Ok(alloc1));
1745 allocator
1747 .expect_allocate()
1748 .times(1)
1749 .return_once(|_| Ok(alloc2));
1750
1751 let remote_allocator = RemoteProcessAllocator::new();
1752 let handle = tokio::spawn({
1753 let remote_allocator = remote_allocator.clone();
1754 async move {
1755 remote_allocator
1756 .start_with_allocator(serve_addr, allocator, None)
1757 .await
1758 }
1759 });
1760
1761 let alloc_key = ShortUuid::generate();
1762
1763 tx.send(RemoteProcessAllocatorMessage::Allocate {
1764 alloc_key: alloc_key.clone(),
1765 extent: extent.clone(),
1766 bootstrap_addr: bootstrap_addr.clone(),
1767 hosts: vec![],
1768 client_context: None,
1769 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1770 })
1771 .await
1772 .unwrap();
1773
1774 let m = rx.recv().await.unwrap();
1776 assert_matches!(
1777 m,
1778 RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1779 if got_world_id == world_id && got_alloc_key == alloc_key
1780 );
1781
1782 read_all_created(&mut rx, extent.num_ranks()).await;
1783 read_all_running(&mut rx, extent.num_ranks()).await;
1784
1785 let alloc_key = ShortUuid::generate();
1786
1787 tx.send(RemoteProcessAllocatorMessage::Allocate {
1789 alloc_key: alloc_key.clone(),
1790 extent: extent.clone(),
1791 bootstrap_addr,
1792 hosts: vec![],
1793 client_context: None,
1794 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1795 })
1796 .await
1797 .unwrap();
1798 next_tx1.send(()).unwrap();
1800 read_all_stopped(&mut rx, extent.num_ranks()).await;
1802 let m = rx.recv().await.unwrap();
1803 assert_matches!(m, RemoteProcessProcStateMessage::Done(_));
1804 let m = rx.recv().await.unwrap();
1805 assert_matches!(
1806 m,
1807 RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1808 if got_world_id == world_id && got_alloc_key == alloc_key
1809 );
1810 read_all_created(&mut rx, extent.num_ranks()).await;
1812 read_all_running(&mut rx, extent.num_ranks()).await;
1813 tracing::info!("stopping allocation");
1815 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1816 next_tx2.send(()).unwrap();
1818
1819 read_all_stopped(&mut rx, extent.num_ranks()).await;
1820
1821 remote_allocator.terminate();
1822 handle.await.unwrap().unwrap();
1823 }
1824
1825 #[timed_test::async_timed_test(timeout_secs = 15)]
1826 async fn test_upstream_closed() {
1827 let config = hyperactor_config::global::lock();
1829 let _guard1 = config.override_key(
1830 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1831 Duration::from_secs(1),
1832 );
1833 let _guard2 = config.override_key(
1834 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1835 Duration::from_millis(100),
1836 );
1837
1838 hyperactor_telemetry::initialize_logging_for_test();
1839 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1840 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1841 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1842
1843 let extent = extent!(host = 1, gpu = 2);
1844
1845 let tx = channel::dial(serve_addr.clone()).unwrap();
1846
1847 let world_id: WorldId = id!(test_world_id);
1848 let mut alloc = MockAllocWrapper::new_block_next(
1849 MockAlloc::new(),
1850 extent.num_ranks() * 2,
1852 );
1853 let next_tx = alloc.notify_tx();
1854 alloc.alloc.expect_world_id().return_const(world_id.clone());
1855 alloc.alloc.expect_extent().return_const(extent.clone());
1856
1857 set_procstate_expectations(&mut alloc.alloc, extent.clone());
1858
1859 alloc.alloc.expect_next().return_const(None);
1860 let (stop_tx, stop_rx) = oneshot::channel();
1863 alloc.alloc.expect_stop().times(1).return_once(|| {
1864 stop_tx.send(()).unwrap();
1865 Ok(())
1866 });
1867
1868 let mut allocator = MockAllocator::new();
1869 allocator
1870 .expect_allocate()
1871 .times(1)
1872 .return_once(|_| Ok(alloc));
1873
1874 let remote_allocator = RemoteProcessAllocator::new();
1875 let handle = tokio::spawn({
1876 let remote_allocator = remote_allocator.clone();
1877 async move {
1878 remote_allocator
1879 .start_with_allocator(serve_addr, allocator, None)
1880 .await
1881 }
1882 });
1883
1884 let alloc_key = ShortUuid::generate();
1885
1886 tx.send(RemoteProcessAllocatorMessage::Allocate {
1887 alloc_key: alloc_key.clone(),
1888 extent: extent.clone(),
1889 bootstrap_addr,
1890 hosts: vec![],
1891 client_context: None,
1892 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1893 })
1894 .await
1895 .unwrap();
1896
1897 let m = rx.recv().await.unwrap();
1899 assert_matches!(
1900 m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1901 if got_world_id == world_id && got_alloc_key == alloc_key
1902 );
1903
1904 read_all_created(&mut rx, extent.num_ranks()).await;
1905 read_all_running(&mut rx, extent.num_ranks()).await;
1906
1907 tracing::info!("closing upstream");
1909 drop(rx);
1910 #[allow(clippy::disallowed_methods)]
1912 tokio::time::sleep(Duration::from_secs(2)).await;
1913 stop_rx.await.unwrap();
1915 next_tx.send(()).unwrap();
1917 remote_allocator.terminate();
1918 handle.await.unwrap().unwrap();
1919 }
1920
1921 #[timed_test::async_timed_test(timeout_secs = 15)]
1922 async fn test_inner_alloc_failure() {
1923 let config = hyperactor_config::global::lock();
1924 let _guard = config.override_key(
1925 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1926 Duration::from_mins(1),
1927 );
1928 hyperactor_telemetry::initialize_logging_for_test();
1929 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1930 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1931 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1932
1933 let extent = extent!(host = 1, gpu = 2);
1934
1935 let tx = channel::dial(serve_addr.clone()).unwrap();
1936
1937 let test_world_id: WorldId = id!(test_world_id);
1938 let mut alloc = MockAllocWrapper::new_block_next(
1939 MockAlloc::new(),
1940 1,
1942 );
1943 let next_tx = alloc.notify_tx();
1944 alloc
1945 .alloc
1946 .expect_world_id()
1947 .return_const(test_world_id.clone());
1948 alloc.alloc.expect_extent().return_const(extent.clone());
1949 alloc
1950 .alloc
1951 .expect_next()
1952 .times(1)
1953 .return_const(Some(ProcState::Failed {
1954 world_id: test_world_id.clone(),
1955 description: "test".to_string(),
1956 }));
1957 alloc.alloc.expect_next().times(1).return_const(None);
1958
1959 alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1960
1961 let mut allocator = MockAllocator::new();
1962 allocator
1963 .expect_allocate()
1964 .times(1)
1965 .return_once(|_| Ok(alloc));
1966
1967 let remote_allocator = RemoteProcessAllocator::new();
1968 let handle = tokio::spawn({
1969 let remote_allocator = remote_allocator.clone();
1970 async move {
1971 remote_allocator
1972 .start_with_allocator(serve_addr, allocator, None)
1973 .await
1974 }
1975 });
1976
1977 let alloc_key = ShortUuid::generate();
1978 tx.send(RemoteProcessAllocatorMessage::Allocate {
1979 alloc_key: alloc_key.clone(),
1980 extent: extent.clone(),
1981 bootstrap_addr,
1982 hosts: vec![],
1983 client_context: None,
1984 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1985 })
1986 .await
1987 .unwrap();
1988
1989 let m = rx.recv().await.unwrap();
1991 assert_matches!(
1992 m,
1993 RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1994 if test_world_id == got_world_id && alloc_key == got_alloc_key
1995 );
1996
1997 let m = rx.recv().await.unwrap();
1999 assert_matches!(
2000 m,
2001 RemoteProcessProcStateMessage::Update(
2002 got_alloc_key,
2003 ProcState::Failed { world_id, description }
2004 ) if got_alloc_key == alloc_key && world_id == test_world_id && description == "test"
2005 );
2006
2007 tracing::info!("stopping allocation");
2008 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
2009 next_tx.send(()).unwrap();
2011 let m = rx.recv().await.unwrap();
2013 assert_matches!(
2014 m,
2015 RemoteProcessProcStateMessage::Done(got_alloc_key)
2016 if got_alloc_key == alloc_key
2017 );
2018
2019 remote_allocator.terminate();
2020 handle.await.unwrap().unwrap();
2021 }
2022
2023 #[timed_test::async_timed_test(timeout_secs = 15)]
2024 async fn test_trace_id_propagation() {
2025 let config = hyperactor_config::global::lock();
2026 let _guard = config.override_key(
2027 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2028 Duration::from_mins(1),
2029 );
2030 hyperactor_telemetry::initialize_logging(ClockKind::default());
2031 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
2032 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
2033 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
2034
2035 let extent = extent!(host = 1, gpu = 1);
2036 let tx = channel::dial(serve_addr.clone()).unwrap();
2037 let test_world_id: WorldId = id!(test_world_id);
2038 let test_trace_id = "test_trace_id_12345";
2039
2040 let mut alloc = MockAlloc::new();
2042 alloc.expect_world_id().return_const(test_world_id.clone());
2043 alloc.expect_extent().return_const(extent.clone());
2044 alloc.expect_next().return_const(None);
2045
2046 let mut allocator = MockAllocator::new();
2048 allocator
2049 .expect_allocate()
2050 .times(1)
2051 .withf(move |spec: &AllocSpec| {
2052 spec.constraints
2054 .match_labels
2055 .get(CLIENT_TRACE_ID_LABEL)
2056 .is_some_and(|trace_id| trace_id == test_trace_id)
2057 })
2058 .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
2059
2060 let remote_allocator = RemoteProcessAllocator::new();
2061 let handle = tokio::spawn({
2062 let remote_allocator = remote_allocator.clone();
2063 async move {
2064 remote_allocator
2065 .start_with_allocator(serve_addr, allocator, None)
2066 .await
2067 }
2068 });
2069
2070 let alloc_key = ShortUuid::generate();
2071 tx.send(RemoteProcessAllocatorMessage::Allocate {
2072 alloc_key: alloc_key.clone(),
2073 extent: extent.clone(),
2074 bootstrap_addr,
2075 hosts: vec![],
2076 client_context: Some(ClientContext {
2077 trace_id: test_trace_id.to_string(),
2078 }),
2079 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
2080 })
2081 .await
2082 .unwrap();
2083
2084 let m = rx.recv().await.unwrap();
2086 assert_matches!(
2087 m,
2088 RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
2089 if got_world_id == test_world_id && got_alloc_key == alloc_key
2090 );
2091
2092 let m = rx.recv().await.unwrap();
2094 assert_matches!(
2095 m,
2096 RemoteProcessProcStateMessage::Done(got_alloc_key)
2097 if alloc_key == got_alloc_key
2098 );
2099
2100 remote_allocator.terminate();
2101 handle.await.unwrap().unwrap();
2102 }
2103
2104 #[timed_test::async_timed_test(timeout_secs = 15)]
2105 async fn test_trace_id_propagation_no_client_context() {
2106 let config = hyperactor_config::global::lock();
2107 let _guard = config.override_key(
2108 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2109 Duration::from_mins(1),
2110 );
2111 hyperactor_telemetry::initialize_logging(ClockKind::default());
2112 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
2113 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
2114 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
2115
2116 let extent = extent!(host = 1, gpu = 1);
2117 let tx = channel::dial(serve_addr.clone()).unwrap();
2118 let test_world_id: WorldId = id!(test_world_id);
2119
2120 let mut alloc = MockAlloc::new();
2122 alloc.expect_world_id().return_const(test_world_id.clone());
2123 alloc.expect_extent().return_const(extent.clone());
2124 alloc.expect_next().return_const(None);
2125
2126 let mut allocator = MockAllocator::new();
2128 allocator
2129 .expect_allocate()
2130 .times(1)
2131 .withf(move |spec: &AllocSpec| {
2132 spec.constraints.match_labels.is_empty()
2134 })
2135 .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
2136
2137 let remote_allocator = RemoteProcessAllocator::new();
2138 let handle = tokio::spawn({
2139 let remote_allocator = remote_allocator.clone();
2140 async move {
2141 remote_allocator
2142 .start_with_allocator(serve_addr, allocator, None)
2143 .await
2144 }
2145 });
2146
2147 let alloc_key = ShortUuid::generate();
2148 tx.send(RemoteProcessAllocatorMessage::Allocate {
2149 alloc_key: alloc_key.clone(),
2150 extent: extent.clone(),
2151 bootstrap_addr,
2152 hosts: vec![],
2153 client_context: None,
2154 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
2155 })
2156 .await
2157 .unwrap();
2158
2159 let m = rx.recv().await.unwrap();
2161 assert_matches!(
2162 m,
2163 RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
2164 if got_world_id == test_world_id && got_alloc_key == alloc_key
2165 );
2166
2167 let m = rx.recv().await.unwrap();
2169 assert_matches!(
2170 m,
2171 RemoteProcessProcStateMessage::Done(got_alloc_key)
2172 if got_alloc_key == alloc_key
2173 );
2174
2175 remote_allocator.terminate();
2176 handle.await.unwrap().unwrap();
2177 }
2178}
2179
2180#[cfg(test)]
2181mod test_alloc {
2182 use std::os::unix::process::ExitStatusExt;
2183
2184 use hyperactor::clock::ClockKind;
2185 use hyperactor_config;
2186 use ndslice::extent;
2187 use nix::sys::signal;
2188 use nix::unistd::Pid;
2189 use timed_test::async_timed_test;
2190
2191 use super::*;
2192
2193 #[async_timed_test(timeout_secs = 60)]
2194 #[cfg(fbcode_build)]
2195 async fn test_alloc_simple() {
2196 let config = hyperactor_config::global::lock();
2198 let _guard1 = config.override_key(
2199 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2200 Duration::from_secs(1),
2201 );
2202 let _guard2 = config.override_key(
2203 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2204 Duration::from_millis(100),
2205 );
2206 hyperactor_telemetry::initialize_logging(ClockKind::default());
2207
2208 let spec = AllocSpec {
2209 extent: extent!(host = 2, gpu = 2),
2210 constraints: Default::default(),
2211 proc_name: None,
2212 transport: ChannelTransport::Unix,
2213 proc_allocation_mode: Default::default(),
2214 };
2215 let world_id = WorldId("test_world_id".to_string());
2216
2217 let task1_allocator = RemoteProcessAllocator::new();
2218 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2219 let task1_addr_string = task1_addr.to_string();
2220 let task1_cmd = Command::new(crate::testresource::get(
2221 "monarch/hyperactor_mesh/bootstrap",
2222 ));
2223 let task2_allocator = RemoteProcessAllocator::new();
2224 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2225 let task2_addr_string = task2_addr.to_string();
2226 let task2_cmd = Command::new(crate::testresource::get(
2227 "monarch/hyperactor_mesh/bootstrap",
2228 ));
2229 let task1_allocator_copy = task1_allocator.clone();
2230 let task1_allocator_handle = tokio::spawn(async move {
2231 tracing::info!("spawning task1");
2232 task1_allocator_copy
2233 .start(task1_cmd, task1_addr, None)
2234 .await
2235 .unwrap();
2236 });
2237 let task2_allocator_copy = task2_allocator.clone();
2238 let task2_allocator_handle = tokio::spawn(async move {
2239 task2_allocator_copy
2240 .start(task2_cmd, task2_addr, None)
2241 .await
2242 .unwrap();
2243 });
2244
2245 let mut initializer = MockRemoteProcessAllocInitializer::new();
2246 initializer.expect_initialize_alloc().return_once(move || {
2247 Ok(vec![
2248 RemoteProcessAllocHost {
2249 hostname: task1_addr_string,
2250 id: "task1".to_string(),
2251 },
2252 RemoteProcessAllocHost {
2253 hostname: task2_addr_string,
2254 id: "task2".to_string(),
2255 },
2256 ])
2257 });
2258 let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2259 .await
2260 .unwrap();
2261 let mut created = HashSet::new();
2262 let mut running_procs = HashSet::new();
2263 let mut proc_points = HashSet::new();
2264 for _ in 0..spec.extent.num_ranks() * 2 {
2265 let proc_state = alloc.next().await.unwrap();
2266 tracing::debug!("test got message: {:?}", proc_state);
2267 match proc_state {
2268 ProcState::Created {
2269 create_key, point, ..
2270 } => {
2271 created.insert(create_key);
2272 proc_points.insert(point);
2273 }
2274 ProcState::Running { create_key, .. } => {
2275 assert!(created.remove(&create_key));
2276 running_procs.insert(create_key);
2277 }
2278 _ => panic!("expected Created or Running"),
2279 }
2280 }
2281 assert!(created.is_empty());
2282 assert!(
2284 spec.extent
2285 .points()
2286 .all(|point| proc_points.contains(&point))
2287 );
2288
2289 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2291 tokio::select! {
2292 _ = hyperactor::clock::RealClock.sleep_until(timeout) => {},
2293 _ = alloc.next() => panic!("expected no more items"),
2294 }
2295
2296 alloc.stop().await.unwrap();
2298 for _ in 0..spec.extent.num_ranks() {
2299 let proc_state = alloc.next().await.unwrap();
2300 tracing::info!("test received next proc_state: {:?}", proc_state);
2301 match proc_state {
2302 ProcState::Stopped {
2303 create_key, reason, ..
2304 } => {
2305 assert!(running_procs.remove(&create_key));
2306 assert_eq!(reason, ProcStopReason::Stopped);
2307 }
2308 _ => panic!("expected stopped"),
2309 }
2310 }
2311 let proc_state = alloc.next().await;
2313 assert!(proc_state.is_none());
2314 let proc_state = alloc.next().await;
2316 assert!(proc_state.is_none());
2317
2318 task1_allocator.terminate();
2319 task1_allocator_handle.await.unwrap();
2320 task2_allocator.terminate();
2321 task2_allocator_handle.await.unwrap();
2322 }
2323
2324 #[async_timed_test(timeout_secs = 60)]
2325 #[cfg(fbcode_build)]
2326 async fn test_alloc_host_failure() {
2327 let config = hyperactor_config::global::lock();
2329 let _guard1 = config.override_key(
2330 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2331 Duration::from_secs(1),
2332 );
2333 let _guard2 = config.override_key(
2334 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2335 Duration::from_millis(100),
2336 );
2337 hyperactor_telemetry::initialize_logging(ClockKind::default());
2338
2339 let spec = AllocSpec {
2340 extent: extent!(host = 2, gpu = 2),
2341 constraints: Default::default(),
2342 proc_name: None,
2343 transport: ChannelTransport::Unix,
2344 proc_allocation_mode: Default::default(),
2345 };
2346 let world_id = WorldId("test_world_id".to_string());
2347
2348 let task1_allocator = RemoteProcessAllocator::new();
2349 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2350 let task1_addr_string = task1_addr.to_string();
2351 let task1_cmd = Command::new(crate::testresource::get(
2352 "monarch/hyperactor_mesh/bootstrap",
2353 ));
2354 let task2_allocator = RemoteProcessAllocator::new();
2355 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2356 let task2_addr_string = task2_addr.to_string();
2357 let task2_cmd = Command::new(crate::testresource::get(
2358 "monarch/hyperactor_mesh/bootstrap",
2359 ));
2360 let task1_allocator_copy = task1_allocator.clone();
2361 let task1_allocator_handle = tokio::spawn(async move {
2362 tracing::info!("spawning task1");
2363 task1_allocator_copy
2364 .start(task1_cmd, task1_addr, None)
2365 .await
2366 .unwrap();
2367 tracing::info!("task1 terminated");
2368 });
2369 let task2_allocator_copy = task2_allocator.clone();
2370 let task2_allocator_handle = tokio::spawn(async move {
2371 task2_allocator_copy
2372 .start(task2_cmd, task2_addr, None)
2373 .await
2374 .unwrap();
2375 tracing::info!("task2 terminated");
2376 });
2377
2378 let mut initializer = MockRemoteProcessAllocInitializer::new();
2379 initializer.expect_initialize_alloc().return_once(move || {
2380 Ok(vec![
2381 RemoteProcessAllocHost {
2382 hostname: task1_addr_string,
2383 id: "task1".to_string(),
2384 },
2385 RemoteProcessAllocHost {
2386 hostname: task2_addr_string,
2387 id: "task2".to_string(),
2388 },
2389 ])
2390 });
2391 let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2392 .await
2393 .unwrap();
2394 for _ in 0..spec.extent.num_ranks() * 2 {
2395 match alloc.next().await {
2396 Some(ProcState::Created { .. }) | Some(ProcState::Running { .. }) => {}
2397 _ => panic!("expected Created or Running"),
2398 }
2399 }
2400
2401 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2403 tokio::select! {
2404 _ = hyperactor::clock::RealClock
2405 .sleep_until(timeout) => {},
2406 _ = alloc.next() => panic!("expected no more items"),
2407 }
2408
2409 tracing::info!("aborting task1 allocator");
2411 task1_allocator_handle.abort();
2412 RealClock
2413 .sleep(
2414 hyperactor_config::global::get(
2415 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2416 ) * 2,
2417 )
2418 .await;
2419 for _ in 0..spec.extent.num_ranks() / 2 {
2420 let proc_state = alloc.next().await.unwrap();
2421 tracing::info!("test received next proc_state: {:?}", proc_state);
2422 match proc_state {
2423 ProcState::Stopped { reason, .. } => {
2424 assert_eq!(reason, ProcStopReason::HostWatchdog);
2425 }
2426 _ => panic!("expected stopped"),
2427 }
2428 }
2429 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2431 tokio::select! {
2432 _ = hyperactor::clock::RealClock
2433 .sleep_until(timeout) => {},
2434 _ = alloc.next() => panic!("expected no more items"),
2435 }
2436
2437 tracing::info!("aborting task2 allocator");
2439 task2_allocator_handle.abort();
2440 RealClock
2441 .sleep(
2442 hyperactor_config::global::get(
2443 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2444 ) * 2,
2445 )
2446 .await;
2447 for _ in 0..spec.extent.num_ranks() / 2 {
2448 let proc_state = alloc.next().await.unwrap();
2449 tracing::info!("test received next proc_state: {:?}", proc_state);
2450 match proc_state {
2451 ProcState::Stopped { reason, .. } => {
2452 assert_eq!(reason, ProcStopReason::HostWatchdog);
2453 }
2454 _ => panic!("expected stopped"),
2455 }
2456 }
2457 let proc_state = alloc.next().await;
2459 assert!(proc_state.is_none());
2460 let proc_state = alloc.next().await;
2462 assert!(proc_state.is_none());
2463 }
2464
2465 #[async_timed_test(timeout_secs = 15)]
2466 #[cfg(fbcode_build)]
2467 async fn test_alloc_inner_alloc_failure() {
2468 unsafe {
2470 std::env::set_var("MONARCH_MESSAGE_DELIVERY_TIMEOUT_SECS", "1");
2471 }
2472
2473 let config = hyperactor_config::global::lock();
2474 let _guard = config.override_key(
2475 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2476 Duration::from_millis(100),
2477 );
2478 hyperactor_telemetry::initialize_logging_for_test();
2479
2480 let spec = AllocSpec {
2481 extent: extent!(host = 2, gpu = 2),
2482 constraints: Default::default(),
2483 proc_name: None,
2484 transport: ChannelTransport::Unix,
2485 proc_allocation_mode: Default::default(),
2486 };
2487 let world_id = WorldId("test_world_id".to_string());
2488
2489 let task1_allocator = RemoteProcessAllocator::new();
2490 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2491 let task1_addr_string = task1_addr.to_string();
2492 let task1_cmd = Command::new(crate::testresource::get(
2493 "monarch/hyperactor_mesh/bootstrap",
2494 ));
2495 let task2_allocator = RemoteProcessAllocator::new();
2496 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2497 let task2_addr_string = task2_addr.to_string();
2498 let task2_cmd = Command::new("/caught/somewhere/in/time");
2500 let task1_allocator_copy = task1_allocator.clone();
2501 let task1_allocator_handle = tokio::spawn(async move {
2502 tracing::info!("spawning task1");
2503 task1_allocator_copy
2504 .start(task1_cmd, task1_addr, None)
2505 .await
2506 .unwrap();
2507 });
2508 let task2_allocator_copy = task2_allocator.clone();
2509 let task2_allocator_handle = tokio::spawn(async move {
2510 task2_allocator_copy
2511 .start(task2_cmd, task2_addr, None)
2512 .await
2513 .unwrap();
2514 });
2515
2516 let mut initializer = MockRemoteProcessAllocInitializer::new();
2517 initializer.expect_initialize_alloc().return_once(move || {
2518 Ok(vec![
2519 RemoteProcessAllocHost {
2520 hostname: task1_addr_string,
2521 id: "task1".to_string(),
2522 },
2523 RemoteProcessAllocHost {
2524 hostname: task2_addr_string,
2525 id: "task2".to_string(),
2526 },
2527 ])
2528 });
2529 let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2530 .await
2531 .unwrap();
2532 let mut created = HashSet::new();
2533 let mut started_procs = HashSet::new();
2534 let mut proc_points = HashSet::new();
2535 let mut failed = 0;
2536 for _ in 0..spec.extent.num_ranks() + 1 {
2538 let proc_state = alloc.next().await.unwrap();
2539 tracing::debug!("test got message: {:?}", proc_state);
2540 match proc_state {
2541 ProcState::Created {
2542 create_key, point, ..
2543 } => {
2544 created.insert(create_key);
2545 proc_points.insert(point);
2546 }
2547 ProcState::Running { create_key, .. } => {
2548 assert!(created.remove(&create_key));
2549 started_procs.insert(create_key);
2550 }
2551 ProcState::Failed { .. } => {
2552 failed += 1;
2553 }
2554 _ => panic!("expected Created, Running or Failed"),
2555 }
2556 }
2557 assert!(created.is_empty());
2558 assert_eq!(failed, 1);
2559 for rank in 0..spec.extent.num_ranks() / 2 {
2561 let point = spec.extent.point_of_rank(rank).unwrap();
2562 assert!(proc_points.contains(&point));
2563 }
2564
2565 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2567 tokio::select! {
2568 _ = hyperactor::clock::RealClock
2569 .sleep_until(timeout) => {},
2570 _ = alloc.next() => panic!("expected no more items"),
2571 }
2572
2573 alloc.stop().await.unwrap();
2575 for _ in 0..spec.extent.num_ranks() / 2 {
2576 let proc_state = alloc.next().await.unwrap();
2577 tracing::info!("test received next proc_state: {:?}", proc_state);
2578 match proc_state {
2579 ProcState::Stopped {
2580 create_key, reason, ..
2581 } => {
2582 assert!(started_procs.remove(&create_key));
2583 assert_eq!(reason, ProcStopReason::Stopped);
2584 }
2585 _ => panic!("expected stopped"),
2586 }
2587 }
2588 let proc_state = alloc.next().await;
2590 assert!(proc_state.is_none());
2591 let proc_state = alloc.next().await;
2593 assert!(proc_state.is_none());
2594
2595 task1_allocator.terminate();
2596 task1_allocator_handle.await.unwrap();
2597 task2_allocator.terminate();
2598 task2_allocator_handle.await.unwrap();
2599 }
2600
2601 #[async_timed_test(timeout_secs = 60)]
2602 #[cfg(fbcode_build)]
2603 async fn test_remote_process_alloc_signal_handler() {
2604 hyperactor_telemetry::initialize_logging_for_test();
2605 let num_proc_meshes = 5;
2606 let hosts_per_proc_mesh = 5;
2607
2608 let pid_addr = ChannelAddr::any(ChannelTransport::Unix);
2609 let (pid_addr, mut pid_rx) = channel::serve::<u32>(pid_addr).unwrap();
2610
2611 let addresses = (0..(num_proc_meshes * hosts_per_proc_mesh))
2612 .map(|_| ChannelAddr::any(ChannelTransport::Unix).to_string())
2613 .collect::<Vec<_>>();
2614
2615 let remote_process_allocators = addresses
2616 .iter()
2617 .map(|addr| {
2618 Command::new(crate::testresource::get(
2619 "monarch/hyperactor_mesh/remote_process_allocator",
2620 ))
2621 .env("RUST_LOG", "info")
2622 .arg(format!("--addr={addr}"))
2623 .stdout(std::process::Stdio::piped())
2624 .spawn()
2625 .unwrap()
2626 })
2627 .collect::<Vec<_>>();
2628
2629 let done_allocating_addr = ChannelAddr::any(ChannelTransport::Unix);
2630 let (done_allocating_addr, mut done_allocating_rx) =
2631 channel::serve::<()>(done_allocating_addr).unwrap();
2632 let mut remote_process_alloc = Command::new(crate::testresource::get(
2633 "monarch/hyperactor_mesh/remote_process_alloc",
2634 ))
2635 .arg(format!("--done-allocating-addr={}", done_allocating_addr))
2636 .arg(format!("--addresses={}", addresses.join(",")))
2637 .arg(format!("--num-proc-meshes={}", num_proc_meshes))
2638 .arg(format!("--hosts-per-proc-mesh={}", hosts_per_proc_mesh))
2639 .arg(format!("--pid-addr={}", pid_addr))
2640 .spawn()
2641 .unwrap();
2642
2643 done_allocating_rx.recv().await.unwrap();
2644 let mut received_pids = Vec::new();
2645 while let Ok(pid) = pid_rx.recv().await {
2646 received_pids.push(pid);
2647 if received_pids.len() == remote_process_allocators.len() {
2648 break;
2649 }
2650 }
2651
2652 signal::kill(
2653 Pid::from_raw(remote_process_alloc.id().unwrap() as i32),
2654 signal::SIGINT,
2655 )
2656 .unwrap();
2657
2658 assert_eq!(
2659 remote_process_alloc.wait().await.unwrap().signal(),
2660 Some(signal::SIGINT as i32)
2661 );
2662
2663 RealClock.sleep(tokio::time::Duration::from_secs(5)).await;
2664
2665 for child_pid in received_pids {
2667 let pid_check = Command::new("kill")
2668 .arg("-0")
2669 .arg(child_pid.to_string())
2670 .output()
2671 .await
2672 .expect("Failed to check if PID is alive");
2673
2674 assert!(
2675 !pid_check.status.success(),
2676 "PID {} should no longer be alive",
2677 child_pid
2678 );
2679 }
2680
2681 for mut remote_process_allocator in remote_process_allocators {
2684 remote_process_allocator.kill().await.unwrap();
2685 }
2686 }
2687}