1use std::collections::HashMap;
10use std::collections::HashSet;
11use std::collections::VecDeque;
12use std::sync::Arc;
13use std::time::Duration;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use dashmap::DashMap;
18use futures::FutureExt;
19use futures::future::join_all;
20use futures::future::select_all;
21use hyperactor::Named;
22use hyperactor::ProcId;
23use hyperactor::WorldId;
24use hyperactor::channel;
25use hyperactor::channel::ChannelAddr;
26use hyperactor::channel::ChannelRx;
27use hyperactor::channel::ChannelTransport;
28use hyperactor::channel::ChannelTx;
29use hyperactor::channel::Rx;
30use hyperactor::channel::Tx;
31use hyperactor::channel::TxStatus;
32use hyperactor::clock;
33use hyperactor::clock::Clock;
34use hyperactor::clock::RealClock;
35use hyperactor::config;
36use hyperactor::mailbox::DialMailboxRouter;
37use hyperactor::mailbox::MailboxServer;
38use hyperactor::reference::Reference;
39use hyperactor::serde_json;
40use mockall::automock;
41use ndslice::View;
42use ndslice::ViewExt;
43use ndslice::view::Extent;
44use ndslice::view::Point;
45use serde::Deserialize;
46use serde::Serialize;
47use tokio::io::AsyncWriteExt;
48use tokio::process::Command;
49use tokio::sync::mpsc::UnboundedReceiver;
50use tokio::sync::mpsc::UnboundedSender;
51use tokio::sync::mpsc::unbounded_channel;
52use tokio::task::JoinHandle;
53use tokio_stream::StreamExt;
54use tokio_stream::wrappers::WatchStream;
55use tokio_util::sync::CancellationToken;
56
57use crate::alloc::Alloc;
58use crate::alloc::AllocConstraints;
59use crate::alloc::AllocSpec;
60use crate::alloc::Allocator;
61use crate::alloc::AllocatorError;
62use crate::alloc::ProcState;
63use crate::alloc::ProcStopReason;
64use crate::alloc::ProcessAllocator;
65use crate::alloc::process::CLIENT_TRACE_ID_LABEL;
66use crate::alloc::process::ClientContext;
67
68#[derive(Debug, Clone, Serialize, Deserialize, Named)]
70pub enum RemoteProcessAllocatorMessage {
71 Allocate {
73 view: View,
75 bootstrap_addr: ChannelAddr,
77 hosts: Vec<String>,
80 client_context: Option<ClientContext>,
84 },
85 Stop,
87 HeartBeat,
90}
91
92#[derive(Debug, Clone, Serialize, Deserialize, Named)]
95pub enum RemoteProcessProcStateMessage {
96 Allocated { world_id: WorldId, view: View },
98 Update(ProcState),
100 Done(WorldId),
102 HeartBeat,
104}
105
106pub struct RemoteProcessAllocator {
108 cancel_token: CancellationToken,
109}
110
111async fn conditional_sleeper<F: futures::Future<Output = ()>>(t: Option<F>) {
112 match t {
113 Some(timer) => timer.await,
114 None => futures::future::pending().await,
115 }
116}
117
118impl RemoteProcessAllocator {
119 pub fn new() -> Arc<Self> {
121 Arc::new(Self {
122 cancel_token: CancellationToken::new(),
123 })
124 }
125
126 pub fn terminate(&self) {
128 self.cancel_token.cancel();
129 }
130
131 pub async fn start(
147 &self,
148 cmd: Command,
149 serve_addr: ChannelAddr,
150 timeout: Option<Duration>,
151 ) -> Result<(), anyhow::Error> {
152 let process_allocator = ProcessAllocator::new(cmd);
153 self.start_with_allocator(serve_addr, process_allocator, timeout)
154 .await
155 }
156
157 pub async fn start_with_allocator<A: Allocator + Send + Sync + 'static>(
161 &self,
162 serve_addr: ChannelAddr,
163 mut process_allocator: A,
164 timeout: Option<Duration>,
165 ) -> Result<(), anyhow::Error>
166 where
167 <A as Allocator>::Alloc: Send,
168 <A as Allocator>::Alloc: Sync,
169 {
170 tracing::info!("starting remote allocator on: {}", serve_addr);
171 let (_, mut rx) = channel::serve(serve_addr.clone())
172 .await
173 .map_err(anyhow::Error::from)?;
174
175 struct ActiveAllocation {
176 handle: JoinHandle<()>,
177 cancel_token: CancellationToken,
178 }
179 async fn ensure_previous_alloc_stopped(active_allocation: &mut Option<ActiveAllocation>) {
180 if let Some(active_allocation) = active_allocation.take() {
181 tracing::info!("previous alloc found, stopping");
182 active_allocation.cancel_token.cancel();
183 match active_allocation.handle.await {
184 Ok(_) => {
185 tracing::info!("allocation stopped.")
186 }
187 Err(e) => {
188 tracing::error!("allocation handler failed: {}", e);
189 }
190 }
191 }
192 }
193
194 let mut active_allocation: Option<ActiveAllocation> = None;
195 loop {
196 let sleep = conditional_sleeper(timeout.map(|t| RealClock.sleep(t)));
199 tokio::select! {
200 msg = rx.recv() => {
201 match msg {
202 Ok(RemoteProcessAllocatorMessage::Allocate {
203 view,
204 bootstrap_addr,
205 hosts,
206 client_context,
207 }) => {
208 tracing::info!("received allocation request for view: {}", view);
209 ensure_previous_alloc_stopped(&mut active_allocation).await;
210 tracing::info!("allocating...");
211
212 let mut constraints: AllocConstraints = Default::default();
214 if let Some(context) = &client_context {
215 constraints = AllocConstraints {
216 match_labels: HashMap::from([(
217 CLIENT_TRACE_ID_LABEL.to_string(),
218 context.trace_id.to_string(),
219 )]
220 )};
221 }
222 let spec = AllocSpec {
223 extent: view.extent(),
224 constraints,
225 };
226
227 match process_allocator.allocate(spec.clone()).await {
228 Ok(alloc) => {
229 let cancel_token = CancellationToken::new();
230 active_allocation = Some(ActiveAllocation {
231 cancel_token: cancel_token.clone(),
232 handle: tokio::spawn(Self::handle_allocation_request(
233 Box::new(alloc) as Box<dyn Alloc + Send + Sync>,
234 view,
235 serve_addr.transport(),
236 bootstrap_addr,
237 hosts,
238 cancel_token,
239 )),
240 })
241 }
242 Err(e) => {
243 tracing::error!("allocation for {:?} failed: {}", spec, e);
244 continue;
245 }
246 }
247 }
248 Ok(RemoteProcessAllocatorMessage::Stop) => {
249 tracing::info!("received stop request");
250
251 ensure_previous_alloc_stopped(&mut active_allocation).await;
252 }
253 Ok(RemoteProcessAllocatorMessage::HeartBeat) => {}
257 Err(e) => {
258 tracing::error!("upstream channel error: {}", e);
259 continue;
260 }
261 }
262 }
263 _ = self.cancel_token.cancelled() => {
264 tracing::info!("main loop cancelled");
265
266 ensure_previous_alloc_stopped(&mut active_allocation).await;
267
268 break;
269 }
270 _ = sleep => {
271 if active_allocation.is_some() {
273 continue;
274 }
275 tracing::warn!("timeout of {} seconds elapsed without any allocations, exiting", timeout.unwrap_or_default().as_secs());
278 break;
279 }
280 }
281 }
282
283 Ok(())
284 }
285
286 async fn handle_allocation_request(
287 alloc: Box<dyn Alloc + Send + Sync>,
288 view: View,
289 serve_transport: ChannelTransport,
290 bootstrap_addr: ChannelAddr,
291 hosts: Vec<String>,
292 cancel_token: CancellationToken,
293 ) {
294 tracing::info!("handle allocation request, bootstrap_addr: {bootstrap_addr}");
295 let (forwarder_addr, forwarder_rx) =
300 match channel::serve(ChannelAddr::any(serve_transport)).await {
301 Ok(v) => v,
302 Err(e) => {
303 tracing::error!("failed to to bootstrap forwarder actor: {}", e);
304 return;
305 }
306 };
307 let router = DialMailboxRouter::new();
308 let mailbox_handle = router.clone().serve(forwarder_rx);
309 tracing::info!("started forwarder on: {}", forwarder_addr);
310
311 if let Ok(hosts_file) = std::env::var("TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE") {
314 tracing::info!("writing hosts to {}", hosts_file);
315 #[derive(Serialize)]
316 struct Hosts {
317 hostnames: Vec<String>,
318 }
319 match serde_json::to_string(&Hosts { hostnames: hosts }) {
320 Ok(json) => match tokio::fs::File::create(&hosts_file).await {
321 Ok(mut file) => {
322 if file.write_all(json.as_bytes()).await.is_err() {
323 tracing::error!("failed to write hosts to {}", hosts_file);
324 return;
325 }
326 }
327 Err(e) => {
328 tracing::error!("failed to open hosts file {}: {}", hosts_file, e);
329 return;
330 }
331 },
332 Err(e) => {
333 tracing::error!("failed to serialize hosts: {}", e);
334 return;
335 }
336 }
337 }
338
339 Self::handle_allocation_loop(
340 alloc,
341 view,
342 bootstrap_addr,
343 router,
344 forwarder_addr,
345 cancel_token,
346 )
347 .await;
348
349 mailbox_handle.stop("alloc stopped");
350 if let Err(e) = mailbox_handle.await {
351 tracing::error!("failed to join forwarder: {}", e);
352 }
353 }
354
355 async fn handle_allocation_loop(
356 mut alloc: Box<dyn Alloc + Send + Sync>,
357 view: View,
358 bootstrap_addr: ChannelAddr,
359 router: DialMailboxRouter,
360 forward_addr: ChannelAddr,
361 cancel_token: CancellationToken,
362 ) {
363 tracing::info!("starting handle allocation loop");
364 let tx = match channel::dial(bootstrap_addr) {
365 Ok(tx) => tx,
366 Err(err) => {
367 tracing::error!("failed to dial bootstrap address: {}", err);
368 return;
369 }
370 };
371 if let Err(e) = tx
372 .send(RemoteProcessProcStateMessage::Allocated {
373 world_id: alloc.world_id().clone(),
374 view,
375 })
376 .await
377 {
378 tracing::error!("failed to send Allocated message: {}", e);
379 return;
380 }
381
382 let mut mesh_agents_by_proc_id = HashMap::new();
383 let mut running = true;
384 let tx_status = tx.status().clone();
385 let mut tx_watcher = WatchStream::new(tx_status);
386 loop {
387 tokio::select! {
388 _ = cancel_token.cancelled(), if running => {
389 tracing::info!("cancelled, stopping allocation");
390 running = false;
391 if let Err(e) = alloc.stop().await {
392 tracing::error!("stop failed: {}", e);
393 break;
394 }
395 }
396 status = tx_watcher.next(), if running => {
397 match status {
398 Some(TxStatus::Closed) => {
399 tracing::error!("upstream channel state closed");
400 break;
401 },
402 _ => {
403 tracing::debug!("got channel event: {:?}", status.unwrap());
404 continue;
405 }
406 }
407 }
408 e = alloc.next() => {
409 match e {
410 Some(event) => {
411 tracing::debug!("got event: {:?}", event);
412 let event = match event {
413 ProcState::Created { .. } => event,
414 ProcState::Running { proc_id, mesh_agent, addr } => {
415 tracing::debug!("remapping mesh_agent {}: addr {} -> {}", mesh_agent, addr, forward_addr);
416 mesh_agents_by_proc_id.insert(proc_id.clone(), mesh_agent.clone());
417 router.bind(mesh_agent.actor_id().proc_id().clone().into(), addr);
418 ProcState::Running { proc_id, mesh_agent, addr: forward_addr.clone() }
419 },
420 ProcState::Stopped { proc_id, reason } => {
421 match mesh_agents_by_proc_id.remove(&proc_id) {
422 Some(mesh_agent) => {
423 tracing::debug!("unmapping mesh_agent {}", mesh_agent);
424 let agent_ref: Reference = mesh_agent.actor_id().proc_id().clone().into();
425 router.unbind(&agent_ref);
426 },
427 None => {
428 tracing::warn!("mesh_agent not found for proc_id: {}", proc_id);
429 }
430 }
431 ProcState::Stopped { proc_id, reason }
432 },
433 ProcState::Failed { ref world_id, ref description } => {
434 tracing::error!("allocation failed for {}: {}", world_id, description);
435 event
436 }
437 };
438 tracing::debug!("sending event: {:?}", event);
439 tx.post(RemoteProcessProcStateMessage::Update(event));
440 }
441 None => {
442 tracing::debug!("sending done");
443 tx.post(RemoteProcessProcStateMessage::Done(alloc.world_id().clone()));
444 running = false;
445 break;
446 }
447 }
448 }
449 _ = RealClock.sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)) => {
450 tracing::trace!("sending heartbeat");
451 tx.post(RemoteProcessProcStateMessage::HeartBeat);
452 }
453 }
454 }
455 tracing::info!("allocation handler loop exited");
456 if running {
457 tracing::info!("stopping processes");
458 if let Err(e) = alloc.stop_and_wait().await {
459 tracing::error!("stop failed: {}", e);
460 return;
461 }
462 tracing::info!("stop finished");
463 }
464 }
465}
466
467type HostId = String;
470
471#[derive(Clone)]
473pub struct RemoteProcessAllocHost {
474 pub id: HostId,
477 pub hostname: String,
479}
480
481struct RemoteProcessAllocHostState {
483 host_id: HostId,
485 tx: ChannelTx<RemoteProcessAllocatorMessage>,
487 active_procs: HashSet<ProcId>,
489 offset: usize,
491 world_id: Option<WorldId>,
493 failed: bool,
495 allocated: bool,
497}
498
499#[automock]
500#[async_trait]
501pub trait RemoteProcessAllocInitializer {
503 async fn initialize_alloc(&mut self) -> Result<Vec<RemoteProcessAllocHost>, anyhow::Error>;
505}
506
507struct HostStates {
510 inner: HashMap<HostId, RemoteProcessAllocHostState>,
511 host_addresses: Arc<DashMap<HostId, ChannelAddr>>,
512}
513
514impl HostStates {
515 fn new(host_addresses: Arc<DashMap<HostId, ChannelAddr>>) -> HostStates {
516 Self {
517 inner: HashMap::new(),
518 host_addresses,
519 }
520 }
521
522 fn insert(
523 &mut self,
524 host_id: HostId,
525 state: RemoteProcessAllocHostState,
526 address: ChannelAddr,
527 ) {
528 self.host_addresses.insert(host_id.clone(), address);
529 self.inner.insert(host_id, state);
530 }
531
532 fn get(&self, host_id: &HostId) -> Option<&RemoteProcessAllocHostState> {
533 self.inner.get(host_id)
534 }
535
536 fn get_mut(&mut self, host_id: &HostId) -> Option<&mut RemoteProcessAllocHostState> {
537 self.inner.get_mut(host_id)
538 }
539
540 fn remove(&mut self, host_id: &HostId) -> Option<RemoteProcessAllocHostState> {
541 self.host_addresses.remove(host_id);
542 self.inner.remove(host_id)
543 }
544
545 fn iter(&self) -> impl Iterator<Item = (&HostId, &RemoteProcessAllocHostState)> {
546 self.inner.iter()
547 }
548
549 fn iter_mut(&mut self) -> impl Iterator<Item = (&HostId, &mut RemoteProcessAllocHostState)> {
550 self.inner.iter_mut()
551 }
552
553 fn is_empty(&self) -> bool {
554 self.inner.is_empty()
555 }
556 }
558
559pub struct RemoteProcessAlloc {
562 initializer: Box<dyn RemoteProcessAllocInitializer + Send + Sync>,
565 spec: AllocSpec,
566 remote_allocator_port: u16,
567 transport: ChannelTransport,
568 world_id: WorldId,
569 ordered_hosts: Vec<RemoteProcessAllocHost>,
570 started: bool,
572 running: bool,
574 failed: bool,
576 hosts_by_offset: HashMap<usize, HostId>,
577 host_states: HostStates,
578 world_offsets: HashMap<WorldId, usize>,
579 event_queue: VecDeque<ProcState>,
580 comm_watcher_tx: UnboundedSender<HostId>,
581 comm_watcher_rx: UnboundedReceiver<HostId>,
582
583 bootstrap_addr: ChannelAddr,
584 rx: ChannelRx<RemoteProcessProcStateMessage>,
585 _signal_cleanup_guard: hyperactor::SignalCleanupGuard,
586}
587
588impl RemoteProcessAlloc {
589 pub async fn new(
594 spec: AllocSpec,
595 world_id: WorldId,
596 transport: ChannelTransport,
597 remote_allocator_port: u16,
598 initializer: impl RemoteProcessAllocInitializer + Send + Sync + 'static,
599 ) -> Result<Self, anyhow::Error> {
600 let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(transport.clone()))
601 .await
602 .map_err(anyhow::Error::from)?;
603
604 tracing::info!(
605 "starting alloc for {} on: {}",
606 world_id,
607 bootstrap_addr.clone()
608 );
609
610 let (comm_watcher_tx, comm_watcher_rx) = unbounded_channel();
611
612 let host_addresses = Arc::new(DashMap::<HostId, ChannelAddr>::new());
613 let host_addresses_for_signal = host_addresses.clone();
614
615 let signal_cleanup_guard =
617 hyperactor::register_signal_cleanup_scoped(Box::pin(async move {
618 join_all(host_addresses_for_signal.iter().map(|entry| async move {
619 let addr = entry.value().clone();
620 match channel::dial(addr.clone()) {
621 Ok(tx) => {
622 if let Err(e) = tx.send(RemoteProcessAllocatorMessage::Stop).await {
623 tracing::error!("Failed to send Stop to {}: {}", addr, e);
624 }
625 }
626 Err(e) => {
627 tracing::error!("Failed to dial {} during signal cleanup: {}", addr, e);
628 }
629 }
630 }))
631 .await;
632 }));
633
634 Ok(Self {
635 spec,
636 world_id,
637 transport,
638 remote_allocator_port,
639 initializer: Box::new(initializer),
640 world_offsets: HashMap::new(),
641 ordered_hosts: Vec::new(),
642 hosts_by_offset: HashMap::new(),
643 host_states: HostStates::new(host_addresses),
644 bootstrap_addr,
645 event_queue: VecDeque::new(),
646 comm_watcher_tx,
647 comm_watcher_rx,
648 rx,
649 started: false,
650 running: true,
651 failed: false,
652 _signal_cleanup_guard: signal_cleanup_guard,
653 })
654 }
655
656 async fn start_comm_watcher(&self) {
664 let mut tx_watchers = Vec::new();
665 for host in &self.ordered_hosts {
666 let tx_status = self.host_states.get(&host.id).unwrap().tx.status().clone();
667 let watcher = WatchStream::new(tx_status);
668 tx_watchers.push((watcher, host.id.clone()));
669 }
670 assert!(!tx_watchers.is_empty());
671 let tx = self.comm_watcher_tx.clone();
672 tokio::spawn(async move {
673 loop {
674 let mut tx_status_futures = Vec::new();
675 for (watcher, _) in &mut tx_watchers {
676 let fut = watcher.next().boxed();
677 tx_status_futures.push(fut);
678 }
679 let (tx_status, index, _) = select_all(tx_status_futures).await;
680 let host_id = match tx_watchers.get(index) {
681 Some((_, host_id)) => host_id.clone(),
682 None => {
683 tracing::error!(
685 "got selected index {} with no matching host in {}",
686 index,
687 tx_watchers.len()
688 );
689 continue;
690 }
691 };
692 if let Some(tx_status) = tx_status {
693 tracing::debug!("host {} channel event: {:?}", host_id, tx_status);
694 if tx_status == TxStatus::Closed {
695 if tx.send(host_id.clone()).is_err() {
696 break;
698 }
699 tx_watchers.remove(index);
700 if tx_watchers.is_empty() {
701 break;
703 }
704 }
705 }
706 }
707 });
708 }
709
710 async fn ensure_started(&mut self) -> Result<(), anyhow::Error> {
715 if self.started || self.failed {
716 return Ok(());
717 }
718
719 self.started = true;
720 let hosts = self
721 .initializer
722 .initialize_alloc()
723 .await
724 .context("alloc initializer error")?;
725 if hosts.is_empty() {
726 anyhow::bail!("Initializer returned empty list of hosts");
727 }
728 let hostnames: Vec<_> = hosts.iter().map(|e| e.hostname.clone()).collect();
731 tracing::info!("obtained {} hosts for this allocation", hostnames.len());
732
733 anyhow::ensure!(
735 self.spec.extent.len() >= 2,
736 "invalid extent: {}, expected at least 2 dimensions",
737 self.spec.extent
738 );
739
740 let split_dim = &self.spec.extent.labels()[self.spec.extent.len() - 1];
742 for (i, view) in self.spec.extent.group_by(split_dim)?.enumerate() {
743 let host = &hosts[i];
744 tracing::debug!("allocating: {} for host: {}", view, host.id);
745
746 let remote_addr = match self.transport {
747 ChannelTransport::MetaTls(_) => {
748 format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
749 }
750 ChannelTransport::Tcp => {
751 format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
752 }
753 ChannelTransport::Unix => host.hostname.clone(),
755 _ => {
756 anyhow::bail!(
757 "unsupported transport for host {}: {:?}",
758 host.id,
759 self.transport
760 );
761 }
762 };
763
764 let offset = offset_of_view(&view).unwrap();
765
766 tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
767 let remote_addr = remote_addr.parse::<ChannelAddr>()?;
768 let tx = channel::dial(remote_addr.clone())
769 .map_err(anyhow::Error::from)
770 .context(format!(
771 "failed to dial remote {} for host {}",
772 remote_addr, host.id
773 ))?;
774
775 let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
776 let client_context = Some(ClientContext { trace_id });
777
778 tx.post(RemoteProcessAllocatorMessage::Allocate {
779 view,
780 bootstrap_addr: self.bootstrap_addr.clone(),
781 hosts: hostnames.clone(),
782 client_context,
783 });
784
785 self.hosts_by_offset.insert(offset, host.id.clone());
786 self.host_states.insert(
787 host.id.clone(),
788 RemoteProcessAllocHostState {
789 host_id: host.id.clone(),
790 tx,
791 active_procs: HashSet::new(),
792 offset,
793 world_id: None,
794 failed: false,
795 allocated: false,
796 },
797 remote_addr,
798 );
799 }
800
801 self.ordered_hosts = hosts;
802 self.start_comm_watcher().await;
803 self.started = true;
804
805 Ok(())
806 }
807
808 fn host_id_for_world_id(&self, world_id: &WorldId) -> Option<HostId> {
809 let offset = self.world_offsets.get(world_id)?;
810 self.hosts_by_offset.get(offset).cloned()
811 }
812
813 fn host_state_for_world_id(
814 &mut self,
815 world_id: &WorldId,
816 ) -> Result<&mut RemoteProcessAllocHostState, anyhow::Error> {
817 if let Some(host_id) = self.host_id_for_world_id(world_id) {
818 if let Some(task_state) = self.host_states.get_mut(&host_id) {
819 Ok(task_state)
820 } else {
821 anyhow::bail!(
822 "task state not found for world id: {}, host id: {}",
823 world_id,
824 host_id
825 );
826 }
827 } else {
828 anyhow::bail!("task not found for world id: {}", world_id);
829 }
830 }
831
832 fn host_state_for_proc_id(
834 &mut self,
835 proc_id: &ProcId,
836 ) -> Result<&mut RemoteProcessAllocHostState, anyhow::Error> {
837 self.host_state_for_world_id(
838 proc_id
839 .world_id()
840 .expect("proc must be ranked for host state lookup"),
841 )
842 }
843
844 fn add_proc_id_to_host_state(&mut self, proc_id: &ProcId) -> Result<(), anyhow::Error> {
845 let task_state = self.host_state_for_proc_id(proc_id)?;
846 if !task_state.active_procs.insert(proc_id.clone()) {
847 tracing::error!("proc id already in host state: {}", proc_id);
849 }
850 task_state.allocated = true;
851 Ok(())
852 }
853
854 fn remove_proc_from_host_state(&mut self, proc_id: &ProcId) -> Result<(), anyhow::Error> {
855 let task_state = self.host_state_for_proc_id(proc_id)?;
856 if !task_state.active_procs.remove(proc_id) {
857 tracing::error!("proc id already in host state: {}", proc_id);
859 }
860 Ok(())
861 }
862
863 fn project_proc_into_global_extent(
865 &self,
866 proc_id: &ProcId,
867 point: &Point,
868 ) -> Result<Point, anyhow::Error> {
869 let world_id = proc_id
870 .world_id()
871 .expect("proc must be ranked for point mapping");
872 let offset = self
873 .world_offsets
874 .get(world_id)
875 .ok_or_else(|| anyhow::anyhow!("could not find offset for world id: {}", world_id))?;
876
877 Ok(self.spec.extent.point_of_rank(offset + point.rank())?)
878 }
879
880 fn cleanup_host_channel_closed(
882 &mut self,
883 host_id: HostId,
884 ) -> Result<Vec<ProcId>, anyhow::Error> {
885 let state = match self.host_states.remove(&host_id) {
886 Some(state) => state,
887 None => {
888 anyhow::bail!(
890 "got channel closed event for host {} which has no known state",
891 host_id
892 );
893 }
894 };
895 self.ordered_hosts.retain(|host| host.id != host_id);
896 self.hosts_by_offset.remove(&state.offset);
897 if let Some(world_id) = state.world_id {
898 self.world_offsets.remove(&world_id);
899 }
900 let proc_ids = state.active_procs.iter().cloned().collect();
901
902 Ok(proc_ids)
903 }
904}
905
906#[async_trait]
907impl Alloc for RemoteProcessAlloc {
908 async fn next(&mut self) -> Option<ProcState> {
909 loop {
910 if let state @ Some(_) = self.event_queue.pop_front() {
911 break state;
912 }
913
914 if !self.running {
915 break None;
916 }
917
918 if let Err(e) = self.ensure_started().await {
919 break Some(ProcState::Failed {
920 world_id: self.world_id.clone(),
921 description: format!("failed to ensure started: {:#}", e),
922 });
923 }
924
925 let heartbeat_interval =
926 config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL);
927 let mut heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
928 let mut reloop = false;
930 let update = loop {
931 tokio::select! {
932 msg = self.rx.recv() => {
933 tracing::debug!("got ProcState message from allocator: {:?}", msg);
934 match msg {
935 Ok(RemoteProcessProcStateMessage::Allocated { world_id, view }) => {
936 let Some(offset) = offset_of_view(&view) else {
937 tracing::error!(
938 "allocated {}: cannot derive offset from allocated view: {}",
939 world_id,
940 view
941 );
942 continue;
943 };
944 tracing::info!("received allocated world id: {}", world_id);
945 match self.hosts_by_offset.get(&offset) {
946 Some(host_id) => {
947 match self.host_states.get_mut(host_id) {
949 Some(state) => {
950 state.world_id = Some(world_id.clone());
951 if let Some(old_entry) = self.world_offsets.insert(world_id.clone(), offset) {
952 tracing::warn!(
954 "got allocated for duplicate world id: {} with known offset: {}",
955 world_id,
956 old_entry
957 );
958 }
959 }
960 None => {
961 tracing::error!(
963 "got allocated for host ID: {} with no known state",
964 host_id
965 );
966 }
967 }
968 }
969 None => {
970 tracing::error!(
972 "got allocated for unknown world: {}, view: {}",
973 world_id,
974 view
975 );
976 }
977 }
978 }
979 Ok(RemoteProcessProcStateMessage::Update(proc_state)) => {
980 break match proc_state {
981 ProcState::Created { ref proc_id, .. } => {
982 if let Err(e) = self.add_proc_id_to_host_state(proc_id) {
983 tracing::error!("failed to add proc id to host state: {}", e);
984 }
985 Some(proc_state)
986 }
987 ProcState::Stopped{ ref proc_id, ..} => {
988 if let Err(e) = self.remove_proc_from_host_state(proc_id) {
989 tracing::error!("failed to remove proc id from host state: {}", e);
990 }
991 Some(proc_state)
992 }
993 ProcState::Failed { ref world_id, ref description } => {
994 match self.host_state_for_world_id(world_id) {
995 Ok(state) => {
996 state.failed = true;
997 Some(ProcState::Failed {
998 world_id: world_id.clone(),
999 description: format!("host {} failed: {}", state.host_id, description),
1000 })
1001 }
1002 Err(e) => {
1003 tracing::error!("failed to find host state for world id: {}: {}", world_id, e);
1004 Some(proc_state)
1005 }
1006 }
1007 }
1008 _ => Some(proc_state)
1009
1010 };
1011 }
1012 Ok(RemoteProcessProcStateMessage::Done(world_id)) => {
1013 tracing::info!("allocator world_id: {} is done", world_id);
1014 if let Some(host_id) = self.host_id_for_world_id(&world_id) {
1015 if let Some(state) = self.host_states.get(&host_id) {
1016 if !state.active_procs.is_empty() {
1017 tracing::error!("received done for world id: {} with active procs: {:?}", world_id, state.active_procs);
1018 }
1019 } else {
1020 tracing::error!("received done for unknown state world id: {}", world_id);
1021 }
1022 } else {
1023 tracing::error!("received done for unknown world id: {}", world_id);
1024 }
1025 if self.world_offsets.remove(&world_id).is_none() {
1026 tracing::error!("received done for unknown world id: {}", world_id);
1027 } else if self.world_offsets.is_empty() {
1028 self.running = false;
1029 break None;
1030 }
1031 }
1032 Ok(RemoteProcessProcStateMessage::HeartBeat) => {}
1036 Err(e) => {
1037 break Some(ProcState::Failed {world_id: self.world_id.clone(), description: format!("error receiving events: {}", e)});
1038 }
1039 }
1040 }
1041
1042 _ = clock::RealClock.sleep_until(heartbeat_time) => {
1043 self.host_states.iter().for_each(|(_, host_state)| host_state.tx.post(RemoteProcessAllocatorMessage::HeartBeat));
1044 heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
1045 }
1046
1047 closed_host_id = self.comm_watcher_rx.recv() => {
1048 if let Some(closed_host_id) = closed_host_id {
1049 tracing::debug!("host {} channel closed, cleaning up", closed_host_id);
1050 if let Some(state) = self.host_states.get(&closed_host_id) {
1051 if !state.allocated {
1052 break Some(ProcState::Failed {
1053 world_id: self.world_id.clone(),
1054 description: format!(
1055 "no process has ever been allocated on {} before the channel is closed; \
1056 a common issue could be the channel was never established",
1057 closed_host_id
1058 )});
1059 }
1060 }
1061 let proc_ids = match self.cleanup_host_channel_closed(closed_host_id) {
1062 Ok(proc_ids) => proc_ids,
1063 Err(err) => {
1064 tracing::error!("failed to cleanup disconnected host: {}", err);
1065 continue;
1066 }
1067 };
1068 for proc_id in proc_ids {
1069 tracing::debug!("queuing Stopped state for {}", proc_id);
1070 self.event_queue.push_back(ProcState::Stopped{proc_id, reason: ProcStopReason::HostWatchdog});
1071 }
1072 if self.host_states.is_empty() {
1074 tracing::info!("no more hosts left, stopping the alloc");
1075 self.running = false;
1076 }
1077 reloop = true;
1079 break None;
1080 } else {
1081 tracing::warn!("unexpected comm watcher channel close");
1083 break None;
1084 }
1085 }
1086 }
1087 };
1088
1089 if reloop {
1090 assert!(update.is_none());
1093 continue;
1094 }
1095
1096 break match update {
1097 Some(ProcState::Created {
1098 proc_id,
1099 point,
1100 pid,
1101 }) => match self.project_proc_into_global_extent(&proc_id, &point) {
1102 Ok(global_point) => {
1103 tracing::debug!("reprojected coords: {} -> {}", point, global_point);
1104 Some(ProcState::Created {
1105 proc_id,
1106 point: global_point,
1107 pid,
1108 })
1109 }
1110 Err(e) => {
1111 tracing::error!("failed to project coords for proc: {}: {}", proc_id, e);
1112 None
1113 }
1114 },
1115
1116 Some(ProcState::Failed {
1117 world_id: _,
1118 ref description,
1119 }) => {
1120 tracing::error!(description);
1121 self.failed = true;
1122 update
1123 }
1124
1125 _ => update,
1126 };
1127 }
1128 }
1129
1130 fn extent(&self) -> &Extent {
1131 &self.spec.extent
1132 }
1133
1134 fn world_id(&self) -> &WorldId {
1135 &self.world_id
1136 }
1137
1138 fn transport(&self) -> ChannelTransport {
1139 self.transport.clone()
1140 }
1141
1142 async fn stop(&mut self) -> Result<(), AllocatorError> {
1143 tracing::info!("stopping alloc");
1144
1145 for (host_id, task_state) in self.host_states.iter_mut() {
1146 tracing::debug!("stopping alloc at host {}", host_id);
1147 task_state.tx.post(RemoteProcessAllocatorMessage::Stop);
1148 }
1149
1150 Ok(())
1151 }
1152}
1153
1154impl Drop for RemoteProcessAlloc {
1155 fn drop(&mut self) {
1156 tracing::debug!("dropping RemoteProcessAlloc of world_id {}", self.world_id);
1157 }
1158}
1159
1160fn offset_of_view(view: &View) -> Option<usize> {
1164 let (_point, offset) = view.iter().next()?;
1165 Some(offset)
1166}
1167
1168#[cfg(test)]
1169mod test {
1170 use std::assert_matches::assert_matches;
1171
1172 use hyperactor::ActorRef;
1173 use hyperactor::channel::ChannelRx;
1174 use hyperactor::clock::ClockKind;
1175 use hyperactor::id;
1176 use ndslice::extent;
1177 use tokio::sync::oneshot;
1178
1179 use super::*;
1180 use crate::alloc::ChannelTransport;
1181 use crate::alloc::MockAlloc;
1182 use crate::alloc::MockAllocWrapper;
1183 use crate::alloc::MockAllocator;
1184 use crate::alloc::ProcStopReason;
1185 use crate::proc_mesh::mesh_agent::MeshAgent;
1186
1187 async fn read_all_created(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1188 let mut i: usize = 0;
1189 while i < alloc_len {
1190 let m = rx.recv().await.unwrap();
1191 match m {
1192 RemoteProcessProcStateMessage::Update(ProcState::Created { .. }) => i += 1,
1193 RemoteProcessProcStateMessage::HeartBeat => {}
1194 _ => panic!("unexpected message: {:?}", m),
1195 }
1196 }
1197 }
1198
1199 async fn read_all_running(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1200 let mut i: usize = 0;
1201 while i < alloc_len {
1202 let m = rx.recv().await.unwrap();
1203 match m {
1204 RemoteProcessProcStateMessage::Update(ProcState::Running { .. }) => i += 1,
1205 RemoteProcessProcStateMessage::HeartBeat => {}
1206 _ => panic!("unexpected message: {:?}", m),
1207 }
1208 }
1209 }
1210
1211 async fn read_all_stopped(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1212 let mut i: usize = 0;
1213 while i < alloc_len {
1214 let m = rx.recv().await.unwrap();
1215 match m {
1216 RemoteProcessProcStateMessage::Update(ProcState::Stopped { .. }) => i += 1,
1217 RemoteProcessProcStateMessage::HeartBeat => {}
1218 _ => panic!("unexpected message: {:?}", m),
1219 }
1220 }
1221 }
1222
1223 fn set_procstate_expectations(alloc: &mut MockAlloc, extent: Extent) {
1224 alloc.expect_extent().return_const(extent.clone());
1225 for (i, point) in extent.points().enumerate() {
1226 let proc_id = format!("test[{}]", i).parse().unwrap();
1227 alloc.expect_next().times(1).return_once(|| {
1228 Some(ProcState::Created {
1229 proc_id,
1230 point,
1231 pid: 0,
1232 })
1233 });
1234 }
1235 for i in 0..extent.num_ranks() {
1236 let proc_id = format!("test[{}]", i).parse().unwrap();
1237 let mesh_agent = ActorRef::<MeshAgent>::attest(
1238 format!("test[{}].mesh_agent[{}]", i, i).parse().unwrap(),
1239 );
1240 alloc.expect_next().times(1).return_once(|| {
1241 Some(ProcState::Running {
1242 proc_id,
1243 addr: ChannelAddr::Unix("/proc0".parse().unwrap()),
1244 mesh_agent,
1245 })
1246 });
1247 }
1248 for i in 0..extent.num_ranks() {
1249 let proc_id = format!("test[{}]", i).parse().unwrap();
1250 alloc.expect_next().times(1).return_once(|| {
1251 Some(ProcState::Stopped {
1252 proc_id,
1253 reason: ProcStopReason::Unknown,
1254 })
1255 });
1256 }
1257 }
1258
1259 #[timed_test::async_timed_test(timeout_secs = 5)]
1260 async fn test_simple() {
1261 let config = hyperactor::config::global::lock();
1262 let _guard = config.override_key(
1263 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1264 Duration::from_millis(100),
1265 );
1266 hyperactor_telemetry::initialize_logging(ClockKind::default());
1267 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1268 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1269 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1270
1271 let extent = extent!(host = 1, gpu = 2);
1272 let tx = channel::dial(serve_addr.clone()).unwrap();
1273
1274 let world_id: WorldId = id!(test_world_id);
1275 let mut alloc = MockAlloc::new();
1276 alloc.expect_world_id().return_const(world_id.clone());
1277 alloc.expect_extent().return_const(extent.clone());
1278
1279 set_procstate_expectations(&mut alloc, extent.clone());
1280
1281 alloc.expect_next().return_const(None);
1283
1284 let mut allocator = MockAllocator::new();
1285 let total_messages = extent.num_ranks() * 3 + 1;
1286 let mock_wrapper = MockAllocWrapper::new_block_next(
1287 alloc,
1288 total_messages,
1290 );
1291 allocator
1292 .expect_allocate()
1293 .times(1)
1294 .return_once(move |_| Ok(mock_wrapper));
1295
1296 let remote_allocator = RemoteProcessAllocator::new();
1297 let handle = tokio::spawn({
1298 let remote_allocator = remote_allocator.clone();
1299 async move {
1300 remote_allocator
1301 .start_with_allocator(serve_addr, allocator, None)
1302 .await
1303 }
1304 });
1305
1306 tx.send(RemoteProcessAllocatorMessage::Allocate {
1307 view: extent.clone().into(),
1308 bootstrap_addr,
1309 hosts: vec![],
1310 client_context: None,
1311 })
1312 .await
1313 .unwrap();
1314
1315 let m = rx.recv().await.unwrap();
1317 assert_matches!(
1318 m, RemoteProcessProcStateMessage::Allocated { world_id, view }
1319 if world_id == world_id && extent == view.extent()
1320 );
1321
1322 let mut rank: usize = 0;
1324 while rank < extent.num_ranks() {
1325 let m = rx.recv().await.unwrap();
1326 match m {
1327 RemoteProcessProcStateMessage::Update(ProcState::Created {
1328 proc_id,
1329 point,
1330 ..
1331 }) => {
1332 let expected_proc_id = format!("test[{}]", rank).parse().unwrap();
1333 let expected_point = extent.point_of_rank(rank).unwrap();
1334 assert_eq!(proc_id, expected_proc_id);
1335 assert_eq!(point, expected_point);
1336 rank += 1;
1337 }
1338 RemoteProcessProcStateMessage::HeartBeat => {}
1339 _ => panic!("unexpected message: {:?}", m),
1340 }
1341 }
1342 let mut rank: usize = 0;
1344 while rank < extent.num_ranks() {
1345 let m = rx.recv().await.unwrap();
1346 match m {
1347 RemoteProcessProcStateMessage::Update(ProcState::Running {
1348 proc_id,
1349 mesh_agent,
1350 addr: _,
1351 }) => {
1352 let expected_proc_id = format!("test[{}]", rank).parse().unwrap();
1353 let expected_mesh_agent = ActorRef::<MeshAgent>::attest(
1354 format!("test[{}].mesh_agent[{}]", rank, rank)
1355 .parse()
1356 .unwrap(),
1357 );
1358 assert_eq!(proc_id, expected_proc_id);
1359 assert_eq!(mesh_agent, expected_mesh_agent);
1360 rank += 1;
1361 }
1362 RemoteProcessProcStateMessage::HeartBeat => {}
1363 _ => panic!("unexpected message: {:?}", m),
1364 }
1365 }
1366 let mut rank: usize = 0;
1368 while rank < extent.num_ranks() {
1369 let m = rx.recv().await.unwrap();
1370 match m {
1371 RemoteProcessProcStateMessage::Update(ProcState::Stopped {
1372 proc_id,
1373 reason: ProcStopReason::Unknown,
1374 }) => {
1375 let expected_proc_id = format!("test[{}]", rank).parse().unwrap();
1376 assert_eq!(proc_id, expected_proc_id);
1377 rank += 1;
1378 }
1379 RemoteProcessProcStateMessage::HeartBeat => {}
1380 _ => panic!("unexpected message: {:?}", m),
1381 }
1382 }
1383 loop {
1385 let m = rx.recv().await.unwrap();
1386 match m {
1387 RemoteProcessProcStateMessage::Done(id) => {
1388 assert_eq!(id, world_id);
1389 break;
1390 }
1391 RemoteProcessProcStateMessage::HeartBeat => {}
1392 _ => panic!("unexpected message: {:?}", m),
1393 }
1394 }
1395
1396 remote_allocator.terminate();
1397 handle.await.unwrap().unwrap();
1398 }
1399
1400 #[timed_test::async_timed_test(timeout_secs = 15)]
1401 async fn test_normal_stop() {
1402 let config = hyperactor::config::global::lock();
1403 let _guard = config.override_key(
1404 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1405 Duration::from_millis(100),
1406 );
1407 hyperactor_telemetry::initialize_logging(ClockKind::default());
1408 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1409 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1410 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1411
1412 let extent = extent!(host = 1, gpu = 2);
1413 let tx = channel::dial(serve_addr.clone()).unwrap();
1414
1415 let world_id: WorldId = id!(test_world_id);
1416 let mut alloc = MockAllocWrapper::new_block_next(
1417 MockAlloc::new(),
1418 extent.num_ranks() * 2,
1420 );
1421 let next_tx = alloc.notify_tx();
1422 alloc.alloc.expect_world_id().return_const(world_id.clone());
1423 alloc.alloc.expect_extent().return_const(extent.clone());
1424
1425 set_procstate_expectations(&mut alloc.alloc, extent.clone());
1426
1427 alloc.alloc.expect_next().return_const(None);
1428 alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1429
1430 let mut allocator = MockAllocator::new();
1431 allocator
1432 .expect_allocate()
1433 .times(1)
1434 .return_once(|_| Ok(alloc));
1435
1436 let remote_allocator = RemoteProcessAllocator::new();
1437 let handle = tokio::spawn({
1438 let remote_allocator = remote_allocator.clone();
1439 async move {
1440 remote_allocator
1441 .start_with_allocator(serve_addr, allocator, None)
1442 .await
1443 }
1444 });
1445
1446 tx.send(RemoteProcessAllocatorMessage::Allocate {
1447 view: extent.clone().into(),
1448 bootstrap_addr,
1449 hosts: vec![],
1450 client_context: None,
1451 })
1452 .await
1453 .unwrap();
1454
1455 let m = rx.recv().await.unwrap();
1457 assert_matches!(
1458 m,
1459 RemoteProcessProcStateMessage::Allocated { world_id, view }
1460 if world_id == world_id && view.extent() == extent
1461 );
1462
1463 read_all_created(&mut rx, extent.num_ranks()).await;
1464 read_all_running(&mut rx, extent.num_ranks()).await;
1465
1466 tracing::info!("stopping allocation");
1468 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1469 next_tx.send(()).unwrap();
1471
1472 read_all_stopped(&mut rx, extent.num_ranks()).await;
1473
1474 remote_allocator.terminate();
1475 handle.await.unwrap().unwrap();
1476 }
1477
1478 #[timed_test::async_timed_test(timeout_secs = 15)]
1479 async fn test_realloc() {
1480 let config = hyperactor::config::global::lock();
1481 let _guard = config.override_key(
1482 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1483 Duration::from_millis(100),
1484 );
1485 hyperactor_telemetry::initialize_logging(ClockKind::default());
1486 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1487 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1488 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1489
1490 let extent = extent!(host = 1, gpu = 2);
1491
1492 let tx = channel::dial(serve_addr.clone()).unwrap();
1493
1494 let world_id: WorldId = id!(test_world_id);
1495 let mut alloc1 = MockAllocWrapper::new_block_next(
1496 MockAlloc::new(),
1497 extent.num_ranks() * 2,
1499 );
1500 let next_tx1 = alloc1.notify_tx();
1501 alloc1
1502 .alloc
1503 .expect_world_id()
1504 .return_const(world_id.clone());
1505 alloc1.alloc.expect_extent().return_const(extent.clone());
1506
1507 set_procstate_expectations(&mut alloc1.alloc, extent.clone());
1508 alloc1.alloc.expect_next().return_const(None);
1509 alloc1.alloc.expect_stop().times(1).return_once(|| Ok(()));
1510 let mut alloc2 = MockAllocWrapper::new_block_next(
1512 MockAlloc::new(),
1513 extent.num_ranks() * 2,
1515 );
1516 let next_tx2 = alloc2.notify_tx();
1517 alloc2
1518 .alloc
1519 .expect_world_id()
1520 .return_const(world_id.clone());
1521 alloc2.alloc.expect_extent().return_const(extent.clone());
1522 set_procstate_expectations(&mut alloc2.alloc, extent.clone());
1523 alloc2.alloc.expect_next().return_const(None);
1524 alloc2.alloc.expect_stop().times(1).return_once(|| Ok(()));
1525
1526 let mut allocator = MockAllocator::new();
1527 allocator
1528 .expect_allocate()
1529 .times(1)
1530 .return_once(|_| Ok(alloc1));
1531 allocator
1533 .expect_allocate()
1534 .times(1)
1535 .return_once(|_| Ok(alloc2));
1536
1537 let remote_allocator = RemoteProcessAllocator::new();
1538 let handle = tokio::spawn({
1539 let remote_allocator = remote_allocator.clone();
1540 async move {
1541 remote_allocator
1542 .start_with_allocator(serve_addr, allocator, None)
1543 .await
1544 }
1545 });
1546
1547 tx.send(RemoteProcessAllocatorMessage::Allocate {
1548 view: extent.clone().into(),
1549 bootstrap_addr: bootstrap_addr.clone(),
1550 hosts: vec![],
1551 client_context: None,
1552 })
1553 .await
1554 .unwrap();
1555
1556 let m = rx.recv().await.unwrap();
1558 assert_matches!(
1559 m,
1560 RemoteProcessProcStateMessage::Allocated { world_id, view }
1561 if world_id == world_id && extent == view.extent()
1562 );
1563
1564 read_all_created(&mut rx, extent.num_ranks()).await;
1565 read_all_running(&mut rx, extent.num_ranks()).await;
1566
1567 tx.send(RemoteProcessAllocatorMessage::Allocate {
1569 view: extent.clone().into(),
1570 bootstrap_addr,
1571 hosts: vec![],
1572 client_context: None,
1573 })
1574 .await
1575 .unwrap();
1576 next_tx1.send(()).unwrap();
1578 read_all_stopped(&mut rx, extent.num_ranks()).await;
1580 let m = rx.recv().await.unwrap();
1581 assert_matches!(m, RemoteProcessProcStateMessage::Done(_));
1582 let m = rx.recv().await.unwrap();
1583 assert_matches!(
1584 m,
1585 RemoteProcessProcStateMessage::Allocated { world_id, view }
1586 if world_id == world_id && extent == view.extent()
1587 );
1588 read_all_created(&mut rx, extent.num_ranks()).await;
1590 read_all_running(&mut rx, extent.num_ranks()).await;
1591 tracing::info!("stopping allocation");
1593 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1594 next_tx2.send(()).unwrap();
1596
1597 read_all_stopped(&mut rx, extent.num_ranks()).await;
1598
1599 remote_allocator.terminate();
1600 handle.await.unwrap().unwrap();
1601 }
1602
1603 #[timed_test::async_timed_test(timeout_secs = 15)]
1604 async fn test_upstream_closed() {
1605 let config = hyperactor::config::global::lock();
1607 let _guard1 = config.override_key(
1608 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1609 Duration::from_secs(1),
1610 );
1611 let _guard2 = config.override_key(
1612 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1613 Duration::from_millis(100),
1614 );
1615
1616 hyperactor_telemetry::initialize_logging(ClockKind::default());
1617 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1618 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1619 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1620
1621 let extent = extent!(host = 1, gpu = 2);
1622
1623 let tx = channel::dial(serve_addr.clone()).unwrap();
1624
1625 let world_id: WorldId = id!(test_world_id);
1626 let mut alloc = MockAllocWrapper::new_block_next(
1627 MockAlloc::new(),
1628 extent.num_ranks() * 2,
1630 );
1631 let next_tx = alloc.notify_tx();
1632 alloc.alloc.expect_world_id().return_const(world_id.clone());
1633 alloc.alloc.expect_extent().return_const(extent.clone());
1634
1635 set_procstate_expectations(&mut alloc.alloc, extent.clone());
1636
1637 alloc.alloc.expect_next().return_const(None);
1638 let (stop_tx, stop_rx) = oneshot::channel();
1641 alloc.alloc.expect_stop().times(1).return_once(|| {
1642 stop_tx.send(()).unwrap();
1643 Ok(())
1644 });
1645
1646 let mut allocator = MockAllocator::new();
1647 allocator
1648 .expect_allocate()
1649 .times(1)
1650 .return_once(|_| Ok(alloc));
1651
1652 let remote_allocator = RemoteProcessAllocator::new();
1653 let handle = tokio::spawn({
1654 let remote_allocator = remote_allocator.clone();
1655 async move {
1656 remote_allocator
1657 .start_with_allocator(serve_addr, allocator, None)
1658 .await
1659 }
1660 });
1661
1662 tx.send(RemoteProcessAllocatorMessage::Allocate {
1663 view: extent.clone().into(),
1664 bootstrap_addr,
1665 hosts: vec![],
1666 client_context: None,
1667 })
1668 .await
1669 .unwrap();
1670
1671 let m = rx.recv().await.unwrap();
1673 assert_matches!(
1674 m,
1675 RemoteProcessProcStateMessage::Allocated { world_id, view }
1676 if world_id == world_id && view.extent() == extent
1677 );
1678
1679 read_all_created(&mut rx, extent.num_ranks()).await;
1680 read_all_running(&mut rx, extent.num_ranks()).await;
1681
1682 tracing::info!("closing upstream");
1684 drop(rx);
1685 #[allow(clippy::disallowed_methods)]
1687 tokio::time::sleep(Duration::from_secs(2)).await;
1688 stop_rx.await.unwrap();
1690 next_tx.send(()).unwrap();
1692 remote_allocator.terminate();
1693 handle.await.unwrap().unwrap();
1694 }
1695
1696 #[timed_test::async_timed_test(timeout_secs = 15)]
1697 async fn test_inner_alloc_failure() {
1698 let config = hyperactor::config::global::lock();
1699 let _guard = config.override_key(
1700 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1701 Duration::from_secs(60),
1702 );
1703 hyperactor_telemetry::initialize_logging(ClockKind::default());
1704 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1705 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1706 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1707
1708 let extent = extent!(host = 1, gpu = 2);
1709
1710 let tx = channel::dial(serve_addr.clone()).unwrap();
1711
1712 let test_world_id: WorldId = id!(test_world_id);
1713 let mut alloc = MockAllocWrapper::new_block_next(
1714 MockAlloc::new(),
1715 1,
1717 );
1718 let next_tx = alloc.notify_tx();
1719 alloc
1720 .alloc
1721 .expect_world_id()
1722 .return_const(test_world_id.clone());
1723 alloc.alloc.expect_extent().return_const(extent.clone());
1724 alloc
1725 .alloc
1726 .expect_next()
1727 .times(1)
1728 .return_const(Some(ProcState::Failed {
1729 world_id: test_world_id.clone(),
1730 description: "test".to_string(),
1731 }));
1732 alloc.alloc.expect_next().times(1).return_const(None);
1733
1734 alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1735
1736 let mut allocator = MockAllocator::new();
1737 allocator
1738 .expect_allocate()
1739 .times(1)
1740 .return_once(|_| Ok(alloc));
1741
1742 let remote_allocator = RemoteProcessAllocator::new();
1743 let handle = tokio::spawn({
1744 let remote_allocator = remote_allocator.clone();
1745 async move {
1746 remote_allocator
1747 .start_with_allocator(serve_addr, allocator, None)
1748 .await
1749 }
1750 });
1751
1752 tx.send(RemoteProcessAllocatorMessage::Allocate {
1753 view: extent.clone().into(),
1754 bootstrap_addr,
1755 hosts: vec![],
1756 client_context: None,
1757 })
1758 .await
1759 .unwrap();
1760
1761 let m = rx.recv().await.unwrap();
1763 assert_matches!(
1764 m,
1765 RemoteProcessProcStateMessage::Allocated { world_id, view }
1766 if world_id == test_world_id && view.extent() == extent
1767 );
1768 let m = rx.recv().await.unwrap();
1770 assert_matches!(m, RemoteProcessProcStateMessage::Update(ProcState::Failed {world_id, description}) if world_id == test_world_id && description == "test");
1771
1772 tracing::info!("stopping allocation");
1773 tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1774 next_tx.send(()).unwrap();
1776 let m = rx.recv().await.unwrap();
1778 assert_matches!(m, RemoteProcessProcStateMessage::Done(world_id) if world_id == test_world_id);
1779
1780 remote_allocator.terminate();
1781 handle.await.unwrap().unwrap();
1782 }
1783
1784 #[timed_test::async_timed_test(timeout_secs = 15)]
1785 async fn test_trace_id_propagation() {
1786 let config = hyperactor::config::global::lock();
1787 let _guard = config.override_key(
1788 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1789 Duration::from_secs(60),
1790 );
1791 hyperactor_telemetry::initialize_logging(ClockKind::default());
1792 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1793 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1794 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1795
1796 let extent = extent!(host = 1, gpu = 1);
1797 let tx = channel::dial(serve_addr.clone()).unwrap();
1798 let test_world_id: WorldId = id!(test_world_id);
1799 let test_trace_id = "test_trace_id_12345";
1800
1801 let mut alloc = MockAlloc::new();
1803 alloc.expect_world_id().return_const(test_world_id.clone());
1804 alloc.expect_extent().return_const(extent.clone());
1805 alloc.expect_next().return_const(None);
1806
1807 let mut allocator = MockAllocator::new();
1809 allocator
1810 .expect_allocate()
1811 .times(1)
1812 .withf(move |spec: &AllocSpec| {
1813 spec.constraints
1815 .match_labels
1816 .get(CLIENT_TRACE_ID_LABEL)
1817 .is_some_and(|trace_id| trace_id == test_trace_id)
1818 })
1819 .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
1820
1821 let remote_allocator = RemoteProcessAllocator::new();
1822 let handle = tokio::spawn({
1823 let remote_allocator = remote_allocator.clone();
1824 async move {
1825 remote_allocator
1826 .start_with_allocator(serve_addr, allocator, None)
1827 .await
1828 }
1829 });
1830
1831 tx.send(RemoteProcessAllocatorMessage::Allocate {
1833 view: extent.clone().into(),
1834 bootstrap_addr,
1835 hosts: vec![],
1836 client_context: Some(ClientContext {
1837 trace_id: test_trace_id.to_string(),
1838 }),
1839 })
1840 .await
1841 .unwrap();
1842
1843 let m = rx.recv().await.unwrap();
1845 assert_matches!(
1846 m,
1847 RemoteProcessProcStateMessage::Allocated { world_id, view }
1848 if world_id == test_world_id && view.extent() == extent
1849 );
1850
1851 let m = rx.recv().await.unwrap();
1853 assert_matches!(m, RemoteProcessProcStateMessage::Done(world_id) if world_id == test_world_id);
1854
1855 remote_allocator.terminate();
1856 handle.await.unwrap().unwrap();
1857 }
1858
1859 #[timed_test::async_timed_test(timeout_secs = 15)]
1860 async fn test_trace_id_propagation_no_client_context() {
1861 let config = hyperactor::config::global::lock();
1862 let _guard = config.override_key(
1863 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1864 Duration::from_secs(60),
1865 );
1866 hyperactor_telemetry::initialize_logging(ClockKind::default());
1867 let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1868 let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1869 let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1870
1871 let extent = extent!(host = 1, gpu = 1);
1872 let tx = channel::dial(serve_addr.clone()).unwrap();
1873 let test_world_id: WorldId = id!(test_world_id);
1874
1875 let mut alloc = MockAlloc::new();
1877 alloc.expect_world_id().return_const(test_world_id.clone());
1878 alloc.expect_extent().return_const(extent.clone());
1879 alloc.expect_next().return_const(None);
1880
1881 let mut allocator = MockAllocator::new();
1883 allocator
1884 .expect_allocate()
1885 .times(1)
1886 .withf(move |spec: &AllocSpec| {
1887 spec.constraints.match_labels.is_empty()
1889 })
1890 .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
1891
1892 let remote_allocator = RemoteProcessAllocator::new();
1893 let handle = tokio::spawn({
1894 let remote_allocator = remote_allocator.clone();
1895 async move {
1896 remote_allocator
1897 .start_with_allocator(serve_addr, allocator, None)
1898 .await
1899 }
1900 });
1901
1902 tx.send(RemoteProcessAllocatorMessage::Allocate {
1904 view: extent.clone().into(),
1905 bootstrap_addr,
1906 hosts: vec![],
1907 client_context: None,
1908 })
1909 .await
1910 .unwrap();
1911
1912 let m = rx.recv().await.unwrap();
1914 assert_matches!(
1915 m,
1916 RemoteProcessProcStateMessage::Allocated { world_id, view }
1917 if world_id == test_world_id && view.extent() == extent
1918 );
1919
1920 let m = rx.recv().await.unwrap();
1922 assert_matches!(m, RemoteProcessProcStateMessage::Done(world_id) if world_id == test_world_id);
1923
1924 remote_allocator.terminate();
1925 handle.await.unwrap().unwrap();
1926 }
1927}
1928
1929#[cfg(test)]
1930mod test_alloc {
1931 use std::os::unix::process::ExitStatusExt;
1932
1933 use hyperactor::clock::ClockKind;
1934 use hyperactor::config;
1935 use ndslice::extent;
1936 use nix::sys::signal;
1937 use nix::unistd::Pid;
1938 use timed_test::async_timed_test;
1939
1940 use super::*;
1941
1942 #[async_timed_test(timeout_secs = 60)]
1943 async fn test_alloc_simple() {
1944 let config = hyperactor::config::global::lock();
1946 let _guard1 = config.override_key(
1947 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1948 Duration::from_secs(1),
1949 );
1950 let _guard2 = config.override_key(
1951 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1952 Duration::from_millis(100),
1953 );
1954 hyperactor_telemetry::initialize_logging(ClockKind::default());
1955
1956 let spec = AllocSpec {
1957 extent: extent!(host = 2, gpu = 2),
1958 constraints: Default::default(),
1959 };
1960 let world_id = WorldId("test_world_id".to_string());
1961 let transport = ChannelTransport::Unix;
1962
1963 let task1_allocator = RemoteProcessAllocator::new();
1964 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
1965 let task1_addr_string = task1_addr.to_string();
1966 let task1_cmd =
1967 Command::new(buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap());
1968 let task2_allocator = RemoteProcessAllocator::new();
1969 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
1970 let task2_addr_string = task2_addr.to_string();
1971 let task2_cmd =
1972 Command::new(buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap());
1973 let task1_allocator_copy = task1_allocator.clone();
1974 let task1_allocator_handle = tokio::spawn(async move {
1975 tracing::info!("spawning task1");
1976 task1_allocator_copy
1977 .start(task1_cmd, task1_addr, None)
1978 .await
1979 .unwrap();
1980 });
1981 let task2_allocator_copy = task2_allocator.clone();
1982 let task2_allocator_handle = tokio::spawn(async move {
1983 task2_allocator_copy
1984 .start(task2_cmd, task2_addr, None)
1985 .await
1986 .unwrap();
1987 });
1988
1989 let mut initializer = MockRemoteProcessAllocInitializer::new();
1990 initializer.expect_initialize_alloc().return_once(move || {
1991 Ok(vec![
1992 RemoteProcessAllocHost {
1993 hostname: task1_addr_string,
1994 id: "task1".to_string(),
1995 },
1996 RemoteProcessAllocHost {
1997 hostname: task2_addr_string,
1998 id: "task2".to_string(),
1999 },
2000 ])
2001 });
2002 let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, initializer)
2003 .await
2004 .unwrap();
2005 let mut procs = HashSet::new();
2006 let mut started_procs = HashSet::new();
2007 let mut proc_points = HashSet::new();
2008 for _ in 0..spec.extent.num_ranks() * 2 {
2009 let proc_state = alloc.next().await.unwrap();
2010 tracing::debug!("test got message: {:?}", proc_state);
2011 match proc_state {
2012 ProcState::Created { proc_id, point, .. } => {
2013 procs.insert(proc_id);
2014 proc_points.insert(point);
2015 }
2016 ProcState::Running { proc_id, .. } => {
2017 assert!(procs.contains(&proc_id));
2018 started_procs.insert(proc_id);
2019 }
2020 _ => panic!("expected Created or Running"),
2021 }
2022 }
2023 assert_eq!(procs, started_procs);
2024 assert!(
2026 spec.extent
2027 .points()
2028 .all(|point| proc_points.contains(&point))
2029 );
2030
2031 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2033 tokio::select! {
2034 _ = hyperactor::clock::RealClock.sleep_until(timeout) => {},
2035 _ = alloc.next() => panic!("expected no more items"),
2036 }
2037
2038 alloc.stop().await.unwrap();
2040 for _ in 0..spec.extent.num_ranks() {
2041 let proc_state = alloc.next().await.unwrap();
2042 tracing::info!("test received next proc_state: {:?}", proc_state);
2043 match proc_state {
2044 ProcState::Stopped { proc_id, reason } => {
2045 assert!(started_procs.remove(&proc_id));
2046 assert_eq!(reason, ProcStopReason::Stopped);
2047 }
2048 _ => panic!("expected stopped"),
2049 }
2050 }
2051 let proc_state = alloc.next().await;
2053 assert!(proc_state.is_none());
2054 let proc_state = alloc.next().await;
2056 assert!(proc_state.is_none());
2057
2058 task1_allocator.terminate();
2059 task1_allocator_handle.await.unwrap();
2060 task2_allocator.terminate();
2061 task2_allocator_handle.await.unwrap();
2062 }
2063
2064 #[async_timed_test(timeout_secs = 60)]
2065 async fn test_alloc_host_failure() {
2066 let config = hyperactor::config::global::lock();
2068 let _guard1 = config.override_key(
2069 hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2070 Duration::from_secs(1),
2071 );
2072 let _guard2 = config.override_key(
2073 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2074 Duration::from_millis(100),
2075 );
2076 hyperactor_telemetry::initialize_logging(ClockKind::default());
2077
2078 let spec = AllocSpec {
2079 extent: extent!(host = 2, gpu = 2),
2080 constraints: Default::default(),
2081 };
2082 let world_id = WorldId("test_world_id".to_string());
2083 let transport = ChannelTransport::Unix;
2084
2085 let task1_allocator = RemoteProcessAllocator::new();
2086 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2087 let task1_addr_string = task1_addr.to_string();
2088 let task1_cmd =
2089 Command::new(buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap());
2090 let task2_allocator = RemoteProcessAllocator::new();
2091 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2092 let task2_addr_string = task2_addr.to_string();
2093 let task2_cmd =
2094 Command::new(buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap());
2095 let task1_allocator_copy = task1_allocator.clone();
2096 let task1_allocator_handle = tokio::spawn(async move {
2097 tracing::info!("spawning task1");
2098 task1_allocator_copy
2099 .start(task1_cmd, task1_addr, None)
2100 .await
2101 .unwrap();
2102 tracing::info!("task1 terminated");
2103 });
2104 let task2_allocator_copy = task2_allocator.clone();
2105 let task2_allocator_handle = tokio::spawn(async move {
2106 task2_allocator_copy
2107 .start(task2_cmd, task2_addr, None)
2108 .await
2109 .unwrap();
2110 tracing::info!("task2 terminated");
2111 });
2112
2113 let mut initializer = MockRemoteProcessAllocInitializer::new();
2114 initializer.expect_initialize_alloc().return_once(move || {
2115 Ok(vec![
2116 RemoteProcessAllocHost {
2117 hostname: task1_addr_string,
2118 id: "task1".to_string(),
2119 },
2120 RemoteProcessAllocHost {
2121 hostname: task2_addr_string,
2122 id: "task2".to_string(),
2123 },
2124 ])
2125 });
2126 let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, initializer)
2127 .await
2128 .unwrap();
2129 for _ in 0..spec.extent.num_ranks() * 2 {
2130 match alloc.next().await {
2131 Some(ProcState::Created { .. }) | Some(ProcState::Running { .. }) => {}
2132 _ => panic!("expected Created or Running"),
2133 }
2134 }
2135
2136 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2138 tokio::select! {
2139 _ = hyperactor::clock::RealClock
2140 .sleep_until(timeout) => {},
2141 _ = alloc.next() => panic!("expected no more items"),
2142 }
2143
2144 tracing::info!("aborting task1 allocator");
2146 task1_allocator_handle.abort();
2147 RealClock
2148 .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2)
2149 .await;
2150 for _ in 0..spec.extent.num_ranks() / 2 {
2151 let proc_state = alloc.next().await.unwrap();
2152 tracing::info!("test received next proc_state: {:?}", proc_state);
2153 match proc_state {
2154 ProcState::Stopped { proc_id: _, reason } => {
2155 assert_eq!(reason, ProcStopReason::HostWatchdog);
2156 }
2157 _ => panic!("expected stopped"),
2158 }
2159 }
2160 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2162 tokio::select! {
2163 _ = hyperactor::clock::RealClock
2164 .sleep_until(timeout) => {},
2165 _ = alloc.next() => panic!("expected no more items"),
2166 }
2167
2168 tracing::info!("aborting task2 allocator");
2170 task2_allocator_handle.abort();
2171 RealClock
2172 .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2)
2173 .await;
2174 for _ in 0..spec.extent.num_ranks() / 2 {
2175 let proc_state = alloc.next().await.unwrap();
2176 tracing::info!("test received next proc_state: {:?}", proc_state);
2177 match proc_state {
2178 ProcState::Stopped { proc_id: _, reason } => {
2179 assert_eq!(reason, ProcStopReason::HostWatchdog);
2180 }
2181 _ => panic!("expected stopped"),
2182 }
2183 }
2184 let proc_state = alloc.next().await;
2186 assert!(proc_state.is_none());
2187 let proc_state = alloc.next().await;
2189 assert!(proc_state.is_none());
2190 }
2191
2192 #[async_timed_test(timeout_secs = 15)]
2193 async fn test_alloc_inner_alloc_failure() {
2194 unsafe {
2196 std::env::set_var("MONARCH_MESSAGE_DELIVERY_TIMEOUT_SECS", "1");
2197 }
2198 let config = hyperactor::config::global::lock();
2199 let _guard = config.override_key(
2200 hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2201 Duration::from_millis(100),
2202 );
2203 hyperactor_telemetry::initialize_logging(ClockKind::default());
2204
2205 let spec = AllocSpec {
2206 extent: extent!(host = 2, gpu = 2),
2207 constraints: Default::default(),
2208 };
2209 let world_id = WorldId("test_world_id".to_string());
2210 let transport = ChannelTransport::Unix;
2211
2212 let task1_allocator = RemoteProcessAllocator::new();
2213 let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2214 let task1_addr_string = task1_addr.to_string();
2215 let task1_cmd =
2216 Command::new(buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap());
2217 let task2_allocator = RemoteProcessAllocator::new();
2218 let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2219 let task2_addr_string = task2_addr.to_string();
2220 let task2_cmd = Command::new("/caught/somewhere/in/time");
2222 let task1_allocator_copy = task1_allocator.clone();
2223 let task1_allocator_handle = tokio::spawn(async move {
2224 tracing::info!("spawning task1");
2225 task1_allocator_copy
2226 .start(task1_cmd, task1_addr, None)
2227 .await
2228 .unwrap();
2229 });
2230 let task2_allocator_copy = task2_allocator.clone();
2231 let task2_allocator_handle = tokio::spawn(async move {
2232 task2_allocator_copy
2233 .start(task2_cmd, task2_addr, None)
2234 .await
2235 .unwrap();
2236 });
2237
2238 let mut initializer = MockRemoteProcessAllocInitializer::new();
2239 initializer.expect_initialize_alloc().return_once(move || {
2240 Ok(vec![
2241 RemoteProcessAllocHost {
2242 hostname: task1_addr_string,
2243 id: "task1".to_string(),
2244 },
2245 RemoteProcessAllocHost {
2246 hostname: task2_addr_string,
2247 id: "task2".to_string(),
2248 },
2249 ])
2250 });
2251 let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, initializer)
2252 .await
2253 .unwrap();
2254 let mut procs = HashSet::new();
2255 let mut started_procs = HashSet::new();
2256 let mut proc_points = HashSet::new();
2257 let mut failed = 0;
2258 for _ in 0..spec.extent.num_ranks() + 1 {
2260 let proc_state = alloc.next().await.unwrap();
2261 tracing::debug!("test got message: {:?}", proc_state);
2262 match proc_state {
2263 ProcState::Created { proc_id, point, .. } => {
2264 procs.insert(proc_id);
2265 proc_points.insert(point);
2266 }
2267 ProcState::Running { proc_id, .. } => {
2268 assert!(procs.contains(&proc_id));
2269 started_procs.insert(proc_id);
2270 }
2271 ProcState::Failed { .. } => {
2272 failed += 1;
2273 }
2274 _ => panic!("expected Created, Running or Failed"),
2275 }
2276 }
2277 assert_eq!(procs, started_procs);
2278 assert_eq!(failed, 1);
2279 for rank in 0..spec.extent.num_ranks() / 2 {
2281 let point = spec.extent.point_of_rank(rank).unwrap();
2282 assert!(proc_points.contains(&point));
2283 }
2284
2285 let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2287 tokio::select! {
2288 _ = hyperactor::clock::RealClock
2289 .sleep_until(timeout) => {},
2290 _ = alloc.next() => panic!("expected no more items"),
2291 }
2292
2293 alloc.stop().await.unwrap();
2295 for _ in 0..spec.extent.num_ranks() / 2 {
2296 let proc_state = alloc.next().await.unwrap();
2297 tracing::info!("test received next proc_state: {:?}", proc_state);
2298 match proc_state {
2299 ProcState::Stopped { proc_id, reason } => {
2300 assert!(started_procs.remove(&proc_id));
2301 assert_eq!(reason, ProcStopReason::Stopped);
2302 }
2303 _ => panic!("expected stopped"),
2304 }
2305 }
2306 let proc_state = alloc.next().await;
2308 assert!(proc_state.is_none());
2309 let proc_state = alloc.next().await;
2311 assert!(proc_state.is_none());
2312
2313 task1_allocator.terminate();
2314 task1_allocator_handle.await.unwrap();
2315 task2_allocator.terminate();
2316 task2_allocator_handle.await.unwrap();
2317 }
2318
2319 #[tracing_test::traced_test]
2320 #[async_timed_test(timeout_secs = 60)]
2321 async fn test_remote_process_alloc_signal_handler() {
2322 let num_proc_meshes = 5;
2323 let hosts_per_proc_mesh = 5;
2324
2325 let pid_addr = ChannelAddr::any(ChannelTransport::Unix);
2326 let (pid_addr, mut pid_rx) = channel::serve::<u32>(pid_addr).await.unwrap();
2327
2328 let addresses = (0..(num_proc_meshes * hosts_per_proc_mesh))
2329 .map(|_| ChannelAddr::any(ChannelTransport::Unix).to_string())
2330 .collect::<Vec<_>>();
2331
2332 let remote_process_allocators = addresses
2333 .iter()
2334 .map(|addr| {
2335 Command::new(
2336 buck_resources::get("monarch/hyperactor_mesh/remote_process_allocator")
2337 .unwrap(),
2338 )
2339 .env("RUST_LOG", "info")
2340 .arg(format!("--addr={addr}"))
2341 .stdout(std::process::Stdio::piped())
2342 .spawn()
2343 .unwrap()
2344 })
2345 .collect::<Vec<_>>();
2346
2347 let done_allocating_addr = ChannelAddr::any(ChannelTransport::Unix);
2348 let (done_allocating_addr, mut done_allocating_rx) =
2349 channel::serve::<()>(done_allocating_addr).await.unwrap();
2350 let mut remote_process_alloc = Command::new(
2351 buck_resources::get("monarch/hyperactor_mesh/remote_process_alloc").unwrap(),
2352 )
2353 .arg(format!("--done-allocating-addr={}", done_allocating_addr))
2354 .arg(format!("--addresses={}", addresses.join(",")))
2355 .arg(format!("--num-proc-meshes={}", num_proc_meshes))
2356 .arg(format!("--hosts-per-proc-mesh={}", hosts_per_proc_mesh))
2357 .arg(format!("--pid-addr={}", pid_addr))
2358 .spawn()
2359 .unwrap();
2360
2361 done_allocating_rx.recv().await.unwrap();
2362 let mut received_pids = Vec::new();
2363 while let Ok(pid) = pid_rx.recv().await {
2364 received_pids.push(pid);
2365 if received_pids.len() == remote_process_allocators.len() {
2366 break;
2367 }
2368 }
2369
2370 signal::kill(
2371 Pid::from_raw(remote_process_alloc.id().unwrap() as i32),
2372 signal::SIGINT,
2373 )
2374 .unwrap();
2375
2376 assert_eq!(
2377 remote_process_alloc.wait().await.unwrap().signal(),
2378 Some(signal::SIGINT as i32)
2379 );
2380
2381 RealClock.sleep(tokio::time::Duration::from_secs(5)).await;
2382
2383 for child_pid in received_pids {
2385 let pid_check = Command::new("kill")
2386 .arg("-0")
2387 .arg(child_pid.to_string())
2388 .output()
2389 .await
2390 .expect("Failed to check if PID is alive");
2391
2392 assert!(
2393 !pid_check.status.success(),
2394 "PID {} should no longer be alive",
2395 child_pid
2396 );
2397 }
2398
2399 for mut remote_process_allocator in remote_process_allocators {
2402 remote_process_allocator.kill().await.unwrap();
2403 }
2404 }
2405}