1use std::collections::HashMap;
10use std::collections::HashSet;
11use std::collections::VecDeque;
12use std::sync::Arc;
13use std::time::Duration;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use dashmap::DashMap;
18use futures::FutureExt;
19use futures::future::join_all;
20use futures::future::select_all;
21use hyperactor::Named;
22use hyperactor::WorldId;
23use hyperactor::channel;
24use hyperactor::channel::ChannelAddr;
25use hyperactor::channel::ChannelRx;
26use hyperactor::channel::ChannelTransport;
27use hyperactor::channel::ChannelTx;
28use hyperactor::channel::Rx;
29use hyperactor::channel::TcpMode;
30use hyperactor::channel::Tx;
31use hyperactor::channel::TxStatus;
32use hyperactor::clock;
33use hyperactor::clock::Clock;
34use hyperactor::clock::RealClock;
35use hyperactor::config;
36use hyperactor::mailbox::DialMailboxRouter;
37use hyperactor::mailbox::MailboxServer;
38use hyperactor::observe_async;
39use hyperactor::observe_result;
40use hyperactor::reference::Reference;
41use hyperactor::serde_json;
42use mockall::automock;
43use ndslice::Region;
44use ndslice::Slice;
45use ndslice::View;
46use ndslice::ViewExt;
47use ndslice::view::Extent;
48use ndslice::view::Point;
49use serde::Deserialize;
50use serde::Serialize;
51use strum::AsRefStr;
52use tokio::io::AsyncWriteExt;
53use tokio::process::Command;
54use tokio::sync::mpsc::UnboundedReceiver;
55use tokio::sync::mpsc::UnboundedSender;
56use tokio::sync::mpsc::unbounded_channel;
57use tokio::task::JoinHandle;
58use tokio_stream::StreamExt;
59use tokio_stream::wrappers::WatchStream;
60use tokio_util::sync::CancellationToken;
61
62use crate::alloc::Alloc;
63use crate::alloc::AllocConstraints;
64use crate::alloc::AllocSpec;
65use crate::alloc::Allocator;
66use crate::alloc::AllocatorError;
67use crate::alloc::ProcState;
68use crate::alloc::ProcStopReason;
69use crate::alloc::ProcessAllocator;
70use crate::alloc::REMOTE_ALLOC_BOOTSTRAP_ADDR;
71use crate::alloc::process::CLIENT_TRACE_ID_LABEL;
72use crate::alloc::process::ClientContext;
73use crate::alloc::serve_with_config;
74use crate::alloc::with_unspecified_port_or_any;
75use crate::shortuuid::ShortUuid;
76
77#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
79pub enum RemoteProcessAllocatorMessage {
80 Allocate {
82 alloc_key: ShortUuid,
84 extent: Extent,
86 bootstrap_addr: ChannelAddr,
88 hosts: Vec<String>,
91 client_context: Option<ClientContext>,
95 forwarder_addr: ChannelAddr,
97 },
98 Stop,
100 HeartBeat,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
109pub enum RemoteProcessProcStateMessage {
110 Allocated {
112 alloc_key: ShortUuid,
113 world_id: WorldId,
114 },
115 Update(ShortUuid, ProcState),
117 Done(ShortUuid),
119 HeartBeat,
121}
122
123pub struct RemoteProcessAllocator {
125 cancel_token: CancellationToken,
126}
127
128async fn conditional_sleeper<F: futures::Future<Output = ()>>(t: Option<F>) {
129 match t {
130 Some(timer) => timer.await,
131 None => futures::future::pending().await,
132 }
133}
134
135impl RemoteProcessAllocator {
136 pub fn new() -> Arc<Self> {
138 Arc::new(Self {
139 cancel_token: CancellationToken::new(),
140 })
141 }
142
143 pub fn terminate(&self) {
145 self.cancel_token.cancel();
146 }
147
148 #[hyperactor::instrument]
164 pub async fn start(
165 &self,
166 cmd: Command,
167 serve_addr: ChannelAddr,
168 timeout: Option<Duration>,
169 ) -> Result<(), anyhow::Error> {
170 let process_allocator = ProcessAllocator::new(cmd);
171 self.start_with_allocator(serve_addr, process_allocator, timeout)
172 .await
173 }
174
175 #[tracing::instrument(skip(self, process_allocator))]
179 pub async fn start_with_allocator<A: Allocator + Send + Sync + 'static>(
180 &self,
181 serve_addr: ChannelAddr,
182 mut process_allocator: A,
183 timeout: Option<Duration>,
184 ) -> Result<(), anyhow::Error>
185 where
186 <A as Allocator>::Alloc: Send,
187 <A as Allocator>::Alloc: Sync,
188 {
189 tracing::info!("starting remote allocator on: {}", serve_addr);
190 let (_, mut rx) = channel::serve(serve_addr.clone()).map_err(anyhow::Error::from)?;
191
192 struct ActiveAllocation {
193 handle: JoinHandle<()>,
194 cancel_token: CancellationToken,
195 }
196 #[observe_async("RemoteProcessAllocator")]
197 async fn ensure_previous_alloc_stopped(active_allocation: &mut Option<ActiveAllocation>) {
198 if let Some(active_allocation) = active_allocation.take() {
199 tracing::info!("previous alloc found, stopping");
200 active_allocation.cancel_token.cancel();
201 match active_allocation.handle.await {
202 Ok(_) => {
203 tracing::info!("allocation stopped.")
205 }
206 Err(e) => {
207 tracing::error!("allocation handler failed: {}", e);
208 }
209 }
210 }
211 }
212
213 let mut active_allocation: Option<ActiveAllocation> = None;
214 loop {
215 let sleep = conditional_sleeper(timeout.map(|t| RealClock.sleep(t)));
218 tokio::select! {
219 msg = rx.recv() => {
220 match msg {
221 Ok(RemoteProcessAllocatorMessage::Allocate {
222 alloc_key,
223 extent,
224 bootstrap_addr,
225 hosts,
226 client_context,
227 forwarder_addr,
228 }) => {
229 tracing::info!("received allocation request for {} with extent {}", alloc_key, extent);
230 ensure_previous_alloc_stopped(&mut active_allocation).await;
231
232 let mut constraints: AllocConstraints = Default::default();
234 if let Some(context) = &client_context {
235 constraints = AllocConstraints {
236 match_labels: HashMap::from([(
237 CLIENT_TRACE_ID_LABEL.to_string(),
238 context.trace_id.to_string(),
239 )]
240 )};
241 tracing::info!(
242 monarch_client_trace_id = context.trace_id.to_string(),
243 "allocating...",
244 );
245 }
246
247
248 let spec = AllocSpec {
249 extent,
250 constraints,
251 proc_name: None, transport: ChannelTransport::Unix,
253 proc_allocation_mode: Default::default(),
254 };
255
256 match process_allocator.allocate(spec.clone()).await {
257 Ok(alloc) => {
258 let cancel_token = CancellationToken::new();
259 active_allocation = Some(ActiveAllocation {
260 cancel_token: cancel_token.clone(),
261 handle: tokio::spawn(Self::handle_allocation_request(
262 Box::new(alloc) as Box<dyn Alloc + Send + Sync>,
263 alloc_key,
264 bootstrap_addr,
265 hosts,
266 cancel_token,
267 forwarder_addr,
268 )),
269 })
270 }
271 Err(e) => {
272 tracing::error!("allocation for {:?} failed: {}", spec, e);
273 continue;
274 }
275 }
276 }
277 Ok(RemoteProcessAllocatorMessage::Stop) => {
278 tracing::info!("received stop request");
279
280 ensure_previous_alloc_stopped(&mut active_allocation).await;
281 }
282 Ok(RemoteProcessAllocatorMessage::HeartBeat) => {}
286 Err(e) => {
287 tracing::error!("upstream channel error: {}", e);
288 continue;
289 }
290 }
291 }
292 _ = self.cancel_token.cancelled() => {
293 tracing::info!("main loop cancelled");
294
295 ensure_previous_alloc_stopped(&mut active_allocation).await;
296
297 break;
298 }
299 _ = sleep => {
300 if active_allocation.is_some() {
302 continue;
303 }
304 tracing::warn!("timeout of {} seconds elapsed without any allocations, exiting", timeout.unwrap_or_default().as_secs());
307 break;
308 }
309 }
310 }
311
312 Ok(())
313 }
314
315 #[tracing::instrument(skip(alloc, cancel_token))]
316 #[observe_async("RemoteProcessAllocator")]
317 async fn handle_allocation_request(
318 alloc: Box<dyn Alloc + Send + Sync>,
319 alloc_key: ShortUuid,
320 bootstrap_addr: ChannelAddr,
321 hosts: Vec<String>,
322 cancel_token: CancellationToken,
323 forwarder_addr: ChannelAddr,
324 ) {
325 tracing::info!("handle allocation request, bootstrap_addr: {bootstrap_addr}");
326 let (forwarder_addr, forwarder_rx) = match serve_with_config(forwarder_addr) {
328 Ok(v) => v,
329 Err(e) => {
330 tracing::error!("failed to to bootstrap forwarder actor: {}", e);
331 return;
332 }
333 };
334 let router = DialMailboxRouter::new();
335 let mailbox_handle = router.clone().serve(forwarder_rx);
336 tracing::info!("started forwarder on: {}", forwarder_addr);
337
338 if let Ok(hosts_file) = std::env::var("TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE") {
341 tracing::info!("writing hosts to {}", hosts_file);
342 #[derive(Serialize)]
343 struct Hosts {
344 hostnames: Vec<String>,
345 }
346 match serde_json::to_string(&Hosts { hostnames: hosts }) {
347 Ok(json) => match tokio::fs::File::create(&hosts_file).await {
348 Ok(mut file) => {
349 if file.write_all(json.as_bytes()).await.is_err() {
350 tracing::error!("failed to write hosts to {}", hosts_file);
351 return;
352 }
353 }
354 Err(e) => {
355 tracing::error!("failed to open hosts file {}: {}", hosts_file, e);
356 return;
357 }
358 },
359 Err(e) => {
360 tracing::error!("failed to serialize hosts: {}", e);
361 return;
362 }
363 }
364 }
365
366 Self::handle_allocation_loop(
367 alloc,
368 alloc_key,
369 bootstrap_addr,
370 router,
371 forwarder_addr,
372 cancel_token,
373 )
374 .await;
375
376 mailbox_handle.stop("alloc stopped");
377 if let Err(e) = mailbox_handle.await {
378 tracing::error!("failed to join forwarder: {}", e);
379 }
380 }
381
382 async fn handle_allocation_loop(
383 mut alloc: Box<dyn Alloc + Send + Sync>,
384 alloc_key: ShortUuid,
385 bootstrap_addr: ChannelAddr,
386 router: DialMailboxRouter,
387 forward_addr: ChannelAddr,
388 cancel_token: CancellationToken,
389 ) {
390 let world_id = alloc.world_id().clone();
391 tracing::info!("starting handle allocation loop for {}", world_id);
392 let tx = match channel::dial(bootstrap_addr) {
393 Ok(tx) => tx,
394 Err(err) => {
395 tracing::error!("failed to dial bootstrap address: {}", err);
396 return;
397 }
398 };
399 let message = RemoteProcessProcStateMessage::Allocated {
400 alloc_key: alloc_key.clone(),
401 world_id,
402 };
403 tracing::info!(name = message.as_ref(), "sending allocated message",);
404 if let Err(e) = tx.send(message).await {
405 tracing::error!("failed to send Allocated message: {}", e);
406 return;
407 }
408
409 let mut mesh_agents_by_create_key = HashMap::new();
410 let mut running = true;
411 let tx_status = tx.status().clone();
412 let mut tx_watcher = WatchStream::new(tx_status);
413 loop {
414 tokio::select! {
415 _ = cancel_token.cancelled(), if running => {
416 tracing::info!("cancelled, stopping allocation");
417 running = false;
418 if let Err(e) = alloc.stop().await {
419 tracing::error!("stop failed: {}", e);
420 break;
421 }
422 }
423 status = tx_watcher.next(), if running => {
424 match status {
425 Some(TxStatus::Closed) => {
426 tracing::error!("upstream channel state closed");
427 break;
428 },
429 _ => {
430 tracing::debug!("got channel event: {:?}", status.unwrap());
431 continue;
432 }
433 }
434 }
435 e = alloc.next() => {
436 match e {
437 Some(event) => {
438 tracing::debug!(name = event.as_ref(), "got event: {:?}", event);
439 let event = match event {
440 ProcState::Created { .. } => event,
441 ProcState::Running { create_key, proc_id, mesh_agent, addr } => {
442 tracing::debug!("remapping mesh_agent {}: addr {} -> {}", mesh_agent, addr, forward_addr);
444 mesh_agents_by_create_key.insert(create_key.clone(), mesh_agent.clone());
445 router.bind(mesh_agent.actor_id().proc_id().clone().into(), addr);
446 ProcState::Running { create_key, proc_id, mesh_agent, addr: forward_addr.clone() }
447 },
448 ProcState::Stopped { create_key, reason } => {
449 match mesh_agents_by_create_key.remove(&create_key) {
450 Some(mesh_agent) => {
451 tracing::debug!("unmapping mesh_agent {}", mesh_agent);
452 let agent_ref: Reference = mesh_agent.actor_id().proc_id().clone().into();
453 router.unbind(&agent_ref);
454 },
455 None => {
456 tracing::warn!("mesh_agent not found for create key {}", create_key);
457 }
458 }
459 ProcState::Stopped { create_key, reason }
460 },
461 ProcState::Failed { ref world_id, ref description } => {
462 tracing::error!("allocation failed for {}: {}", world_id, description);
463 event
464 }
465 };
466 tracing::debug!(name = event.as_ref(), "sending event: {:?}", event);
467 tx.post(RemoteProcessProcStateMessage::Update(alloc_key.clone(), event));
468 }
469 None => {
470 tracing::debug!("sending done");
471 tx.post(RemoteProcessProcStateMessage::Done(alloc_key.clone()));
472 running = false;
473 break;
474 }
475 }
476 }
477 _ = RealClock.sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)) => {
478 tracing::trace!("sending heartbeat");
479 tx.post(RemoteProcessProcStateMessage::HeartBeat);
480 }
481 }
482 }
483 tracing::info!("allocation handler loop exited");
484 if running {
485 tracing::info!("stopping processes");
486 if let Err(e) = alloc.stop_and_wait().await {
487 tracing::error!("stop failed: {}", e);
488 return;
489 }
490 tracing::info!("stop finished");
491 }
492 }
493}
494
495type HostId = String;
498
499#[derive(Clone)]
501pub struct RemoteProcessAllocHost {
502 pub id: HostId,
505 pub hostname: String,
507}
508
509struct RemoteProcessAllocHostState {
511 alloc_key: ShortUuid,
513 host_id: HostId,
515 tx: ChannelTx<RemoteProcessAllocatorMessage>,
517 active_procs: HashSet<ShortUuid>,
519 region: Region,
521 world_id: Option<WorldId>,
523 failed: bool,
525 allocated: bool,
527}
528
529#[automock]
530#[async_trait]
531pub trait RemoteProcessAllocInitializer {
533 async fn initialize_alloc(&mut self) -> Result<Vec<RemoteProcessAllocHost>, anyhow::Error>;
535}
536
537struct HostStates {
540 inner: HashMap<HostId, RemoteProcessAllocHostState>,
541 host_addresses: Arc<DashMap<HostId, ChannelAddr>>,
542}
543
544impl HostStates {
545 fn new(host_addresses: Arc<DashMap<HostId, ChannelAddr>>) -> HostStates {
546 Self {
547 inner: HashMap::new(),
548 host_addresses,
549 }
550 }
551
552 fn insert(
553 &mut self,
554 host_id: HostId,
555 state: RemoteProcessAllocHostState,
556 address: ChannelAddr,
557 ) {
558 self.host_addresses.insert(host_id.clone(), address);
559 self.inner.insert(host_id, state);
560 }
561
562 fn get(&self, host_id: &HostId) -> Option<&RemoteProcessAllocHostState> {
563 self.inner.get(host_id)
564 }
565
566 fn get_mut(&mut self, host_id: &HostId) -> Option<&mut RemoteProcessAllocHostState> {
567 self.inner.get_mut(host_id)
568 }
569
570 fn remove(&mut self, host_id: &HostId) -> Option<RemoteProcessAllocHostState> {
571 self.host_addresses.remove(host_id);
572 self.inner.remove(host_id)
573 }
574
575 fn iter(&self) -> impl Iterator<Item = (&HostId, &RemoteProcessAllocHostState)> {
576 self.inner.iter()
577 }
578
579 fn iter_mut(&mut self) -> impl Iterator<Item = (&HostId, &mut RemoteProcessAllocHostState)> {
580 self.inner.iter_mut()
581 }
582
583 fn is_empty(&self) -> bool {
584 self.inner.is_empty()
585 }
586 }
588
589pub struct RemoteProcessAlloc {
592 initializer: Box<dyn RemoteProcessAllocInitializer + Send + Sync>,
595 spec: AllocSpec,
596 remote_allocator_port: u16,
597 world_id: WorldId,
598 ordered_hosts: Vec<RemoteProcessAllocHost>,
599 started: bool,
601 running: bool,
603 failed: bool,
605 alloc_to_host: HashMap<ShortUuid, HostId>,
607 host_states: HostStates,
608 world_offsets: HashMap<WorldId, usize>,
609 event_queue: VecDeque<ProcState>,
610 comm_watcher_tx: UnboundedSender<HostId>,
611 comm_watcher_rx: UnboundedReceiver<HostId>,
612
613 bootstrap_addr: ChannelAddr,
614 rx: ChannelRx<RemoteProcessProcStateMessage>,
615 _signal_cleanup_guard: hyperactor::SignalCleanupGuard,
616}
617
618impl RemoteProcessAlloc {
619 #[tracing::instrument(skip(initializer))]
624 #[observe_result("RemoteProcessAlloc")]
625 pub async fn new(
626 spec: AllocSpec,
627 world_id: WorldId,
628 remote_allocator_port: u16,
629 initializer: impl RemoteProcessAllocInitializer + Send + Sync + 'static,
630 ) -> Result<Self, anyhow::Error> {
631 let alloc_serve_addr = match config::global::try_get_cloned(REMOTE_ALLOC_BOOTSTRAP_ADDR) {
632 Some(addr_str) => addr_str.parse()?,
633 None => ChannelAddr::any(spec.transport.clone()),
634 };
635
636 let (bootstrap_addr, rx) = serve_with_config(alloc_serve_addr)?;
637
638 tracing::info!(
639 "starting alloc for {} on: {}",
640 world_id,
641 bootstrap_addr.clone()
642 );
643
644 let (comm_watcher_tx, comm_watcher_rx) = unbounded_channel();
645
646 let host_addresses = Arc::new(DashMap::<HostId, ChannelAddr>::new());
647 let host_addresses_for_signal = host_addresses.clone();
648
649 let signal_cleanup_guard =
651 hyperactor::register_signal_cleanup_scoped(Box::pin(async move {
652 join_all(host_addresses_for_signal.iter().map(|entry| async move {
653 let addr = entry.value().clone();
654 match channel::dial(addr.clone()) {
655 Ok(tx) => {
656 if let Err(e) = tx.send(RemoteProcessAllocatorMessage::Stop).await {
657 tracing::error!("Failed to send Stop to {}: {}", addr, e);
658 }
659 }
660 Err(e) => {
661 tracing::error!("Failed to dial {} during signal cleanup: {}", addr, e);
662 }
663 }
664 }))
665 .await;
666 }));
667
668 Ok(Self {
669 spec,
670 world_id,
671 remote_allocator_port,
672 initializer: Box::new(initializer),
673 world_offsets: HashMap::new(),
674 ordered_hosts: Vec::new(),
675 alloc_to_host: HashMap::new(),
676 host_states: HostStates::new(host_addresses),
677 bootstrap_addr,
678 event_queue: VecDeque::new(),
679 comm_watcher_tx,
680 comm_watcher_rx,
681 rx,
682 started: false,
683 running: true,
684 failed: false,
685 _signal_cleanup_guard: signal_cleanup_guard,
686 })
687 }
688
689 async fn start_comm_watcher(&self) {
697 let mut tx_watchers = Vec::new();
698 for host in &self.ordered_hosts {
699 let tx_status = self.host_states.get(&host.id).unwrap().tx.status().clone();
700 let watcher = WatchStream::new(tx_status);
701 tx_watchers.push((watcher, host.id.clone()));
702 }
703 assert!(!tx_watchers.is_empty());
704 let tx = self.comm_watcher_tx.clone();
705 tokio::spawn(async move {
706 loop {
707 let mut tx_status_futures = Vec::new();
708 for (watcher, _) in &mut tx_watchers {
709 let fut = watcher.next().boxed();
710 tx_status_futures.push(fut);
711 }
712 let (tx_status, index, _) = select_all(tx_status_futures).await;
713 let host_id = match tx_watchers.get(index) {
714 Some((_, host_id)) => host_id.clone(),
715 None => {
716 tracing::error!(
718 "got selected index {} with no matching host in {}",
719 index,
720 tx_watchers.len()
721 );
722 continue;
723 }
724 };
725 if let Some(tx_status) = tx_status {
726 tracing::debug!("host {} channel event: {:?}", host_id, tx_status);
727 if tx_status == TxStatus::Closed {
728 if tx.send(host_id.clone()).is_err() {
729 break;
731 }
732 tx_watchers.remove(index);
733 if tx_watchers.is_empty() {
734 break;
736 }
737 }
738 }
739 }
740 });
741 }
742
743 async fn ensure_started(&mut self) -> Result<(), anyhow::Error> {
748 if self.started || self.failed {
749 return Ok(());
750 }
751
752 self.started = true;
753 let hosts = self
754 .initializer
755 .initialize_alloc()
756 .await
757 .context("alloc initializer error")?;
758 if hosts.is_empty() {
759 anyhow::bail!("initializer returned empty list of hosts");
760 }
761 let hostnames: Vec<_> = hosts.iter().map(|e| e.hostname.clone()).collect();
764 tracing::info!("obtained {} hosts for this allocation", hostnames.len());
765
766 use crate::alloc::ProcAllocationMode;
768
769 let regions: Option<Vec<_>> = match self.spec.proc_allocation_mode {
771 ProcAllocationMode::ProcLevel => {
772 anyhow::ensure!(
774 self.spec.extent.len() >= 2,
775 "invalid extent: {}, expected at least 2 dimensions",
776 self.spec.extent
777 );
778 None
779 }
780 ProcAllocationMode::HostLevel => Some({
781 let num_points = self.spec.extent.num_ranks();
783 anyhow::ensure!(
784 hosts.len() >= num_points,
785 "HostLevel allocation mode requires {} hosts (one per point in extent {}), but only {} hosts were provided",
786 num_points,
787 self.spec.extent,
788 hosts.len()
789 );
790
791 let labels = self.spec.extent.labels().to_vec();
794
795 let extent_sizes = self.spec.extent.sizes();
797 let mut parent_strides = vec![1; extent_sizes.len()];
798 for i in (0..extent_sizes.len() - 1).rev() {
799 parent_strides[i] = parent_strides[i + 1] * extent_sizes[i + 1];
800 }
801
802 (0..num_points)
803 .map(|rank| {
804 let sizes = vec![1; labels.len()];
807 Region::new(
808 labels.clone(),
809 Slice::new(rank, sizes, parent_strides.clone()).unwrap(),
810 )
811 })
812 .collect()
813 }),
814 };
815
816 match self.spec.proc_allocation_mode {
817 ProcAllocationMode::ProcLevel => {
818 let split_dim = &self.spec.extent.labels()[self.spec.extent.len() - 1];
820 for (i, region) in self.spec.extent.group_by(split_dim)?.enumerate() {
821 let host = &hosts[i];
822 tracing::debug!("allocating: {} for host: {}", region, host.id);
823
824 let remote_addr = match self.spec.transport {
825 ChannelTransport::MetaTls(_) => {
826 format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
827 }
828 ChannelTransport::Tcp(TcpMode::Localhost) => {
829 format!("tcp![::1]:{}", self.remote_allocator_port)
831 }
832 ChannelTransport::Tcp(TcpMode::Hostname) => {
833 format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
834 }
835 ChannelTransport::Unix => host.hostname.clone(),
837 _ => {
838 anyhow::bail!(
839 "unsupported transport for host {}: {:?}",
840 host.id,
841 self.spec.transport,
842 );
843 }
844 };
845
846 tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
847 let remote_addr = remote_addr.parse::<ChannelAddr>()?;
848 let tx = channel::dial(remote_addr.clone())
849 .map_err(anyhow::Error::from)
850 .context(format!(
851 "failed to dial remote {} for host {}",
852 remote_addr, host.id
853 ))?;
854
855 let alloc_key = ShortUuid::generate();
857 assert!(
858 self.alloc_to_host
859 .insert(alloc_key.clone(), host.id.clone())
860 .is_none()
861 );
862
863 let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
864 let client_context = Some(ClientContext { trace_id });
865 let message = RemoteProcessAllocatorMessage::Allocate {
866 alloc_key: alloc_key.clone(),
867 extent: region.extent(),
868 bootstrap_addr: self.bootstrap_addr.clone(),
869 hosts: hostnames.clone(),
870 client_context,
871 forwarder_addr: with_unspecified_port_or_any(&remote_addr),
877 };
878 tracing::info!(
879 name = message.as_ref(),
880 "sending allocate message to workers"
881 );
882 tx.post(message);
883
884 self.host_states.insert(
885 host.id.clone(),
886 RemoteProcessAllocHostState {
887 alloc_key,
888 host_id: host.id.clone(),
889 tx,
890 active_procs: HashSet::new(),
891 region,
892 world_id: None,
893 failed: false,
894 allocated: false,
895 },
896 remote_addr,
897 );
898 }
899
900 self.ordered_hosts = hosts;
901 }
902 ProcAllocationMode::HostLevel => {
903 let regions = regions.unwrap();
904 let num_regions = regions.len();
905 for (i, region) in regions.into_iter().enumerate() {
906 let host = &hosts[i];
907 tracing::debug!("allocating: {} for host: {}", region, host.id);
908
909 let remote_addr = match self.spec.transport {
910 ChannelTransport::MetaTls(_) => {
911 format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
912 }
913 ChannelTransport::Tcp(TcpMode::Localhost) => {
914 format!("tcp![::1]:{}", self.remote_allocator_port)
916 }
917 ChannelTransport::Tcp(TcpMode::Hostname) => {
918 format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
919 }
920 ChannelTransport::Unix => host.hostname.clone(),
922 _ => {
923 anyhow::bail!(
924 "unsupported transport for host {}: {:?}",
925 host.id,
926 self.spec.transport,
927 );
928 }
929 };
930
931 tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
932 let remote_addr = remote_addr.parse::<ChannelAddr>()?;
933 let tx = channel::dial(remote_addr.clone())
934 .map_err(anyhow::Error::from)
935 .context(format!(
936 "failed to dial remote {} for host {}",
937 remote_addr, host.id
938 ))?;
939
940 let alloc_key = ShortUuid::generate();
942 assert!(
943 self.alloc_to_host
944 .insert(alloc_key.clone(), host.id.clone())
945 .is_none()
946 );
947
948 let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
949 let client_context = Some(ClientContext { trace_id });
950 let message = RemoteProcessAllocatorMessage::Allocate {
951 alloc_key: alloc_key.clone(),
952 extent: region.extent(),
953 bootstrap_addr: self.bootstrap_addr.clone(),
954 hosts: hostnames.clone(),
955 client_context,
956 forwarder_addr: with_unspecified_port_or_any(&remote_addr),
962 };
963 tracing::info!(
964 name = message.as_ref(),
965 "sending allocate message to workers"
966 );
967 tx.post(message);
968
969 self.host_states.insert(
970 host.id.clone(),
971 RemoteProcessAllocHostState {
972 alloc_key,
973 host_id: host.id.clone(),
974 tx,
975 active_procs: HashSet::new(),
976 region,
977 world_id: None,
978 failed: false,
979 allocated: false,
980 },
981 remote_addr,
982 );
983 }
984
985 self.ordered_hosts = hosts.into_iter().take(num_regions).collect();
988 }
989 }
990 self.start_comm_watcher().await;
991 self.started = true;
992
993 Ok(())
994 }
995
996 fn get_host_state_mut(
998 &mut self,
999 alloc_key: &ShortUuid,
1000 ) -> Result<&mut RemoteProcessAllocHostState, anyhow::Error> {
1001 let host_id: &HostId = self
1002 .alloc_to_host
1003 .get(alloc_key)
1004 .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1005
1006 self.host_states
1007 .get_mut(host_id)
1008 .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1009 }
1010
1011 fn get_host_state(
1013 &self,
1014 alloc_key: &ShortUuid,
1015 ) -> Result<&RemoteProcessAllocHostState, anyhow::Error> {
1016 let host_id: &HostId = self
1017 .alloc_to_host
1018 .get(alloc_key)
1019 .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1020
1021 self.host_states
1022 .get(host_id)
1023 .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1024 }
1025
1026 fn remove_host_state(
1027 &mut self,
1028 alloc_key: &ShortUuid,
1029 ) -> Result<RemoteProcessAllocHostState, anyhow::Error> {
1030 let host_id: &HostId = self
1031 .alloc_to_host
1032 .get(alloc_key)
1033 .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1034
1035 self.host_states
1036 .remove(host_id)
1037 .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1038 }
1039
1040 fn add_proc_id_to_host_state(
1041 &mut self,
1042 alloc_key: &ShortUuid,
1043 create_key: &ShortUuid,
1044 ) -> Result<(), anyhow::Error> {
1045 let task_state = self.get_host_state_mut(alloc_key)?;
1046 if !task_state.active_procs.insert(create_key.clone()) {
1047 tracing::error!("proc with create key {} already in host state", create_key);
1049 }
1050 task_state.allocated = true;
1051 Ok(())
1052 }
1053
1054 fn remove_proc_from_host_state(
1055 &mut self,
1056 alloc_key: &ShortUuid,
1057 create_key: &ShortUuid,
1058 ) -> Result<(), anyhow::Error> {
1059 let task_state = self.get_host_state_mut(alloc_key)?;
1060 if !task_state.active_procs.remove(create_key) {
1061 tracing::error!("proc with create_key already in host state: {}", create_key);
1063 }
1064 Ok(())
1065 }
1066
1067 fn project_proc_into_global_extent(
1069 &self,
1070 alloc_key: &ShortUuid,
1071 point: &Point,
1072 ) -> Result<Point, anyhow::Error> {
1073 let global_rank = self
1074 .get_host_state(alloc_key)?
1075 .region
1076 .get(point.rank())
1077 .ok_or_else(|| {
1078 anyhow::anyhow!(
1079 "rank {} out of bounds for in alloc {}",
1080 point.rank(),
1081 alloc_key
1082 )
1083 })?;
1084 Ok(self.spec.extent.point_of_rank(global_rank)?)
1085 }
1086
1087 fn cleanup_host_channel_closed(
1089 &mut self,
1090 host_id: HostId,
1091 ) -> Result<Vec<ShortUuid>, anyhow::Error> {
1092 let state = match self.host_states.remove(&host_id) {
1093 Some(state) => state,
1094 None => {
1095 anyhow::bail!(
1097 "got channel closed event for host {} which has no known state",
1098 host_id
1099 );
1100 }
1101 };
1102 self.ordered_hosts.retain(|host| host.id != host_id);
1103 self.alloc_to_host.remove(&state.alloc_key);
1104 if let Some(world_id) = state.world_id {
1105 self.world_offsets.remove(&world_id);
1106 }
1107 let create_keys = state.active_procs.iter().cloned().collect();
1108
1109 Ok(create_keys)
1110 }
1111}
1112
1113#[async_trait]
1114impl Alloc for RemoteProcessAlloc {
1115 async fn next(&mut self) -> Option<ProcState> {
1116 loop {
1117 if let state @ Some(_) = self.event_queue.pop_front() {
1118 break state;
1119 }
1120
1121 if !self.running {
1122 break None;
1123 }
1124
1125 if let Err(e) = self.ensure_started().await {
1126 break Some(ProcState::Failed {
1127 world_id: self.world_id.clone(),
1128 description: format!("failed to ensure started: {:#}", e),
1129 });
1130 }
1131
1132 let heartbeat_interval =
1133 config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL);
1134 let mut heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
1135 let mut reloop = false;
1137 let update = loop {
1138 tokio::select! {
1139 msg = self.rx.recv() => {
1140 tracing::debug!("got ProcState message from allocator: {:?}", msg);
1141 match msg {
1142 Ok(RemoteProcessProcStateMessage::Allocated { alloc_key, world_id }) => {
1143 tracing::info!("remote alloc {}: allocated", alloc_key);
1144 match self.get_host_state_mut(&alloc_key) {
1145 Ok(state) => {
1146 state.world_id = Some(world_id.clone());
1147 }
1148 Err(err) => {
1149 tracing::error!(
1151 "received allocated message alloc: {} with no known state: {}",
1152 alloc_key, err,
1153 );
1154 }
1155 }
1156 }
1157 Ok(RemoteProcessProcStateMessage::Update(alloc_key, proc_state)) => {
1158 let update = match proc_state {
1159 ProcState::Created { ref create_key, .. } => {
1160 if let Err(e) = self.add_proc_id_to_host_state(&alloc_key, create_key) {
1161 tracing::error!("failed to add proc with create key {} host state: {}", create_key, e);
1162 }
1163 proc_state
1164 }
1165 ProcState::Stopped{ ref create_key, ..} => {
1166 if let Err(e) = self.remove_proc_from_host_state(&alloc_key, create_key) {
1167 tracing::error!("failed to remove proc with create key {} host state: {}", create_key, e);
1168 }
1169 proc_state
1170 }
1171 ProcState::Failed { ref world_id, ref description } => {
1172 match self.get_host_state_mut(&alloc_key) {
1173 Ok(state) => {
1174 state.failed = true;
1175 ProcState::Failed {
1176 world_id: world_id.clone(),
1177 description: format!("host {} failed: {}", state.host_id, description),
1178 }
1179 }
1180 Err(e) => {
1181 tracing::error!("failed to find host state for world id: {}: {}", world_id, e);
1182 proc_state
1183 }
1184 }
1185 }
1186 _ => proc_state
1187 };
1188
1189 break Some((Some(alloc_key), update));
1190 }
1191 Ok(RemoteProcessProcStateMessage::Done(alloc_key)) => {
1192 tracing::info!("allocator {} is done", alloc_key);
1193
1194 if let Ok(state) = self.remove_host_state(&alloc_key) {
1195 if !state.active_procs.is_empty() {
1196 tracing::error!("received done for alloc {} with active procs: {:?}", alloc_key, state.active_procs);
1197 }
1198 } else {
1199 tracing::error!("received done for unknown alloc {}", alloc_key);
1200 }
1201
1202 if self.host_states.is_empty() {
1203 self.running = false;
1204 break None;
1205 }
1206 }
1207 Ok(RemoteProcessProcStateMessage::HeartBeat) => {}
1211 Err(e) => {
1212 break Some((None, ProcState::Failed {world_id: self.world_id.clone(), description: format!("error receiving events: {}", e)}));
1213 }
1214 }
1215 }
1216
1217 _ = clock::RealClock.sleep_until(heartbeat_time) => {
1218 self.host_states.iter().for_each(|(_, host_state)| host_state.tx.post(RemoteProcessAllocatorMessage::HeartBeat));
1219 heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
1220 }
1221
1222 closed_host_id = self.comm_watcher_rx.recv() => {
1223 if let Some(closed_host_id) = closed_host_id {
1224 tracing::debug!("host {} channel closed, cleaning up", closed_host_id);
1225 if let Some(state) = self.host_states.get(&closed_host_id)
1226 && !state.allocated {
1227 break Some((None, ProcState::Failed {
1228 world_id: self.world_id.clone(),
1229 description: format!(
1230 "no process has ever been allocated on {} before the channel is closed; \
1231 a common issue could be the channel was never established",
1232 closed_host_id
1233 )}));
1234 }
1235 let create_keys = match self.cleanup_host_channel_closed(closed_host_id) {
1236 Ok(create_keys) => create_keys,
1237 Err(err) => {
1238 tracing::error!("failed to cleanup disconnected host: {}", err);
1239 continue;
1240 }
1241 };
1242 for create_key in create_keys {
1243 tracing::debug!("queuing Stopped state for proc with create key {}", create_key);
1244 self.event_queue.push_back(
1245 ProcState::Stopped {
1246 create_key,
1247 reason: ProcStopReason::HostWatchdog
1248 }
1249 );
1250 }
1251 if self.host_states.is_empty() {
1253 tracing::info!("no more hosts left, stopping the alloc");
1254 self.running = false;
1255 }
1256 reloop = true;
1258 break None;
1259 } else {
1260 tracing::warn!("unexpected comm watcher channel close");
1262 break None;
1263 }
1264 }
1265 }
1266 };
1267
1268 if reloop {
1269 assert!(update.is_none());
1272 continue;
1273 }
1274
1275 break match update {
1276 Some((
1277 Some(alloc_key),
1278 ProcState::Created {
1279 create_key,
1280 point,
1281 pid,
1282 },
1283 )) => match self.project_proc_into_global_extent(&alloc_key, &point) {
1284 Ok(global_point) => {
1285 tracing::debug!("reprojected coords: {} -> {}", point, global_point);
1286 Some(ProcState::Created {
1287 create_key,
1288 point: global_point,
1289 pid,
1290 })
1291 }
1292 Err(e) => {
1293 tracing::error!(
1294 "failed to project coords for proc: {}.{}: {}",
1295 alloc_key,
1296 create_key,
1297 e
1298 );
1299 None
1300 }
1301 },
1302 Some((None, ProcState::Created { .. })) => {
1303 panic!("illegal state: missing alloc_key for ProcState::Created event")
1304 }
1305 Some((_, update)) => {
1306 if let ProcState::Failed { description, .. } = &update {
1307 tracing::error!(description);
1308 self.failed = true;
1309 }
1310 Some(update)
1311 }
1312 None => None,
1313 };
1314 }
1315 }
1316
1317 fn spec(&self) -> &AllocSpec {
1318 &self.spec
1319 }
1320
1321 fn extent(&self) -> &Extent {
1322 &self.spec.extent
1323 }
1324
1325 fn world_id(&self) -> &WorldId {
1326 &self.world_id
1327 }
1328
1329 async fn stop(&mut self) -> Result<(), AllocatorError> {
1330 tracing::info!("stopping alloc");
1331
1332 for (host_id, task_state) in self.host_states.iter_mut() {
1333 tracing::debug!("stopping alloc at host {}", host_id);
1334 task_state.tx.post(RemoteProcessAllocatorMessage::Stop);
1335 }
1336
1337 Ok(())
1338 }
1339
1340 fn client_router_addr(&self) -> ChannelAddr {
1349 with_unspecified_port_or_any(&self.bootstrap_addr)
1350 }
1351}
1352
1353impl Drop for RemoteProcessAlloc {
1354 fn drop(&mut self) {
1355 tracing::debug!("dropping RemoteProcessAlloc of world_id {}", self.world_id);
1356 }
1357}
1358
1359#[cfg(test)]
1360mod test {
1361 use std::assert_matches::assert_matches;
1362
1363 use hyperactor::ActorRef;
1364 use hyperactor::channel::ChannelRx;
1365 use hyperactor::clock::ClockKind;
1366 use hyperactor::id;
1367 use ndslice::extent;
1368 use tokio::sync::oneshot;
1369
1370 use super::*;
1371 use crate::alloc::ChannelTransport;
1372 use crate::alloc::MockAlloc;
1373 use crate::alloc::MockAllocWrapper;
1374 use crate::alloc::MockAllocator;
1375 use crate::alloc::ProcStopReason;
1376 use crate::alloc::with_unspecified_port_or_any;
1377 use crate::proc_mesh::mesh_agent::ProcMeshAgent;
1378
1379 async fn read_all_created(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1380 let mut i: usize = 0;
1381 while i < alloc_len {
1382 let m = rx.recv().await.unwrap();
1383 match m {
1384 RemoteProcessProcStateMessage::Update(_, ProcState::Created { .. }) => i += 1,
1385 RemoteProcessProcStateMessage::HeartBeat => {}
1386 _ => panic!("unexpected message: {:?}", m),
1387 }
1388 }
1389 }
1390
1391 async fn read_all_running(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1392 let mut i: usize = 0;
1393 while i < alloc_len {
1394 let m = rx.recv().await.unwrap();
1395 match m {
1396 RemoteProcessProcStateMessage::Update(_, ProcState::Running { .. }) => i += 1,
1397 RemoteProcessProcStateMessage::HeartBeat => {}
1398 _ => panic!("unexpected message: {:?}", m),
1399 }
1400 }
1401 }
1402
1403 async fn read_all_stopped(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1404 let mut i: usize = 0;
1405 while i < alloc_len {
1406 let m = rx.recv().await.unwrap();
1407 match m {
1408 RemoteProcessProcStateMessage::Update(_, ProcState::Stopped { .. }) => i += 1,
1409 RemoteProcessProcStateMessage::HeartBeat => {}
1410 _ => panic!("unexpected message: {:?}", m),
1411 }
1412 }
1413 }
1414
1415 fn set_procstate_expectations(alloc: &mut MockAlloc, extent: Extent) {
1416 alloc.expect_extent().return_const(extent.clone());
1417 let mut create_keys = Vec::new();
1418 for point in extent.points() {
1419 let create_key = ShortUuid::generate();
1420 create_keys.push(create_key.clone());
1421 alloc.expect_next().times(1).return_once(move || {
1422 Some(ProcState::Created {
1423 create_key: create_key.clone(),
1424 point,
1425 pid: 0,
1426 })
1427 });
1428 }
1429 for (i, create_key) in create_keys
1430 .iter()
1431 .take(extent.num_ranks())
1432 .cloned()
1433 .enumerate()
1434 {
1435 let proc_id = format!("test[{i}]").parse().unwrap();
1436 let mesh_agent = ActorRef::<ProcMeshAgent>::attest(
1437 format!("test[{i}].mesh_agent[{i}]").parse().unwrap(),
1438 );
1439 alloc.expect_next().times(1).return_once(move || {
1440 Some(ProcState::Running {
1441 create_key,
1442 proc_id,
1443 addr: ChannelAddr::Unix("/proc0".parse().unwrap()),
1444 mesh_agent,
1445 })
1446 });
1447 }
1448 for create_key in create_keys.iter().take(extent.num_ranks()).cloned() {
1449 alloc.expect_next().times(1).return_once(|| {
1450 Some(ProcState::Stopped {
1451 create_key,
1452 reason: ProcStopReason::Unknown,
1453 })
1454 });
1455 }
1456 }
1457
1458 #[timed_test::async_timed_test(timeout_secs = 5)]
1459 async fn test_simple() {
1460 let config = hyperactor::config::global::lock();
1461 let _guard = config.override_key(
1462 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1463 Duration::from_millis(100),
1464 );
1465 hyperactor_telemetry::initialize_logging_for_test();
1466 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1467 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1468 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1469
1470 let extent = extent!(host = 1, gpu = 2);
1471 let tx = channel::dial(serve_addr.clone()).unwrap();
1472
1473 let world_id: WorldId = id!(test_world_id);
1474 let mut alloc = MockAlloc::new();
1475 alloc.expect_world_id().return_const(world_id.clone());
1476 alloc.expect_extent().return_const(extent.clone());
1477
1478 set_procstate_expectations(&mut alloc, extent.clone());
1479
1480 alloc.expect_next().return_const(None);
1482
1483 let mut allocator = MockAllocator::new();
1484 let total_messages = extent.num_ranks() * 3 + 1;
1485 let mock_wrapper = MockAllocWrapper::new_block_next(
1486 alloc,
1487 total_messages,
1489 );
1490 allocator
1491 .expect_allocate()
1492 .times(1)
1493 .return_once(move |_| Ok(mock_wrapper));
1494
1495 let remote_allocator = RemoteProcessAllocator::new();
1496 let handle = tokio::spawn({
1497 let remote_allocator = remote_allocator.clone();
1498 async move {
1499 remote_allocator
1500 .start_with_allocator(serve_addr, allocator, None)
1501 .await
1502 }
1503 });
1504
1505 let alloc_key = ShortUuid::generate();
1506
1507 tx.send(RemoteProcessAllocatorMessage::Allocate {
1508 alloc_key: alloc_key.clone(),
1509 extent: extent.clone(),
1510 bootstrap_addr,
1511 hosts: vec![],
1512 client_context: None,
1513 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1514 })
1515 .await
1516 .unwrap();
1517
1518 let m = rx.recv().await.unwrap();
1520 assert_matches!(
1521 m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1522 if got_world_id == world_id && got_alloc_key == alloc_key
1523 );
1524
1525 let mut rank: usize = 0;
1527 let mut create_keys = Vec::with_capacity(extent.num_ranks());
1528 while rank < extent.num_ranks() {
1529 let m = rx.recv().await.unwrap();
1530 match m {
1531 RemoteProcessProcStateMessage::Update(
1532 got_alloc_key,
1533 ProcState::Created {
1534 create_key, point, ..
1535 },
1536 ) => {
1537 let expected_point = extent.point_of_rank(rank).unwrap();
1538 assert_eq!(got_alloc_key, alloc_key);
1539 assert_eq!(point, expected_point);
1540 create_keys.push(create_key);
1541 rank += 1;
1542 }
1543 RemoteProcessProcStateMessage::HeartBeat => {}
1544 _ => panic!("unexpected message: {:?}", m),
1545 }
1546 }
1547 let mut rank: usize = 0;
1549 while rank < extent.num_ranks() {
1550 let m = rx.recv().await.unwrap();
1551 match m {
1552 RemoteProcessProcStateMessage::Update(
1553 got_alloc_key,
1554 ProcState::Running {
1555 create_key,
1556 proc_id,
1557 mesh_agent,
1558 addr: _,
1559 },
1560 ) => {
1561 assert_eq!(got_alloc_key, alloc_key);
1562 assert_eq!(create_key, create_keys[rank]);
1563 let expected_proc_id = format!("test[{}]", rank).parse().unwrap();
1564 let expected_mesh_agent = ActorRef::<ProcMeshAgent>::attest(
1565 format!("test[{}].mesh_agent[{}]", rank, rank)
1566 .parse()
1567 .unwrap(),
1568 );
1569 assert_eq!(proc_id, expected_proc_id);
1570 assert_eq!(mesh_agent, expected_mesh_agent);
1571 rank += 1;
1572 }
1573 RemoteProcessProcStateMessage::HeartBeat => {}
1574 _ => panic!("unexpected message: {:?}", m),
1575 }
1576 }
1577 let mut rank: usize = 0;
1579 while rank < extent.num_ranks() {
1580 let m = rx.recv().await.unwrap();
1581 match m {
1582 RemoteProcessProcStateMessage::Update(
1583 got_alloc_key,
1584 ProcState::Stopped {
1585 create_key,
1586 reason: ProcStopReason::Unknown,
1587 },
1588 ) => {
1589 assert_eq!(got_alloc_key, alloc_key);
1590 assert_eq!(create_key, create_keys[rank]);
1591 rank += 1;
1592 }
1593 RemoteProcessProcStateMessage::HeartBeat => {}
1594 _ => panic!("unexpected message: {:?}", m),
1595 }
1596 }
1597 loop {
1599 let m = rx.recv().await.unwrap();
1600 match m {
1601 RemoteProcessProcStateMessage::Done(got_alloc_key) => {
1602 assert_eq!(got_alloc_key, alloc_key);
1603 break;
1604 }
1605 RemoteProcessProcStateMessage::HeartBeat => {}
1606 _ => panic!("unexpected message: {:?}", m),
1607 }
1608 }
1609
1610 remote_allocator.terminate();
1611 handle.await.unwrap().unwrap();
1612 }
1613
1614 #[timed_test::async_timed_test(timeout_secs = 15)]
1615 async fn test_normal_stop() {
1616 let config = hyperactor::config::global::lock();
1617 let _guard = config.override_key(
1618 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1619 Duration::from_millis(100),
1620 );
1621 hyperactor_telemetry::initialize_logging_for_test();
1622 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1623 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1624 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1625
1626 let extent = extent!(host = 1, gpu = 2);
1627 let tx = channel::dial(serve_addr.clone()).unwrap();
1628
1629 let world_id: WorldId = id!(test_world_id);
1630 let mut alloc = MockAllocWrapper::new_block_next(
1631 MockAlloc::new(),
1632 extent.num_ranks() * 2,
1634 );
1635 let next_tx = alloc.notify_tx();
1636 alloc.alloc.expect_world_id().return_const(world_id.clone());
1637 alloc.alloc.expect_extent().return_const(extent.clone());
1638
1639 set_procstate_expectations(&mut alloc.alloc, extent.clone());
1640
1641 alloc.alloc.expect_next().return_const(None);
1642 alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1643
1644 let mut allocator = MockAllocator::new();
1645 allocator
1646 .expect_allocate()
1647 .times(1)
1648 .return_once(|_| Ok(alloc));
1649
1650 let remote_allocator = RemoteProcessAllocator::new();
1651 let handle = tokio::spawn({
1652 let remote_allocator = remote_allocator.clone();
1653 async move {
1654 remote_allocator
1655 .start_with_allocator(serve_addr, allocator, None)
1656 .await
1657 }
1658 });
1659
1660 let alloc_key = ShortUuid::generate();
1661 tx.send(RemoteProcessAllocatorMessage::Allocate {
1662 alloc_key: alloc_key.clone(),
1663 extent: extent.clone(),
1664 bootstrap_addr,
1665 hosts: vec![],
1666 client_context: None,
1667 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1668 })
1669 .await
1670 .unwrap();
1671
1672 let m = rx.recv().await.unwrap();
1674 assert_matches!(
1675 m,
1676 RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1677 if world_id == got_world_id && alloc_key == got_alloc_key
1678 );
1679
1680 read_all_created(&mut rx, extent.num_ranks()).await;
1681 read_all_running(&mut rx, extent.num_ranks()).await;
1682
1683 tracing::info!("stopping allocation");
1685 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1686 next_tx.send(()).unwrap();
1688
1689 read_all_stopped(&mut rx, extent.num_ranks()).await;
1690
1691 remote_allocator.terminate();
1692 handle.await.unwrap().unwrap();
1693 }
1694
1695 #[timed_test::async_timed_test(timeout_secs = 15)]
1696 async fn test_realloc() {
1697 let config = hyperactor::config::global::lock();
1698 let _guard = config.override_key(
1699 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1700 Duration::from_millis(100),
1701 );
1702 hyperactor_telemetry::initialize_logging_for_test();
1703 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1704 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1705 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1706
1707 let extent = extent!(host = 1, gpu = 2);
1708
1709 let tx = channel::dial(serve_addr.clone()).unwrap();
1710
1711 let world_id: WorldId = id!(test_world_id);
1712 let mut alloc1 = MockAllocWrapper::new_block_next(
1713 MockAlloc::new(),
1714 extent.num_ranks() * 2,
1716 );
1717 let next_tx1 = alloc1.notify_tx();
1718 alloc1
1719 .alloc
1720 .expect_world_id()
1721 .return_const(world_id.clone());
1722 alloc1.alloc.expect_extent().return_const(extent.clone());
1723
1724 set_procstate_expectations(&mut alloc1.alloc, extent.clone());
1725 alloc1.alloc.expect_next().return_const(None);
1726 alloc1.alloc.expect_stop().times(1).return_once(|| Ok(()));
1727 let mut alloc2 = MockAllocWrapper::new_block_next(
1729 MockAlloc::new(),
1730 extent.num_ranks() * 2,
1732 );
1733 let next_tx2 = alloc2.notify_tx();
1734 alloc2
1735 .alloc
1736 .expect_world_id()
1737 .return_const(world_id.clone());
1738 alloc2.alloc.expect_extent().return_const(extent.clone());
1739 set_procstate_expectations(&mut alloc2.alloc, extent.clone());
1740 alloc2.alloc.expect_next().return_const(None);
1741 alloc2.alloc.expect_stop().times(1).return_once(|| Ok(()));
1742
1743 let mut allocator = MockAllocator::new();
1744 allocator
1745 .expect_allocate()
1746 .times(1)
1747 .return_once(|_| Ok(alloc1));
1748 allocator
1750 .expect_allocate()
1751 .times(1)
1752 .return_once(|_| Ok(alloc2));
1753
1754 let remote_allocator = RemoteProcessAllocator::new();
1755 let handle = tokio::spawn({
1756 let remote_allocator = remote_allocator.clone();
1757 async move {
1758 remote_allocator
1759 .start_with_allocator(serve_addr, allocator, None)
1760 .await
1761 }
1762 });
1763
1764 let alloc_key = ShortUuid::generate();
1765
1766 tx.send(RemoteProcessAllocatorMessage::Allocate {
1767 alloc_key: alloc_key.clone(),
1768 extent: extent.clone(),
1769 bootstrap_addr: bootstrap_addr.clone(),
1770 hosts: vec![],
1771 client_context: None,
1772 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1773 })
1774 .await
1775 .unwrap();
1776
1777 let m = rx.recv().await.unwrap();
1779 assert_matches!(
1780 m,
1781 RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1782 if got_world_id == world_id && got_alloc_key == alloc_key
1783 );
1784
1785 read_all_created(&mut rx, extent.num_ranks()).await;
1786 read_all_running(&mut rx, extent.num_ranks()).await;
1787
1788 let alloc_key = ShortUuid::generate();
1789
1790 tx.send(RemoteProcessAllocatorMessage::Allocate {
1792 alloc_key: alloc_key.clone(),
1793 extent: extent.clone(),
1794 bootstrap_addr,
1795 hosts: vec![],
1796 client_context: None,
1797 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1798 })
1799 .await
1800 .unwrap();
1801 next_tx1.send(()).unwrap();
1803 read_all_stopped(&mut rx, extent.num_ranks()).await;
1805 let m = rx.recv().await.unwrap();
1806 assert_matches!(m, RemoteProcessProcStateMessage::Done(_));
1807 let m = rx.recv().await.unwrap();
1808 assert_matches!(
1809 m,
1810 RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1811 if got_world_id == world_id && got_alloc_key == alloc_key
1812 );
1813 read_all_created(&mut rx, extent.num_ranks()).await;
1815 read_all_running(&mut rx, extent.num_ranks()).await;
1816 tracing::info!("stopping allocation");
1818 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1819 next_tx2.send(()).unwrap();
1821
1822 read_all_stopped(&mut rx, extent.num_ranks()).await;
1823
1824 remote_allocator.terminate();
1825 handle.await.unwrap().unwrap();
1826 }
1827
1828 #[timed_test::async_timed_test(timeout_secs = 15)]
1829 async fn test_upstream_closed() {
1830 let config = hyperactor::config::global::lock();
1832 let _guard1 = config.override_key(
1833 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1834 Duration::from_secs(1),
1835 );
1836 let _guard2 = config.override_key(
1837 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1838 Duration::from_millis(100),
1839 );
1840
1841 hyperactor_telemetry::initialize_logging_for_test();
1842 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1843 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1844 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1845
1846 let extent = extent!(host = 1, gpu = 2);
1847
1848 let tx = channel::dial(serve_addr.clone()).unwrap();
1849
1850 let world_id: WorldId = id!(test_world_id);
1851 let mut alloc = MockAllocWrapper::new_block_next(
1852 MockAlloc::new(),
1853 extent.num_ranks() * 2,
1855 );
1856 let next_tx = alloc.notify_tx();
1857 alloc.alloc.expect_world_id().return_const(world_id.clone());
1858 alloc.alloc.expect_extent().return_const(extent.clone());
1859
1860 set_procstate_expectations(&mut alloc.alloc, extent.clone());
1861
1862 alloc.alloc.expect_next().return_const(None);
1863 let (stop_tx, stop_rx) = oneshot::channel();
1866 alloc.alloc.expect_stop().times(1).return_once(|| {
1867 stop_tx.send(()).unwrap();
1868 Ok(())
1869 });
1870
1871 let mut allocator = MockAllocator::new();
1872 allocator
1873 .expect_allocate()
1874 .times(1)
1875 .return_once(|_| Ok(alloc));
1876
1877 let remote_allocator = RemoteProcessAllocator::new();
1878 let handle = tokio::spawn({
1879 let remote_allocator = remote_allocator.clone();
1880 async move {
1881 remote_allocator
1882 .start_with_allocator(serve_addr, allocator, None)
1883 .await
1884 }
1885 });
1886
1887 let alloc_key = ShortUuid::generate();
1888
1889 tx.send(RemoteProcessAllocatorMessage::Allocate {
1890 alloc_key: alloc_key.clone(),
1891 extent: extent.clone(),
1892 bootstrap_addr,
1893 hosts: vec![],
1894 client_context: None,
1895 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1896 })
1897 .await
1898 .unwrap();
1899
1900 let m = rx.recv().await.unwrap();
1902 assert_matches!(
1903 m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1904 if got_world_id == world_id && got_alloc_key == alloc_key
1905 );
1906
1907 read_all_created(&mut rx, extent.num_ranks()).await;
1908 read_all_running(&mut rx, extent.num_ranks()).await;
1909
1910 tracing::info!("closing upstream");
1912 drop(rx);
1913 #[allow(clippy::disallowed_methods)]
1915 tokio::time::sleep(Duration::from_secs(2)).await;
1916 stop_rx.await.unwrap();
1918 next_tx.send(()).unwrap();
1920 remote_allocator.terminate();
1921 handle.await.unwrap().unwrap();
1922 }
1923
1924 #[timed_test::async_timed_test(timeout_secs = 15)]
1925 async fn test_inner_alloc_failure() {
1926 let config = hyperactor::config::global::lock();
1927 let _guard = config.override_key(
1928 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1929 Duration::from_secs(60),
1930 );
1931 hyperactor_telemetry::initialize_logging_for_test();
1932 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1933 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1934 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1935
1936 let extent = extent!(host = 1, gpu = 2);
1937
1938 let tx = channel::dial(serve_addr.clone()).unwrap();
1939
1940 let test_world_id: WorldId = id!(test_world_id);
1941 let mut alloc = MockAllocWrapper::new_block_next(
1942 MockAlloc::new(),
1943 1,
1945 );
1946 let next_tx = alloc.notify_tx();
1947 alloc
1948 .alloc
1949 .expect_world_id()
1950 .return_const(test_world_id.clone());
1951 alloc.alloc.expect_extent().return_const(extent.clone());
1952 alloc
1953 .alloc
1954 .expect_next()
1955 .times(1)
1956 .return_const(Some(ProcState::Failed {
1957 world_id: test_world_id.clone(),
1958 description: "test".to_string(),
1959 }));
1960 alloc.alloc.expect_next().times(1).return_const(None);
1961
1962 alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1963
1964 let mut allocator = MockAllocator::new();
1965 allocator
1966 .expect_allocate()
1967 .times(1)
1968 .return_once(|_| Ok(alloc));
1969
1970 let remote_allocator = RemoteProcessAllocator::new();
1971 let handle = tokio::spawn({
1972 let remote_allocator = remote_allocator.clone();
1973 async move {
1974 remote_allocator
1975 .start_with_allocator(serve_addr, allocator, None)
1976 .await
1977 }
1978 });
1979
1980 let alloc_key = ShortUuid::generate();
1981 tx.send(RemoteProcessAllocatorMessage::Allocate {
1982 alloc_key: alloc_key.clone(),
1983 extent: extent.clone(),
1984 bootstrap_addr,
1985 hosts: vec![],
1986 client_context: None,
1987 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1988 })
1989 .await
1990 .unwrap();
1991
1992 let m = rx.recv().await.unwrap();
1994 assert_matches!(
1995 m,
1996 RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1997 if test_world_id == got_world_id && alloc_key == got_alloc_key
1998 );
1999
2000 let m = rx.recv().await.unwrap();
2002 assert_matches!(
2003 m,
2004 RemoteProcessProcStateMessage::Update(
2005 got_alloc_key,
2006 ProcState::Failed { world_id, description }
2007 ) if got_alloc_key == alloc_key && world_id == test_world_id && description == "test"
2008 );
2009
2010 tracing::info!("stopping allocation");
2011 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
2012 next_tx.send(()).unwrap();
2014 let m = rx.recv().await.unwrap();
2016 assert_matches!(
2017 m,
2018 RemoteProcessProcStateMessage::Done(got_alloc_key)
2019 if got_alloc_key == alloc_key
2020 );
2021
2022 remote_allocator.terminate();
2023 handle.await.unwrap().unwrap();
2024 }
2025
2026 #[timed_test::async_timed_test(timeout_secs = 15)]
2027 async fn test_trace_id_propagation() {
2028 let config = hyperactor::config::global::lock();
2029 let _guard = config.override_key(
2030 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2031 Duration::from_secs(60),
2032 );
2033 hyperactor_telemetry::initialize_logging(ClockKind::default());
2034 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
2035 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
2036 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
2037
2038 let extent = extent!(host = 1, gpu = 1);
2039 let tx = channel::dial(serve_addr.clone()).unwrap();
2040 let test_world_id: WorldId = id!(test_world_id);
2041 let test_trace_id = "test_trace_id_12345";
2042
2043 let mut alloc = MockAlloc::new();
2045 alloc.expect_world_id().return_const(test_world_id.clone());
2046 alloc.expect_extent().return_const(extent.clone());
2047 alloc.expect_next().return_const(None);
2048
2049 let mut allocator = MockAllocator::new();
2051 allocator
2052 .expect_allocate()
2053 .times(1)
2054 .withf(move |spec: &AllocSpec| {
2055 spec.constraints
2057 .match_labels
2058 .get(CLIENT_TRACE_ID_LABEL)
2059 .is_some_and(|trace_id| trace_id == test_trace_id)
2060 })
2061 .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
2062
2063 let remote_allocator = RemoteProcessAllocator::new();
2064 let handle = tokio::spawn({
2065 let remote_allocator = remote_allocator.clone();
2066 async move {
2067 remote_allocator
2068 .start_with_allocator(serve_addr, allocator, None)
2069 .await
2070 }
2071 });
2072
2073 let alloc_key = ShortUuid::generate();
2074 tx.send(RemoteProcessAllocatorMessage::Allocate {
2075 alloc_key: alloc_key.clone(),
2076 extent: extent.clone(),
2077 bootstrap_addr,
2078 hosts: vec![],
2079 client_context: Some(ClientContext {
2080 trace_id: test_trace_id.to_string(),
2081 }),
2082 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
2083 })
2084 .await
2085 .unwrap();
2086
2087 let m = rx.recv().await.unwrap();
2089 assert_matches!(
2090 m,
2091 RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
2092 if got_world_id == test_world_id && got_alloc_key == alloc_key
2093 );
2094
2095 let m = rx.recv().await.unwrap();
2097 assert_matches!(
2098 m,
2099 RemoteProcessProcStateMessage::Done(got_alloc_key)
2100 if alloc_key == got_alloc_key
2101 );
2102
2103 remote_allocator.terminate();
2104 handle.await.unwrap().unwrap();
2105 }
2106
2107 #[timed_test::async_timed_test(timeout_secs = 15)]
2108 async fn test_trace_id_propagation_no_client_context() {
2109 let config = hyperactor::config::global::lock();
2110 let _guard = config.override_key(
2111 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2112 Duration::from_secs(60),
2113 );
2114 hyperactor_telemetry::initialize_logging(ClockKind::default());
2115 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
2116 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
2117 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
2118
2119 let extent = extent!(host = 1, gpu = 1);
2120 let tx = channel::dial(serve_addr.clone()).unwrap();
2121 let test_world_id: WorldId = id!(test_world_id);
2122
2123 let mut alloc = MockAlloc::new();
2125 alloc.expect_world_id().return_const(test_world_id.clone());
2126 alloc.expect_extent().return_const(extent.clone());
2127 alloc.expect_next().return_const(None);
2128
2129 let mut allocator = MockAllocator::new();
2131 allocator
2132 .expect_allocate()
2133 .times(1)
2134 .withf(move |spec: &AllocSpec| {
2135 spec.constraints.match_labels.is_empty()
2137 })
2138 .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
2139
2140 let remote_allocator = RemoteProcessAllocator::new();
2141 let handle = tokio::spawn({
2142 let remote_allocator = remote_allocator.clone();
2143 async move {
2144 remote_allocator
2145 .start_with_allocator(serve_addr, allocator, None)
2146 .await
2147 }
2148 });
2149
2150 let alloc_key = ShortUuid::generate();
2151 tx.send(RemoteProcessAllocatorMessage::Allocate {
2152 alloc_key: alloc_key.clone(),
2153 extent: extent.clone(),
2154 bootstrap_addr,
2155 hosts: vec![],
2156 client_context: None,
2157 forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
2158 })
2159 .await
2160 .unwrap();
2161
2162 let m = rx.recv().await.unwrap();
2164 assert_matches!(
2165 m,
2166 RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
2167 if got_world_id == test_world_id && got_alloc_key == alloc_key
2168 );
2169
2170 let m = rx.recv().await.unwrap();
2172 assert_matches!(
2173 m,
2174 RemoteProcessProcStateMessage::Done(got_alloc_key)
2175 if got_alloc_key == alloc_key
2176 );
2177
2178 remote_allocator.terminate();
2179 handle.await.unwrap().unwrap();
2180 }
2181}
2182
2183#[cfg(test)]
2184mod test_alloc {
2185 use std::os::unix::process::ExitStatusExt;
2186
2187 use hyperactor::clock::ClockKind;
2188 use hyperactor::config;
2189 use ndslice::extent;
2190 use nix::sys::signal;
2191 use nix::unistd::Pid;
2192 use timed_test::async_timed_test;
2193
2194 use super::*;
2195
2196 #[async_timed_test(timeout_secs = 60)]
2197 #[cfg(fbcode_build)]
2198 async fn test_alloc_simple() {
2199 let config = hyperactor::config::global::lock();
2201 let _guard1 = config.override_key(
2202 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2203 Duration::from_secs(1),
2204 );
2205 let _guard2 = config.override_key(
2206 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2207 Duration::from_millis(100),
2208 );
2209 hyperactor_telemetry::initialize_logging(ClockKind::default());
2210
2211 let spec = AllocSpec {
2212 extent: extent!(host = 2, gpu = 2),
2213 constraints: Default::default(),
2214 proc_name: None,
2215 transport: ChannelTransport::Unix,
2216 proc_allocation_mode: Default::default(),
2217 };
2218 let world_id = WorldId("test_world_id".to_string());
2219
2220 let task1_allocator = RemoteProcessAllocator::new();
2221 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2222 let task1_addr_string = task1_addr.to_string();
2223 let task1_cmd = Command::new(crate::testresource::get(
2224 "monarch/hyperactor_mesh/bootstrap",
2225 ));
2226 let task2_allocator = RemoteProcessAllocator::new();
2227 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2228 let task2_addr_string = task2_addr.to_string();
2229 let task2_cmd = Command::new(crate::testresource::get(
2230 "monarch/hyperactor_mesh/bootstrap",
2231 ));
2232 let task1_allocator_copy = task1_allocator.clone();
2233 let task1_allocator_handle = tokio::spawn(async move {
2234 tracing::info!("spawning task1");
2235 task1_allocator_copy
2236 .start(task1_cmd, task1_addr, None)
2237 .await
2238 .unwrap();
2239 });
2240 let task2_allocator_copy = task2_allocator.clone();
2241 let task2_allocator_handle = tokio::spawn(async move {
2242 task2_allocator_copy
2243 .start(task2_cmd, task2_addr, None)
2244 .await
2245 .unwrap();
2246 });
2247
2248 let mut initializer = MockRemoteProcessAllocInitializer::new();
2249 initializer.expect_initialize_alloc().return_once(move || {
2250 Ok(vec![
2251 RemoteProcessAllocHost {
2252 hostname: task1_addr_string,
2253 id: "task1".to_string(),
2254 },
2255 RemoteProcessAllocHost {
2256 hostname: task2_addr_string,
2257 id: "task2".to_string(),
2258 },
2259 ])
2260 });
2261 let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2262 .await
2263 .unwrap();
2264 let mut created = HashSet::new();
2265 let mut running_procs = HashSet::new();
2266 let mut proc_points = HashSet::new();
2267 for _ in 0..spec.extent.num_ranks() * 2 {
2268 let proc_state = alloc.next().await.unwrap();
2269 tracing::debug!("test got message: {:?}", proc_state);
2270 match proc_state {
2271 ProcState::Created {
2272 create_key, point, ..
2273 } => {
2274 created.insert(create_key);
2275 proc_points.insert(point);
2276 }
2277 ProcState::Running { create_key, .. } => {
2278 assert!(created.remove(&create_key));
2279 running_procs.insert(create_key);
2280 }
2281 _ => panic!("expected Created or Running"),
2282 }
2283 }
2284 assert!(created.is_empty());
2285 assert!(
2287 spec.extent
2288 .points()
2289 .all(|point| proc_points.contains(&point))
2290 );
2291
2292 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2294 tokio::select! {
2295 _ = hyperactor::clock::RealClock.sleep_until(timeout) => {},
2296 _ = alloc.next() => panic!("expected no more items"),
2297 }
2298
2299 alloc.stop().await.unwrap();
2301 for _ in 0..spec.extent.num_ranks() {
2302 let proc_state = alloc.next().await.unwrap();
2303 tracing::info!("test received next proc_state: {:?}", proc_state);
2304 match proc_state {
2305 ProcState::Stopped {
2306 create_key, reason, ..
2307 } => {
2308 assert!(running_procs.remove(&create_key));
2309 assert_eq!(reason, ProcStopReason::Stopped);
2310 }
2311 _ => panic!("expected stopped"),
2312 }
2313 }
2314 let proc_state = alloc.next().await;
2316 assert!(proc_state.is_none());
2317 let proc_state = alloc.next().await;
2319 assert!(proc_state.is_none());
2320
2321 task1_allocator.terminate();
2322 task1_allocator_handle.await.unwrap();
2323 task2_allocator.terminate();
2324 task2_allocator_handle.await.unwrap();
2325 }
2326
2327 #[async_timed_test(timeout_secs = 60)]
2328 #[cfg(fbcode_build)]
2329 async fn test_alloc_host_failure() {
2330 let config = hyperactor::config::global::lock();
2332 let _guard1 = config.override_key(
2333 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2334 Duration::from_secs(1),
2335 );
2336 let _guard2 = config.override_key(
2337 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2338 Duration::from_millis(100),
2339 );
2340 hyperactor_telemetry::initialize_logging(ClockKind::default());
2341
2342 let spec = AllocSpec {
2343 extent: extent!(host = 2, gpu = 2),
2344 constraints: Default::default(),
2345 proc_name: None,
2346 transport: ChannelTransport::Unix,
2347 proc_allocation_mode: Default::default(),
2348 };
2349 let world_id = WorldId("test_world_id".to_string());
2350
2351 let task1_allocator = RemoteProcessAllocator::new();
2352 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2353 let task1_addr_string = task1_addr.to_string();
2354 let task1_cmd = Command::new(crate::testresource::get(
2355 "monarch/hyperactor_mesh/bootstrap",
2356 ));
2357 let task2_allocator = RemoteProcessAllocator::new();
2358 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2359 let task2_addr_string = task2_addr.to_string();
2360 let task2_cmd = Command::new(crate::testresource::get(
2361 "monarch/hyperactor_mesh/bootstrap",
2362 ));
2363 let task1_allocator_copy = task1_allocator.clone();
2364 let task1_allocator_handle = tokio::spawn(async move {
2365 tracing::info!("spawning task1");
2366 task1_allocator_copy
2367 .start(task1_cmd, task1_addr, None)
2368 .await
2369 .unwrap();
2370 tracing::info!("task1 terminated");
2371 });
2372 let task2_allocator_copy = task2_allocator.clone();
2373 let task2_allocator_handle = tokio::spawn(async move {
2374 task2_allocator_copy
2375 .start(task2_cmd, task2_addr, None)
2376 .await
2377 .unwrap();
2378 tracing::info!("task2 terminated");
2379 });
2380
2381 let mut initializer = MockRemoteProcessAllocInitializer::new();
2382 initializer.expect_initialize_alloc().return_once(move || {
2383 Ok(vec![
2384 RemoteProcessAllocHost {
2385 hostname: task1_addr_string,
2386 id: "task1".to_string(),
2387 },
2388 RemoteProcessAllocHost {
2389 hostname: task2_addr_string,
2390 id: "task2".to_string(),
2391 },
2392 ])
2393 });
2394 let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2395 .await
2396 .unwrap();
2397 for _ in 0..spec.extent.num_ranks() * 2 {
2398 match alloc.next().await {
2399 Some(ProcState::Created { .. }) | Some(ProcState::Running { .. }) => {}
2400 _ => panic!("expected Created or Running"),
2401 }
2402 }
2403
2404 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2406 tokio::select! {
2407 _ = hyperactor::clock::RealClock
2408 .sleep_until(timeout) => {},
2409 _ = alloc.next() => panic!("expected no more items"),
2410 }
2411
2412 tracing::info!("aborting task1 allocator");
2414 task1_allocator_handle.abort();
2415 RealClock
2416 .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2)
2417 .await;
2418 for _ in 0..spec.extent.num_ranks() / 2 {
2419 let proc_state = alloc.next().await.unwrap();
2420 tracing::info!("test received next proc_state: {:?}", proc_state);
2421 match proc_state {
2422 ProcState::Stopped { reason, .. } => {
2423 assert_eq!(reason, ProcStopReason::HostWatchdog);
2424 }
2425 _ => panic!("expected stopped"),
2426 }
2427 }
2428 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2430 tokio::select! {
2431 _ = hyperactor::clock::RealClock
2432 .sleep_until(timeout) => {},
2433 _ = alloc.next() => panic!("expected no more items"),
2434 }
2435
2436 tracing::info!("aborting task2 allocator");
2438 task2_allocator_handle.abort();
2439 RealClock
2440 .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2)
2441 .await;
2442 for _ in 0..spec.extent.num_ranks() / 2 {
2443 let proc_state = alloc.next().await.unwrap();
2444 tracing::info!("test received next proc_state: {:?}", proc_state);
2445 match proc_state {
2446 ProcState::Stopped { reason, .. } => {
2447 assert_eq!(reason, ProcStopReason::HostWatchdog);
2448 }
2449 _ => panic!("expected stopped"),
2450 }
2451 }
2452 let proc_state = alloc.next().await;
2454 assert!(proc_state.is_none());
2455 let proc_state = alloc.next().await;
2457 assert!(proc_state.is_none());
2458 }
2459
2460 #[async_timed_test(timeout_secs = 15)]
2461 #[cfg(fbcode_build)]
2462 async fn test_alloc_inner_alloc_failure() {
2463 unsafe {
2465 std::env::set_var("MONARCH_MESSAGE_DELIVERY_TIMEOUT_SECS", "1");
2466 }
2467
2468 let config = hyperactor::config::global::lock();
2469 let _guard = config.override_key(
2470 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2471 Duration::from_millis(100),
2472 );
2473 hyperactor_telemetry::initialize_logging_for_test();
2474
2475 let spec = AllocSpec {
2476 extent: extent!(host = 2, gpu = 2),
2477 constraints: Default::default(),
2478 proc_name: None,
2479 transport: ChannelTransport::Unix,
2480 proc_allocation_mode: Default::default(),
2481 };
2482 let world_id = WorldId("test_world_id".to_string());
2483
2484 let task1_allocator = RemoteProcessAllocator::new();
2485 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2486 let task1_addr_string = task1_addr.to_string();
2487 let task1_cmd = Command::new(crate::testresource::get(
2488 "monarch/hyperactor_mesh/bootstrap",
2489 ));
2490 let task2_allocator = RemoteProcessAllocator::new();
2491 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2492 let task2_addr_string = task2_addr.to_string();
2493 let task2_cmd = Command::new("/caught/somewhere/in/time");
2495 let task1_allocator_copy = task1_allocator.clone();
2496 let task1_allocator_handle = tokio::spawn(async move {
2497 tracing::info!("spawning task1");
2498 task1_allocator_copy
2499 .start(task1_cmd, task1_addr, None)
2500 .await
2501 .unwrap();
2502 });
2503 let task2_allocator_copy = task2_allocator.clone();
2504 let task2_allocator_handle = tokio::spawn(async move {
2505 task2_allocator_copy
2506 .start(task2_cmd, task2_addr, None)
2507 .await
2508 .unwrap();
2509 });
2510
2511 let mut initializer = MockRemoteProcessAllocInitializer::new();
2512 initializer.expect_initialize_alloc().return_once(move || {
2513 Ok(vec![
2514 RemoteProcessAllocHost {
2515 hostname: task1_addr_string,
2516 id: "task1".to_string(),
2517 },
2518 RemoteProcessAllocHost {
2519 hostname: task2_addr_string,
2520 id: "task2".to_string(),
2521 },
2522 ])
2523 });
2524 let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2525 .await
2526 .unwrap();
2527 let mut created = HashSet::new();
2528 let mut started_procs = HashSet::new();
2529 let mut proc_points = HashSet::new();
2530 let mut failed = 0;
2531 for _ in 0..spec.extent.num_ranks() + 1 {
2533 let proc_state = alloc.next().await.unwrap();
2534 tracing::debug!("test got message: {:?}", proc_state);
2535 match proc_state {
2536 ProcState::Created {
2537 create_key, point, ..
2538 } => {
2539 created.insert(create_key);
2540 proc_points.insert(point);
2541 }
2542 ProcState::Running { create_key, .. } => {
2543 assert!(created.remove(&create_key));
2544 started_procs.insert(create_key);
2545 }
2546 ProcState::Failed { .. } => {
2547 failed += 1;
2548 }
2549 _ => panic!("expected Created, Running or Failed"),
2550 }
2551 }
2552 assert!(created.is_empty());
2553 assert_eq!(failed, 1);
2554 for rank in 0..spec.extent.num_ranks() / 2 {
2556 let point = spec.extent.point_of_rank(rank).unwrap();
2557 assert!(proc_points.contains(&point));
2558 }
2559
2560 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2562 tokio::select! {
2563 _ = hyperactor::clock::RealClock
2564 .sleep_until(timeout) => {},
2565 _ = alloc.next() => panic!("expected no more items"),
2566 }
2567
2568 alloc.stop().await.unwrap();
2570 for _ in 0..spec.extent.num_ranks() / 2 {
2571 let proc_state = alloc.next().await.unwrap();
2572 tracing::info!("test received next proc_state: {:?}", proc_state);
2573 match proc_state {
2574 ProcState::Stopped {
2575 create_key, reason, ..
2576 } => {
2577 assert!(started_procs.remove(&create_key));
2578 assert_eq!(reason, ProcStopReason::Stopped);
2579 }
2580 _ => panic!("expected stopped"),
2581 }
2582 }
2583 let proc_state = alloc.next().await;
2585 assert!(proc_state.is_none());
2586 let proc_state = alloc.next().await;
2588 assert!(proc_state.is_none());
2589
2590 task1_allocator.terminate();
2591 task1_allocator_handle.await.unwrap();
2592 task2_allocator.terminate();
2593 task2_allocator_handle.await.unwrap();
2594 }
2595
2596 #[tracing_test::traced_test]
2597 #[async_timed_test(timeout_secs = 60)]
2598 #[cfg(fbcode_build)]
2599 async fn test_remote_process_alloc_signal_handler() {
2600 let num_proc_meshes = 5;
2601 let hosts_per_proc_mesh = 5;
2602
2603 let pid_addr = ChannelAddr::any(ChannelTransport::Unix);
2604 let (pid_addr, mut pid_rx) = channel::serve::<u32>(pid_addr).unwrap();
2605
2606 let addresses = (0..(num_proc_meshes * hosts_per_proc_mesh))
2607 .map(|_| ChannelAddr::any(ChannelTransport::Unix).to_string())
2608 .collect::<Vec<_>>();
2609
2610 let remote_process_allocators = addresses
2611 .iter()
2612 .map(|addr| {
2613 Command::new(crate::testresource::get(
2614 "monarch/hyperactor_mesh/remote_process_allocator",
2615 ))
2616 .env("RUST_LOG", "info")
2617 .arg(format!("--addr={addr}"))
2618 .stdout(std::process::Stdio::piped())
2619 .spawn()
2620 .unwrap()
2621 })
2622 .collect::<Vec<_>>();
2623
2624 let done_allocating_addr = ChannelAddr::any(ChannelTransport::Unix);
2625 let (done_allocating_addr, mut done_allocating_rx) =
2626 channel::serve::<()>(done_allocating_addr).unwrap();
2627 let mut remote_process_alloc = Command::new(crate::testresource::get(
2628 "monarch/hyperactor_mesh/remote_process_alloc",
2629 ))
2630 .arg(format!("--done-allocating-addr={}", done_allocating_addr))
2631 .arg(format!("--addresses={}", addresses.join(",")))
2632 .arg(format!("--num-proc-meshes={}", num_proc_meshes))
2633 .arg(format!("--hosts-per-proc-mesh={}", hosts_per_proc_mesh))
2634 .arg(format!("--pid-addr={}", pid_addr))
2635 .spawn()
2636 .unwrap();
2637
2638 done_allocating_rx.recv().await.unwrap();
2639 let mut received_pids = Vec::new();
2640 while let Ok(pid) = pid_rx.recv().await {
2641 received_pids.push(pid);
2642 if received_pids.len() == remote_process_allocators.len() {
2643 break;
2644 }
2645 }
2646
2647 signal::kill(
2648 Pid::from_raw(remote_process_alloc.id().unwrap() as i32),
2649 signal::SIGINT,
2650 )
2651 .unwrap();
2652
2653 assert_eq!(
2654 remote_process_alloc.wait().await.unwrap().signal(),
2655 Some(signal::SIGINT as i32)
2656 );
2657
2658 RealClock.sleep(tokio::time::Duration::from_secs(5)).await;
2659
2660 for child_pid in received_pids {
2662 let pid_check = Command::new("kill")
2663 .arg("-0")
2664 .arg(child_pid.to_string())
2665 .output()
2666 .await
2667 .expect("Failed to check if PID is alive");
2668
2669 assert!(
2670 !pid_check.status.success(),
2671 "PID {} should no longer be alive",
2672 child_pid
2673 );
2674 }
2675
2676 for mut remote_process_allocator in remote_process_allocators {
2679 remote_process_allocator.kill().await.unwrap();
2680 }
2681 }
2682}