hyperactor_mesh/alloc/
remoteprocess.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::collections::HashSet;
11use std::collections::VecDeque;
12use std::sync::Arc;
13use std::time::Duration;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use dashmap::DashMap;
18use futures::FutureExt;
19use futures::future::join_all;
20use futures::future::select_all;
21use hyperactor::WorldId;
22use hyperactor::channel;
23use hyperactor::channel::ChannelAddr;
24use hyperactor::channel::ChannelRx;
25use hyperactor::channel::ChannelTransport;
26use hyperactor::channel::ChannelTx;
27use hyperactor::channel::Rx;
28use hyperactor::channel::TcpMode;
29use hyperactor::channel::Tx;
30use hyperactor::channel::TxStatus;
31use hyperactor::clock;
32use hyperactor::clock::Clock;
33use hyperactor::clock::RealClock;
34use hyperactor::internal_macro_support::serde_json;
35use hyperactor::mailbox::DialMailboxRouter;
36use hyperactor::mailbox::MailboxServer;
37use hyperactor::observe_async;
38use hyperactor::observe_result;
39use hyperactor::reference::Reference;
40use mockall::automock;
41use ndslice::Region;
42use ndslice::Slice;
43use ndslice::View;
44use ndslice::ViewExt;
45use ndslice::view::Extent;
46use ndslice::view::Point;
47use serde::Deserialize;
48use serde::Serialize;
49use strum::AsRefStr;
50use tokio::io::AsyncWriteExt;
51use tokio::process::Command;
52use tokio::sync::mpsc::UnboundedReceiver;
53use tokio::sync::mpsc::UnboundedSender;
54use tokio::sync::mpsc::unbounded_channel;
55use tokio::task::JoinHandle;
56use tokio_stream::StreamExt;
57use tokio_stream::wrappers::WatchStream;
58use tokio_util::sync::CancellationToken;
59use typeuri::Named;
60
61use crate::alloc::Alloc;
62use crate::alloc::AllocConstraints;
63use crate::alloc::AllocSpec;
64use crate::alloc::Allocator;
65use crate::alloc::AllocatorError;
66use crate::alloc::ProcState;
67use crate::alloc::ProcStopReason;
68use crate::alloc::ProcessAllocator;
69use crate::alloc::process::CLIENT_TRACE_ID_LABEL;
70use crate::alloc::process::ClientContext;
71use crate::alloc::with_unspecified_port_or_any;
72use crate::shortuuid::ShortUuid;
73
74/// Control messages sent from remote process allocator to local allocator.
75#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
76pub enum RemoteProcessAllocatorMessage {
77    /// Create allocation with given spec and send updates to bootstrap_addr.
78    Allocate {
79        /// The key used to identify this allocation.
80        alloc_key: ShortUuid,
81        /// The extent to allocate.
82        extent: Extent,
83        /// Bootstrap address to be used for sending updates.
84        bootstrap_addr: ChannelAddr,
85        /// Ordered list of hosts in this allocation. Can be used to
86        /// pre-populate the any local configurations such as torch.dist.
87        hosts: Vec<String>,
88        /// Client context which is passed to the ProcessAlloc
89        /// Todo: Once RemoteProcessAllocator moves to mailbox,
90        /// the client_context will go to the message header instead
91        client_context: Option<ClientContext>,
92        /// The address allocator should use for its forwarder.
93        forwarder_addr: ChannelAddr,
94    },
95    /// Stop allocation.
96    Stop,
97    /// Heartbeat message to check if remote process allocator and its
98    /// host are alive.
99    HeartBeat,
100}
101wirevalue::register_type!(RemoteProcessAllocatorMessage);
102
103/// Control message sent from local allocator to remote allocator
104/// relaying process state updates.
105/// AsRefStr allows us to log the values
106#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
107pub enum RemoteProcessProcStateMessage {
108    /// Allocation successful and Update, Done messages will follow.
109    Allocated {
110        alloc_key: ShortUuid,
111        world_id: WorldId,
112    },
113    /// ProcState updates.
114    Update(ShortUuid, ProcState),
115    /// Underlying Alloc is done.
116    Done(ShortUuid),
117    /// Heartbeat message to check if client is alive.
118    HeartBeat,
119}
120wirevalue::register_type!(RemoteProcessProcStateMessage);
121
122/// Allocator with a service frontend that wraps ProcessAllocator.
123pub struct RemoteProcessAllocator {
124    cancel_token: CancellationToken,
125}
126
127async fn conditional_sleeper<F: futures::Future<Output = ()>>(t: Option<F>) {
128    match t {
129        Some(timer) => timer.await,
130        None => futures::future::pending().await,
131    }
132}
133
134impl RemoteProcessAllocator {
135    /// Create a new allocator. It will not start until start() is called.
136    pub fn new() -> Arc<Self> {
137        Arc::new(Self {
138            cancel_token: CancellationToken::new(),
139        })
140    }
141
142    /// Stop the allocator. This will stop any ongoing allocations.
143    pub fn terminate(&self) {
144        self.cancel_token.cancel();
145    }
146
147    /// Start a remote process allocator with given cmd listening for
148    /// RemoteProcessAllocatorMessage on serve_addr. Call will block until cancelled.
149    /// The implementation is simple such that it can only handle one Alloc at
150    /// a time. Generally that's the most common use-case.
151    /// Flow works as follows:
152    /// 1. Client sends Allocate message to serve_addr.
153    /// 2. Allocator connects to bootstrap_addr, creates Alloc and sends Allocated message.
154    /// 3. Allocator streams one or more Update messages to bootstrap_addr as Alloc progresses
155    ///    making the following changes:
156    ///    * Remap mesh_agent listen address to our own forwarder actor address.
157    /// 4. Allocator sends Done message to bootstrap_addr when Alloc is done.
158    ///
159    /// At any point, client can send Stop message to serve_addr to stop the allocator.
160    /// If timeout is Some, the allocator will exit if no client connects within
161    /// that timeout, and no child allocation is running.
162    #[hyperactor::instrument]
163    pub async fn start(
164        &self,
165        cmd: Command,
166        serve_addr: ChannelAddr,
167        timeout: Option<Duration>,
168    ) -> Result<(), anyhow::Error> {
169        let process_allocator = ProcessAllocator::new(cmd);
170        self.start_with_allocator(serve_addr, process_allocator, timeout)
171            .await
172    }
173
174    /// Start a remote process allocator with given allocator listening for
175    /// RemoteProcessAllocatorMessage on serve_addr.
176    /// Used for testing.
177    #[hyperactor::instrument(fields(addr=serve_addr.to_string()))]
178    pub async fn start_with_allocator<A: Allocator + Send + Sync + 'static>(
179        &self,
180        serve_addr: ChannelAddr,
181        mut process_allocator: A,
182        timeout: Option<Duration>,
183    ) -> Result<(), anyhow::Error>
184    where
185        <A as Allocator>::Alloc: Send,
186        <A as Allocator>::Alloc: Sync,
187    {
188        tracing::info!("starting remote allocator on: {}", serve_addr);
189        let (_, mut rx) = channel::serve(serve_addr.clone()).map_err(anyhow::Error::from)?;
190
191        struct ActiveAllocation {
192            handle: JoinHandle<()>,
193            cancel_token: CancellationToken,
194        }
195        #[observe_async("RemoteProcessAllocator")]
196        async fn ensure_previous_alloc_stopped(active_allocation: &mut Option<ActiveAllocation>) {
197            if let Some(active_allocation) = active_allocation.take() {
198                tracing::info!("previous alloc found, stopping");
199                active_allocation.cancel_token.cancel();
200                match active_allocation.handle.await {
201                    Ok(_) => {
202                        // Todo, add named tracing for state change here
203                        tracing::info!("allocation stopped.")
204                    }
205                    Err(e) => {
206                        tracing::error!("allocation handler failed: {}", e);
207                    }
208                }
209            }
210        }
211
212        let mut active_allocation: Option<ActiveAllocation> = None;
213        loop {
214            // Refresh each loop iteration so the timer updates whenever a message
215            // is received.
216            let sleep = conditional_sleeper(timeout.map(|t| RealClock.sleep(t)));
217            tokio::select! {
218                msg = rx.recv() => {
219                    match msg {
220                        Ok(RemoteProcessAllocatorMessage::Allocate {
221                            alloc_key,
222                            extent,
223                            bootstrap_addr,
224                            hosts,
225                            client_context,
226                            forwarder_addr,
227                        }) => {
228                            tracing::info!("received allocation request for {} with extent {}", alloc_key, extent);
229                            ensure_previous_alloc_stopped(&mut active_allocation).await;
230
231                            // Create the corresponding local allocation spec.
232                            let mut constraints: AllocConstraints = Default::default();
233                            if let Some(context) = &client_context {
234                                constraints = AllocConstraints {
235                                    match_labels: HashMap::from([(
236                                    CLIENT_TRACE_ID_LABEL.to_string(),
237                                    context.trace_id.to_string(),
238                                    )]
239                                )};
240                                tracing::info!(
241                                    monarch_client_trace_id = context.trace_id.to_string(),
242                                    "allocating...",
243                                );
244                            }
245
246
247                            let spec = AllocSpec {
248                                extent,
249                                constraints,
250                                proc_name: None, // TODO(meriksen, direct addressing): we need to pass the addressing mode here
251                                transport: ChannelTransport::Unix,
252                                proc_allocation_mode: Default::default(),
253                            };
254
255                            match process_allocator.allocate(spec.clone()).await {
256                                Ok(alloc) => {
257                                    let cancel_token = CancellationToken::new();
258                                    active_allocation = Some(ActiveAllocation {
259                                        cancel_token: cancel_token.clone(),
260                                        handle: tokio::spawn(Self::handle_allocation_request(
261                                            Box::new(alloc) as Box<dyn Alloc + Send + Sync>,
262                                            alloc_key,
263                                            bootstrap_addr,
264                                            hosts,
265                                            cancel_token,
266                                            forwarder_addr,
267                                        )),
268                                    })
269                                }
270                                Err(e) => {
271                                    tracing::error!("allocation for {:?} failed: {}", spec, e);
272                                    continue;
273                                }
274                            }
275                        }
276                        Ok(RemoteProcessAllocatorMessage::Stop) => {
277                            tracing::info!("received stop request");
278
279                            ensure_previous_alloc_stopped(&mut active_allocation).await;
280                        }
281                        // Hearbeat message is discarded immediately after being received, sender (client)
282                        // relies on channel ack to know if the receiver (remote process allocator) is
283                        // still alive. No state needs to be updated.
284                        Ok(RemoteProcessAllocatorMessage::HeartBeat) => {}
285                        Err(e) => {
286                            tracing::error!("upstream channel error: {}", e);
287                            continue;
288                        }
289                    }
290                }
291                _ = self.cancel_token.cancelled() => {
292                    tracing::info!("main loop cancelled");
293
294                    ensure_previous_alloc_stopped(&mut active_allocation).await;
295
296                    break;
297                }
298                _ = sleep => {
299                    // If there are any active allocations, reset the timeout.
300                    if active_allocation.is_some() {
301                        continue;
302                    }
303                    // Else, exit the loop as a client hasn't connected in a reasonable
304                    // amount of time.
305                    tracing::warn!("timeout of {} seconds elapsed without any allocations, exiting", timeout.unwrap_or_default().as_secs());
306                    break;
307                }
308            }
309        }
310
311        Ok(())
312    }
313
314    #[tracing::instrument(skip(alloc, cancel_token))]
315    #[observe_async("RemoteProcessAllocator")]
316    async fn handle_allocation_request(
317        alloc: Box<dyn Alloc + Send + Sync>,
318        alloc_key: ShortUuid,
319        bootstrap_addr: ChannelAddr,
320        hosts: Vec<String>,
321        cancel_token: CancellationToken,
322        forwarder_addr: ChannelAddr,
323    ) {
324        tracing::info!("handle allocation request, bootstrap_addr: {bootstrap_addr}");
325        // start proc message forwarder
326        let (forwarder_addr, forwarder_rx) = match channel::serve(forwarder_addr) {
327            Ok(v) => v,
328            Err(e) => {
329                tracing::error!("failed to to bootstrap forwarder actor: {}", e);
330                return;
331            }
332        };
333        let router = DialMailboxRouter::new();
334        let mailbox_handle = router.clone().serve(forwarder_rx);
335        tracing::info!("started forwarder on: {}", forwarder_addr);
336
337        // Check if we need to write TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE
338        // See: https://github.com/fairinternal/xlformers/blob/llama4_monarch/tools/launching/torchx/entrypoint/generate_ranks.py
339        if let Ok(hosts_file) = std::env::var("TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE") {
340            tracing::info!("writing hosts to {}", hosts_file);
341            #[derive(Serialize)]
342            struct Hosts {
343                hostnames: Vec<String>,
344            }
345            match serde_json::to_string(&Hosts { hostnames: hosts }) {
346                Ok(json) => match tokio::fs::File::create(&hosts_file).await {
347                    Ok(mut file) => {
348                        if file.write_all(json.as_bytes()).await.is_err() {
349                            tracing::error!("failed to write hosts to {}", hosts_file);
350                            return;
351                        }
352                    }
353                    Err(e) => {
354                        tracing::error!("failed to open hosts file {}: {}", hosts_file, e);
355                        return;
356                    }
357                },
358                Err(e) => {
359                    tracing::error!("failed to serialize hosts: {}", e);
360                    return;
361                }
362            }
363        }
364
365        Self::handle_allocation_loop(
366            alloc,
367            alloc_key,
368            bootstrap_addr,
369            router,
370            forwarder_addr,
371            cancel_token,
372        )
373        .await;
374
375        mailbox_handle.stop("alloc stopped");
376        if let Err(e) = mailbox_handle.await {
377            tracing::error!("failed to join forwarder: {}", e);
378        }
379    }
380
381    async fn handle_allocation_loop(
382        mut alloc: Box<dyn Alloc + Send + Sync>,
383        alloc_key: ShortUuid,
384        bootstrap_addr: ChannelAddr,
385        router: DialMailboxRouter,
386        forward_addr: ChannelAddr,
387        cancel_token: CancellationToken,
388    ) {
389        let world_id = alloc.world_id().clone();
390        tracing::info!("starting handle allocation loop for {}", world_id);
391        let tx = match channel::dial(bootstrap_addr) {
392            Ok(tx) => tx,
393            Err(err) => {
394                tracing::error!("failed to dial bootstrap address: {}", err);
395                return;
396            }
397        };
398        let message = RemoteProcessProcStateMessage::Allocated {
399            alloc_key: alloc_key.clone(),
400            world_id,
401        };
402        tracing::info!(name = message.as_ref(), "sending allocated message",);
403        if let Err(e) = tx.send(message).await {
404            tracing::error!("failed to send Allocated message: {}", e);
405            return;
406        }
407
408        let mut mesh_agents_by_create_key = HashMap::new();
409        let mut running = true;
410        let tx_status = tx.status().clone();
411        let mut tx_watcher = WatchStream::new(tx_status);
412        loop {
413            tokio::select! {
414                _ = cancel_token.cancelled(), if running => {
415                    tracing::info!("cancelled, stopping allocation");
416                    running = false;
417                    if let Err(e) = alloc.stop().await {
418                        tracing::error!("stop failed: {}", e);
419                        break;
420                    }
421                }
422                status = tx_watcher.next(), if running => {
423                    match status  {
424                        Some(TxStatus::Closed) => {
425                            tracing::error!("upstream channel state closed");
426                            break;
427                        },
428                        _ => {
429                            tracing::debug!("got channel event: {:?}", status.unwrap());
430                            continue;
431                        }
432                    }
433                }
434                e = alloc.next() => {
435                    match e {
436                        Some(event) => {
437                            tracing::debug!(name = event.as_ref(), "got event: {:?}", event);
438                            let event = match event {
439                                ProcState::Created { .. } => event,
440                                ProcState::Running { create_key, proc_id, mesh_agent, addr } => {
441                                    // TODO(meriksen, direct addressing): disable remapping in direct addressing mode
442                                    tracing::debug!("remapping mesh_agent {}: addr {} -> {}", mesh_agent, addr, forward_addr);
443                                    mesh_agents_by_create_key.insert(create_key.clone(), mesh_agent.clone());
444                                    router.bind(mesh_agent.actor_id().proc_id().clone().into(), addr);
445                                    ProcState::Running { create_key, proc_id, mesh_agent, addr: forward_addr.clone() }
446                                },
447                                ProcState::Stopped { create_key, reason } => {
448                                    match mesh_agents_by_create_key.remove(&create_key) {
449                                        Some(mesh_agent) => {
450                                            tracing::debug!("unmapping mesh_agent {}", mesh_agent);
451                                            let agent_ref: Reference = mesh_agent.actor_id().proc_id().clone().into();
452                                            router.unbind(&agent_ref);
453                                        },
454                                        None => {
455                                            tracing::warn!("mesh_agent not found for create key {}", create_key);
456                                        }
457                                    }
458                                    ProcState::Stopped { create_key, reason }
459                                },
460                                ProcState::Failed { ref world_id, ref description } => {
461                                    tracing::error!("allocation failed for {}: {}", world_id, description);
462                                    event
463                                }
464                            };
465                            tracing::debug!(name = event.as_ref(), "sending event: {:?}", event);
466                            tx.post(RemoteProcessProcStateMessage::Update(alloc_key.clone(), event));
467                        }
468                        None => {
469                            tracing::debug!("sending done");
470                            tx.post(RemoteProcessProcStateMessage::Done(alloc_key.clone()));
471                            running = false;
472                            break;
473                        }
474                    }
475                }
476                _ = RealClock.sleep(hyperactor_config::global::get(hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)) => {
477                    tracing::trace!("sending heartbeat");
478                    tx.post(RemoteProcessProcStateMessage::HeartBeat);
479                }
480            }
481        }
482        tracing::info!("allocation handler loop exited");
483        if running {
484            tracing::info!("stopping processes");
485            if let Err(e) = alloc.stop_and_wait().await {
486                tracing::error!("stop failed: {}", e);
487                return;
488            }
489            tracing::info!("stop finished");
490        }
491    }
492}
493
494/// HostId is a string-based identifier of the host. It may be the same
495/// as the hostname but in some situations a separate ID may be used.
496type HostId = String;
497
498/// A host entry passed down from the initializer.
499#[derive(Clone)]
500pub struct RemoteProcessAllocHost {
501    /// A unique identifier of that host. Hostname may be used by ID can
502    /// be different in some situations.
503    pub id: HostId,
504    /// The FQDN of the host.
505    pub hostname: String,
506}
507
508/// State of a host in the RemoteProcessAlloc.
509struct RemoteProcessAllocHostState {
510    /// The allocation key used to identify the host.
511    alloc_key: ShortUuid,
512    /// The host ID of the remote host.
513    host_id: HostId,
514    /// TX channel to the remote host allocator.
515    tx: ChannelTx<RemoteProcessAllocatorMessage>,
516    /// Set of active processes on this host.
517    active_procs: HashSet<ShortUuid>,
518    /// Region allocated by host.
519    region: Region,
520    /// World ID for this host as indicated from Allocated message.
521    world_id: Option<WorldId>,
522    /// If remote allocator sent us ProcState::Failed.
523    failed: bool,
524    /// If remote allocater has ever allocated a proc.
525    allocated: bool,
526}
527
528#[automock]
529#[async_trait]
530/// Interface to provide the set of hosts to be used by RemoteProcessAlloc.
531pub trait RemoteProcessAllocInitializer {
532    /// Initializes and returns a list of hosts to be used by this RemoteProcessAlloc.
533    async fn initialize_alloc(&mut self) -> Result<Vec<RemoteProcessAllocHost>, anyhow::Error>;
534}
535
536/// Wrapper struct around `HashMap<HostId, RemoteProcessAllocHostState>`
537/// to ensure that host addresses are synced with the signal handler
538struct HostStates {
539    inner: HashMap<HostId, RemoteProcessAllocHostState>,
540    host_addresses: Arc<DashMap<HostId, ChannelAddr>>,
541}
542
543impl HostStates {
544    fn new(host_addresses: Arc<DashMap<HostId, ChannelAddr>>) -> HostStates {
545        Self {
546            inner: HashMap::new(),
547            host_addresses,
548        }
549    }
550
551    fn insert(
552        &mut self,
553        host_id: HostId,
554        state: RemoteProcessAllocHostState,
555        address: ChannelAddr,
556    ) {
557        self.host_addresses.insert(host_id.clone(), address);
558        self.inner.insert(host_id, state);
559    }
560
561    fn get(&self, host_id: &HostId) -> Option<&RemoteProcessAllocHostState> {
562        self.inner.get(host_id)
563    }
564
565    fn get_mut(&mut self, host_id: &HostId) -> Option<&mut RemoteProcessAllocHostState> {
566        self.inner.get_mut(host_id)
567    }
568
569    fn remove(&mut self, host_id: &HostId) -> Option<RemoteProcessAllocHostState> {
570        self.host_addresses.remove(host_id);
571        self.inner.remove(host_id)
572    }
573
574    fn iter(&self) -> impl Iterator<Item = (&HostId, &RemoteProcessAllocHostState)> {
575        self.inner.iter()
576    }
577
578    fn iter_mut(&mut self) -> impl Iterator<Item = (&HostId, &mut RemoteProcessAllocHostState)> {
579        self.inner.iter_mut()
580    }
581
582    fn is_empty(&self) -> bool {
583        self.inner.is_empty()
584    }
585    // Any missing HashMap methods should be added here as needed
586}
587
588/// A generalized implementation of an Alloc using one or more hosts running
589/// RemoteProcessAlloc for process allocation.
590pub struct RemoteProcessAlloc {
591    // The initializer to be called at the first next() invocation to obtain
592    // allocated hosts.
593    initializer: Box<dyn RemoteProcessAllocInitializer + Send + Sync>,
594    spec: AllocSpec,
595    remote_allocator_port: u16,
596    world_id: WorldId,
597    ordered_hosts: Vec<RemoteProcessAllocHost>,
598    // Indicates that the initial remote allocation requests have been sent.
599    started: bool,
600    // Indicates that this Alloc is active (we have at least one remote process running).
601    running: bool,
602    // Inidicates that the allocation process has permanently failed.
603    failed: bool,
604    // Maps the alloc key to the host.
605    alloc_to_host: HashMap<ShortUuid, HostId>,
606    host_states: HostStates,
607    world_offsets: HashMap<WorldId, usize>,
608    event_queue: VecDeque<ProcState>,
609    comm_watcher_tx: UnboundedSender<HostId>,
610    comm_watcher_rx: UnboundedReceiver<HostId>,
611
612    bootstrap_addr: ChannelAddr,
613    rx: ChannelRx<RemoteProcessProcStateMessage>,
614    _signal_cleanup_guard: hyperactor::SignalCleanupGuard,
615}
616
617impl RemoteProcessAlloc {
618    /// Create a new Alloc. initializer will be called on the first invocation of next()
619    /// to obtain a list of allocate hosts. Then Allocate message will be sent to all
620    /// RemoteProcessAllocator on all hosts. Heartbeats will be used to maintain health
621    /// status of remote hosts.
622    #[tracing::instrument(skip(initializer))]
623    #[observe_result("RemoteProcessAlloc")]
624    pub async fn new(
625        spec: AllocSpec,
626        world_id: WorldId,
627        remote_allocator_port: u16,
628        initializer: impl RemoteProcessAllocInitializer + Send + Sync + 'static,
629    ) -> Result<Self, anyhow::Error> {
630        let alloc_serve_addr = ChannelAddr::any(spec.transport.clone());
631
632        let (bootstrap_addr, rx) = channel::serve(alloc_serve_addr)?;
633
634        tracing::info!(
635            "starting alloc for {} on: {}",
636            world_id,
637            bootstrap_addr.clone()
638        );
639
640        let (comm_watcher_tx, comm_watcher_rx) = unbounded_channel();
641
642        let host_addresses = Arc::new(DashMap::<HostId, ChannelAddr>::new());
643        let host_addresses_for_signal = host_addresses.clone();
644
645        // Register cleanup callback with global signal manager
646        let signal_cleanup_guard =
647            hyperactor::register_signal_cleanup_scoped(Box::pin(async move {
648                join_all(host_addresses_for_signal.iter().map(|entry| async move {
649                    let addr = entry.value().clone();
650                    match channel::dial(addr.clone()) {
651                        Ok(tx) => {
652                            if let Err(e) = tx.send(RemoteProcessAllocatorMessage::Stop).await {
653                                tracing::error!("Failed to send Stop to {}: {}", addr, e);
654                            }
655                        }
656                        Err(e) => {
657                            tracing::error!("Failed to dial {} during signal cleanup: {}", addr, e);
658                        }
659                    }
660                }))
661                .await;
662            }));
663
664        Ok(Self {
665            spec,
666            world_id,
667            remote_allocator_port,
668            initializer: Box::new(initializer),
669            world_offsets: HashMap::new(),
670            ordered_hosts: Vec::new(),
671            alloc_to_host: HashMap::new(),
672            host_states: HostStates::new(host_addresses),
673            bootstrap_addr,
674            event_queue: VecDeque::new(),
675            comm_watcher_tx,
676            comm_watcher_rx,
677            rx,
678            started: false,
679            running: true,
680            failed: false,
681            _signal_cleanup_guard: signal_cleanup_guard,
682        })
683    }
684
685    /// Start a RemoteProcessAllocator tx watcher. It spawns a task that monitor the underlying
686    /// TxStatus and repot to the main next() loop via comm_watcher_tx. It does not however
687    /// send the actual heartbteats. Just watches the status.
688    /// The task will terminate once comm_watcher_rx is closed.
689    ///
690    /// A task approach was used here as opposed to select! to avoid nesting complexities with
691    /// select!() and select_all().
692    async fn start_comm_watcher(&self) {
693        let mut tx_watchers = Vec::new();
694        for host in &self.ordered_hosts {
695            let tx_status = self.host_states.get(&host.id).unwrap().tx.status().clone();
696            let watcher = WatchStream::new(tx_status);
697            tx_watchers.push((watcher, host.id.clone()));
698        }
699        assert!(!tx_watchers.is_empty());
700        let tx = self.comm_watcher_tx.clone();
701        tokio::spawn(async move {
702            loop {
703                let mut tx_status_futures = Vec::new();
704                for (watcher, _) in &mut tx_watchers {
705                    let fut = watcher.next().boxed();
706                    tx_status_futures.push(fut);
707                }
708                let (tx_status, index, _) = select_all(tx_status_futures).await;
709                let host_id = match tx_watchers.get(index) {
710                    Some((_, host_id)) => host_id.clone(),
711                    None => {
712                        // Should never happen
713                        tracing::error!(
714                            "got selected index {} with no matching host in {}",
715                            index,
716                            tx_watchers.len()
717                        );
718                        continue;
719                    }
720                };
721                if let Some(tx_status) = tx_status {
722                    tracing::debug!("host {} channel event: {:?}", host_id, tx_status);
723                    if tx_status == TxStatus::Closed {
724                        if tx.send(host_id.clone()).is_err() {
725                            // other side closed
726                            break;
727                        }
728                        tx_watchers.remove(index);
729                        if tx_watchers.is_empty() {
730                            // All of the statuses have been closed, exit the loop.
731                            break;
732                        }
733                    }
734                }
735            }
736        });
737    }
738
739    /// Ensure that we have the list of allocated hosts and that we send Allocate
740    /// request to all o fthem. Once done, start_comm_watcher() is called to start
741    /// the channel watcher.
742    /// Function is idempotent.
743    async fn ensure_started(&mut self) -> Result<(), anyhow::Error> {
744        if self.started || self.failed {
745            return Ok(());
746        }
747
748        self.started = true;
749        let hosts = self
750            .initializer
751            .initialize_alloc()
752            .await
753            .context("alloc initializer error")?;
754        if hosts.is_empty() {
755            anyhow::bail!("initializer returned empty list of hosts");
756        }
757        // prepare a list of host names in this allocation to be sent
758        // to remote allocators.
759        let hostnames: Vec<_> = hosts.iter().map(|e| e.hostname.clone()).collect();
760        tracing::info!("obtained {} hosts for this allocation", hostnames.len());
761
762        // Split the extent into regions, one per host.
763        use crate::alloc::ProcAllocationMode;
764
765        // For HostLevel, pre-compute regions. For ProcLevel, skip this step.
766        let regions: Option<Vec<_>> = match self.spec.proc_allocation_mode {
767            ProcAllocationMode::ProcLevel => {
768                // We require at least a dimension for hosts, and one for sub-host (e.g., GPUs)
769                anyhow::ensure!(
770                    self.spec.extent.len() >= 2,
771                    "invalid extent: {}, expected at least 2 dimensions",
772                    self.spec.extent
773                );
774                None
775            }
776            ProcAllocationMode::HostLevel => Some({
777                // HostLevel: each point is a host, create a region for each point
778                let num_points = self.spec.extent.num_ranks();
779                anyhow::ensure!(
780                    hosts.len() >= num_points,
781                    "HostLevel allocation mode requires {} hosts (one per point in extent {}), but only {} hosts were provided",
782                    num_points,
783                    self.spec.extent,
784                    hosts.len()
785                );
786
787                // For HostLevel, create a single-point region for each rank
788                // Each region contains one point that maps to the correct global rank
789                let labels = self.spec.extent.labels().to_vec();
790
791                // Compute strides for row-major layout: strides[i] = product of sizes[i+1..n]
792                let extent_sizes = self.spec.extent.sizes();
793                let mut parent_strides = vec![1; extent_sizes.len()];
794                for i in (0..extent_sizes.len() - 1).rev() {
795                    parent_strides[i] = parent_strides[i + 1] * extent_sizes[i + 1];
796                }
797
798                (0..num_points)
799                    .map(|rank| {
800                        // Create a slice containing only this rank
801                        // Use parent's strides so local point [0,0,...] maps to the correct global rank
802                        let sizes = vec![1; labels.len()];
803                        Region::new(
804                            labels.clone(),
805                            Slice::new(rank, sizes, parent_strides.clone()).unwrap(),
806                        )
807                    })
808                    .collect()
809            }),
810        };
811
812        match self.spec.proc_allocation_mode {
813            ProcAllocationMode::ProcLevel => {
814                // We group by the innermost dimension of the extent.
815                let split_dim = &self.spec.extent.labels()[self.spec.extent.len() - 1];
816                for (i, region) in self.spec.extent.group_by(split_dim)?.enumerate() {
817                    let host = &hosts[i];
818                    tracing::debug!("allocating: {} for host: {}", region, host.id);
819
820                    let remote_addr = match self.spec.transport {
821                        ChannelTransport::MetaTls(_) => {
822                            format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
823                        }
824                        ChannelTransport::Tcp(TcpMode::Localhost) => {
825                            // TODO: @rusch see about moving over to config for this
826                            format!("tcp![::1]:{}", self.remote_allocator_port)
827                        }
828                        ChannelTransport::Tcp(TcpMode::Hostname) => {
829                            format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
830                        }
831                        // Used only for testing.
832                        ChannelTransport::Unix => host.hostname.clone(),
833                        _ => {
834                            anyhow::bail!(
835                                "unsupported transport for host {}: {:?}",
836                                host.id,
837                                self.spec.transport,
838                            );
839                        }
840                    };
841
842                    tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
843                    let remote_addr = remote_addr.parse::<ChannelAddr>()?;
844                    let tx = channel::dial(remote_addr.clone())
845                        .map_err(anyhow::Error::from)
846                        .context(format!(
847                            "failed to dial remote {} for host {}",
848                            remote_addr, host.id
849                        ))?;
850
851                    // Possibly we could use the HostId directly here.
852                    let alloc_key = ShortUuid::generate();
853                    assert!(
854                        self.alloc_to_host
855                            .insert(alloc_key.clone(), host.id.clone())
856                            .is_none()
857                    );
858
859                    let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
860                    let client_context = Some(ClientContext { trace_id });
861                    let message = RemoteProcessAllocatorMessage::Allocate {
862                        alloc_key: alloc_key.clone(),
863                        extent: region.extent(),
864                        bootstrap_addr: self.bootstrap_addr.clone(),
865                        hosts: hostnames.clone(),
866                        client_context,
867                        // Make sure allocator's forwarder uses the same IP address
868                        // which is known to alloc. This is to avoid allocator picks
869                        // its host's private IP address, while its known addres to
870                        // alloc is a public IP address. In some environment, that
871                        // could lead to port unreachable error.
872                        forwarder_addr: with_unspecified_port_or_any(&remote_addr),
873                    };
874                    tracing::info!(
875                        name = message.as_ref(),
876                        "sending allocate message to workers"
877                    );
878                    tx.post(message);
879
880                    self.host_states.insert(
881                        host.id.clone(),
882                        RemoteProcessAllocHostState {
883                            alloc_key,
884                            host_id: host.id.clone(),
885                            tx,
886                            active_procs: HashSet::new(),
887                            region,
888                            world_id: None,
889                            failed: false,
890                            allocated: false,
891                        },
892                        remote_addr,
893                    );
894                }
895
896                self.ordered_hosts = hosts;
897            }
898            ProcAllocationMode::HostLevel => {
899                let regions = regions.unwrap();
900                let num_regions = regions.len();
901                for (i, region) in regions.into_iter().enumerate() {
902                    let host = &hosts[i];
903                    tracing::debug!("allocating: {} for host: {}", region, host.id);
904
905                    let remote_addr = match self.spec.transport {
906                        ChannelTransport::MetaTls(_) => {
907                            format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
908                        }
909                        ChannelTransport::Tcp(TcpMode::Localhost) => {
910                            // TODO: @rusch see about moving over to config for this
911                            format!("tcp![::1]:{}", self.remote_allocator_port)
912                        }
913                        ChannelTransport::Tcp(TcpMode::Hostname) => {
914                            format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
915                        }
916                        // Used only for testing.
917                        ChannelTransport::Unix => host.hostname.clone(),
918                        _ => {
919                            anyhow::bail!(
920                                "unsupported transport for host {}: {:?}",
921                                host.id,
922                                self.spec.transport,
923                            );
924                        }
925                    };
926
927                    tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
928                    let remote_addr = remote_addr.parse::<ChannelAddr>()?;
929                    let tx = channel::dial(remote_addr.clone())
930                        .map_err(anyhow::Error::from)
931                        .context(format!(
932                            "failed to dial remote {} for host {}",
933                            remote_addr, host.id
934                        ))?;
935
936                    // Possibly we could use the HostId directly here.
937                    let alloc_key = ShortUuid::generate();
938                    assert!(
939                        self.alloc_to_host
940                            .insert(alloc_key.clone(), host.id.clone())
941                            .is_none()
942                    );
943
944                    let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
945                    let client_context = Some(ClientContext { trace_id });
946                    let message = RemoteProcessAllocatorMessage::Allocate {
947                        alloc_key: alloc_key.clone(),
948                        extent: region.extent(),
949                        bootstrap_addr: self.bootstrap_addr.clone(),
950                        hosts: hostnames.clone(),
951                        client_context,
952                        // Make sure allocator's forwarder uses the same IP address
953                        // which is known to alloc. This is to avoid allocator picks
954                        // its host's private IP address, while its known addres to
955                        // alloc is a public IP address. In some environment, that
956                        // could lead to port unreachable error.
957                        forwarder_addr: with_unspecified_port_or_any(&remote_addr),
958                    };
959                    tracing::info!(
960                        name = message.as_ref(),
961                        "sending allocate message to workers"
962                    );
963                    tx.post(message);
964
965                    self.host_states.insert(
966                        host.id.clone(),
967                        RemoteProcessAllocHostState {
968                            alloc_key,
969                            host_id: host.id.clone(),
970                            tx,
971                            active_procs: HashSet::new(),
972                            region,
973                            world_id: None,
974                            failed: false,
975                            allocated: false,
976                        },
977                        remote_addr,
978                    );
979                }
980
981                // Only store hosts that were actually used for regions
982                // If num_regions < hosts.len(), we only use the first num_regions hosts
983                self.ordered_hosts = hosts.into_iter().take(num_regions).collect();
984            }
985        }
986        self.start_comm_watcher().await;
987        self.started = true;
988
989        Ok(())
990    }
991
992    // Given a proc_id, obtain the internal HostState structure.
993    fn get_host_state_mut(
994        &mut self,
995        alloc_key: &ShortUuid,
996    ) -> Result<&mut RemoteProcessAllocHostState, anyhow::Error> {
997        let host_id: &HostId = self
998            .alloc_to_host
999            .get(alloc_key)
1000            .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1001
1002        self.host_states
1003            .get_mut(host_id)
1004            .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1005    }
1006
1007    // Given a proc_id, obtain the internal HostState structure.
1008    fn get_host_state(
1009        &self,
1010        alloc_key: &ShortUuid,
1011    ) -> Result<&RemoteProcessAllocHostState, anyhow::Error> {
1012        let host_id: &HostId = self
1013            .alloc_to_host
1014            .get(alloc_key)
1015            .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1016
1017        self.host_states
1018            .get(host_id)
1019            .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1020    }
1021
1022    fn remove_host_state(
1023        &mut self,
1024        alloc_key: &ShortUuid,
1025    ) -> Result<RemoteProcessAllocHostState, anyhow::Error> {
1026        let host_id: &HostId = self
1027            .alloc_to_host
1028            .get(alloc_key)
1029            .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1030
1031        self.host_states
1032            .remove(host_id)
1033            .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1034    }
1035
1036    fn add_proc_id_to_host_state(
1037        &mut self,
1038        alloc_key: &ShortUuid,
1039        create_key: &ShortUuid,
1040    ) -> Result<(), anyhow::Error> {
1041        let task_state = self.get_host_state_mut(alloc_key)?;
1042        if !task_state.active_procs.insert(create_key.clone()) {
1043            // Should not happen but we can ignore
1044            tracing::error!("proc with create key {} already in host state", create_key);
1045        }
1046        task_state.allocated = true;
1047        Ok(())
1048    }
1049
1050    fn remove_proc_from_host_state(
1051        &mut self,
1052        alloc_key: &ShortUuid,
1053        create_key: &ShortUuid,
1054    ) -> Result<(), anyhow::Error> {
1055        let task_state = self.get_host_state_mut(alloc_key)?;
1056        if !task_state.active_procs.remove(create_key) {
1057            // Should not happen but we can ignore
1058            tracing::error!("proc with create_key already in host state: {}", create_key);
1059        }
1060        Ok(())
1061    }
1062
1063    // Reproject proc world coords to global shape coords.
1064    fn project_proc_into_global_extent(
1065        &self,
1066        alloc_key: &ShortUuid,
1067        point: &Point,
1068    ) -> Result<Point, anyhow::Error> {
1069        let global_rank = self
1070            .get_host_state(alloc_key)?
1071            .region
1072            .get(point.rank())
1073            .ok_or_else(|| {
1074                anyhow::anyhow!(
1075                    "rank {} out of bounds for in alloc {}",
1076                    point.rank(),
1077                    alloc_key
1078                )
1079            })?;
1080        Ok(self.spec.extent.point_of_rank(global_rank)?)
1081    }
1082
1083    // Cleanup a comm-failed host information by its ID.
1084    fn cleanup_host_channel_closed(
1085        &mut self,
1086        host_id: HostId,
1087    ) -> Result<Vec<ShortUuid>, anyhow::Error> {
1088        let state = match self.host_states.remove(&host_id) {
1089            Some(state) => state,
1090            None => {
1091                // this should never happen.
1092                anyhow::bail!(
1093                    "got channel closed event for host {} which has no known state",
1094                    host_id
1095                );
1096            }
1097        };
1098        self.ordered_hosts.retain(|host| host.id != host_id);
1099        self.alloc_to_host.remove(&state.alloc_key);
1100        if let Some(world_id) = state.world_id {
1101            self.world_offsets.remove(&world_id);
1102        }
1103        let create_keys = state.active_procs.iter().cloned().collect();
1104
1105        Ok(create_keys)
1106    }
1107}
1108
1109#[async_trait]
1110impl Alloc for RemoteProcessAlloc {
1111    async fn next(&mut self) -> Option<ProcState> {
1112        loop {
1113            if let state @ Some(_) = self.event_queue.pop_front() {
1114                break state;
1115            }
1116
1117            if !self.running {
1118                break None;
1119            }
1120
1121            if let Err(e) = self.ensure_started().await {
1122                break Some(ProcState::Failed {
1123                    world_id: self.world_id.clone(),
1124                    description: format!("failed to ensure started: {:#}", e),
1125                });
1126            }
1127
1128            let heartbeat_interval = hyperactor_config::global::get(
1129                hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1130            );
1131            let mut heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
1132            // rerun outer loop in case we pushed new items to the event queue
1133            let mut reloop = false;
1134            let update = loop {
1135                tokio::select! {
1136                    msg = self.rx.recv() => {
1137                        tracing::debug!("got ProcState message from allocator: {:?}", msg);
1138                        match msg {
1139                            Ok(RemoteProcessProcStateMessage::Allocated { alloc_key, world_id }) => {
1140                                tracing::info!("remote alloc {}: allocated", alloc_key);
1141                                match self.get_host_state_mut(&alloc_key) {
1142                                    Ok(state) => {
1143                                        state.world_id = Some(world_id.clone());
1144                                    }
1145                                    Err(err) => {
1146                                        // should never happenA
1147                                        tracing::error!(
1148                                            "received allocated message alloc: {} with no known state: {}",
1149                                            alloc_key, err,
1150                                        );
1151                                    }
1152                                }
1153                            }
1154                            Ok(RemoteProcessProcStateMessage::Update(alloc_key, proc_state)) => {
1155                                let update = match proc_state {
1156                                    ProcState::Created { ref create_key, .. } => {
1157                                        if let Err(e) = self.add_proc_id_to_host_state(&alloc_key, create_key) {
1158                                            tracing::error!("failed to add proc with create key {} host state: {}", create_key, e);
1159                                        }
1160                                        proc_state
1161                                    }
1162                                    ProcState::Stopped{ ref create_key, ..} => {
1163                                        if let Err(e) = self.remove_proc_from_host_state(&alloc_key, create_key) {
1164                                            tracing::error!("failed to remove proc with create key {} host state: {}", create_key, e);
1165                                        }
1166                                        proc_state
1167                                    }
1168                                    ProcState::Failed { ref world_id, ref description } => {
1169                                        match self.get_host_state_mut(&alloc_key) {
1170                                            Ok(state) => {
1171                                                state.failed = true;
1172                                                ProcState::Failed {
1173                                                    world_id: world_id.clone(),
1174                                                    description: format!("host {} failed: {}", state.host_id, description),
1175                                                }
1176                                            }
1177                                            Err(e) => {
1178                                                tracing::error!("failed to find host state for world id: {}: {}", world_id, e);
1179                                                proc_state
1180                                            }
1181                                        }
1182                                    }
1183                                    _ => proc_state
1184                                };
1185
1186                                break Some((Some(alloc_key), update));
1187                            }
1188                            Ok(RemoteProcessProcStateMessage::Done(alloc_key)) => {
1189                                tracing::info!("allocator {} is done", alloc_key);
1190
1191                                if let Ok(state) = self.remove_host_state(&alloc_key) {
1192                                    if !state.active_procs.is_empty() {
1193                                        tracing::error!("received done for alloc {} with active procs: {:?}", alloc_key, state.active_procs);
1194                                    }
1195                                } else {
1196                                    tracing::error!("received done for unknown alloc {}", alloc_key);
1197                                }
1198
1199                                if self.host_states.is_empty() {
1200                                    self.running = false;
1201                                    break None;
1202                                }
1203                            }
1204                            // Hearbeat message is discarded immediately after being received, sender (remote
1205                            // process allocator) relies on channel ack to know if the receiver (client) is
1206                            // still alive. No state needs to be updated.
1207                            Ok(RemoteProcessProcStateMessage::HeartBeat) => {}
1208                            Err(e) => {
1209                                break Some((None, ProcState::Failed {world_id: self.world_id.clone(), description: format!("error receiving events: {}", e)}));
1210                            }
1211                        }
1212                    }
1213
1214                    _ = clock::RealClock.sleep_until(heartbeat_time) => {
1215                        self.host_states.iter().for_each(|(_, host_state)| host_state.tx.post(RemoteProcessAllocatorMessage::HeartBeat));
1216                        heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
1217                    }
1218
1219                    closed_host_id = self.comm_watcher_rx.recv() => {
1220                        if let Some(closed_host_id) = closed_host_id {
1221                            tracing::debug!("host {} channel closed, cleaning up", closed_host_id);
1222                            if let Some(state) = self.host_states.get(&closed_host_id)
1223                                && !state.allocated {
1224                                    break Some((None, ProcState::Failed {
1225                                        world_id: self.world_id.clone(),
1226                                        description: format!(
1227                                            "no process has ever been allocated on {}, connection is not established; \
1228                                            a common issue could be allocator port mismatch, or transport not enabled on client",
1229                                            closed_host_id
1230                                        )}));
1231                                }
1232                            let create_keys = match self.cleanup_host_channel_closed(closed_host_id) {
1233                                Ok(create_keys) => create_keys,
1234                                Err(err) => {
1235                                    tracing::error!("failed to cleanup disconnected host: {}", err);
1236                                    continue;
1237                                }
1238                            };
1239                            for create_key in create_keys {
1240                                tracing::debug!("queuing Stopped state for proc with create key {}", create_key);
1241                                self.event_queue.push_back(
1242                                    ProcState::Stopped {
1243                                        create_key,
1244                                        reason: ProcStopReason::HostWatchdog
1245                                    }
1246                                );
1247                            }
1248                            // Check if there are any hosts left
1249                            if self.host_states.is_empty() {
1250                                tracing::info!("no more hosts left, stopping the alloc");
1251                                self.running = false;
1252                            }
1253                            // Kick back to the outer loop to pop off the queue if necessary
1254                            reloop = true;
1255                            break None;
1256                        } else {
1257                            // other side closed
1258                            tracing::warn!("unexpected comm watcher channel close");
1259                            break None;
1260                        }
1261                    }
1262                }
1263            };
1264
1265            if reloop {
1266                // Kick back to the outer loop to pop off the queue.
1267                // Should not have produced an update.
1268                assert!(update.is_none());
1269                continue;
1270            }
1271
1272            break match update {
1273                Some((
1274                    Some(alloc_key),
1275                    ProcState::Created {
1276                        create_key,
1277                        point,
1278                        pid,
1279                    },
1280                )) => match self.project_proc_into_global_extent(&alloc_key, &point) {
1281                    Ok(global_point) => {
1282                        tracing::debug!("reprojected coords: {} -> {}", point, global_point);
1283                        Some(ProcState::Created {
1284                            create_key,
1285                            point: global_point,
1286                            pid,
1287                        })
1288                    }
1289                    Err(e) => {
1290                        tracing::error!(
1291                            "failed to project coords for proc: {}.{}: {}",
1292                            alloc_key,
1293                            create_key,
1294                            e
1295                        );
1296                        None
1297                    }
1298                },
1299                Some((None, ProcState::Created { .. })) => {
1300                    panic!("illegal state: missing alloc_key for ProcState::Created event")
1301                }
1302                Some((_, update)) => {
1303                    if let ProcState::Failed { description, .. } = &update {
1304                        tracing::error!(description);
1305                        self.failed = true;
1306                    }
1307                    Some(update)
1308                }
1309                None => None,
1310            };
1311        }
1312    }
1313
1314    fn spec(&self) -> &AllocSpec {
1315        &self.spec
1316    }
1317
1318    fn extent(&self) -> &Extent {
1319        &self.spec.extent
1320    }
1321
1322    fn world_id(&self) -> &WorldId {
1323        &self.world_id
1324    }
1325
1326    async fn stop(&mut self) -> Result<(), AllocatorError> {
1327        tracing::info!("stopping alloc");
1328
1329        for (host_id, task_state) in self.host_states.iter_mut() {
1330            tracing::debug!("stopping alloc at host {}", host_id);
1331            task_state.tx.post(RemoteProcessAllocatorMessage::Stop);
1332        }
1333
1334        Ok(())
1335    }
1336
1337    /// For Tcp and Metatls, return a router address which has the same IP address
1338    /// as the alloc's bootstrap address, but with port set as 0. We do this
1339    /// instead of using `ChannelAddr::any` is because we want alloc uses the
1340    /// same IP address for all its frontend ports. In some environment, the
1341    /// host can have public IP address and private IP address, and use the wrong
1342    /// one could lead to port unreachable error.
1343    ///
1344    /// For other channel types, this method still uses ChannelAddr::any.
1345    fn client_router_addr(&self) -> ChannelAddr {
1346        with_unspecified_port_or_any(&self.bootstrap_addr)
1347    }
1348}
1349
1350impl Drop for RemoteProcessAlloc {
1351    fn drop(&mut self) {
1352        tracing::debug!("dropping RemoteProcessAlloc of world_id {}", self.world_id);
1353    }
1354}
1355
1356#[cfg(test)]
1357mod test {
1358    use std::assert_matches::assert_matches;
1359
1360    use hyperactor::ActorRef;
1361    use hyperactor::channel::ChannelRx;
1362    use hyperactor::clock::ClockKind;
1363    use hyperactor::id;
1364    use ndslice::extent;
1365    use tokio::sync::oneshot;
1366
1367    use super::*;
1368    use crate::alloc::ChannelTransport;
1369    use crate::alloc::MockAlloc;
1370    use crate::alloc::MockAllocWrapper;
1371    use crate::alloc::MockAllocator;
1372    use crate::alloc::ProcStopReason;
1373    use crate::alloc::with_unspecified_port_or_any;
1374    use crate::proc_mesh::mesh_agent::ProcMeshAgent;
1375
1376    async fn read_all_created(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1377        let mut i: usize = 0;
1378        while i < alloc_len {
1379            let m = rx.recv().await.unwrap();
1380            match m {
1381                RemoteProcessProcStateMessage::Update(_, ProcState::Created { .. }) => i += 1,
1382                RemoteProcessProcStateMessage::HeartBeat => {}
1383                _ => panic!("unexpected message: {:?}", m),
1384            }
1385        }
1386    }
1387
1388    async fn read_all_running(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1389        let mut i: usize = 0;
1390        while i < alloc_len {
1391            let m = rx.recv().await.unwrap();
1392            match m {
1393                RemoteProcessProcStateMessage::Update(_, ProcState::Running { .. }) => i += 1,
1394                RemoteProcessProcStateMessage::HeartBeat => {}
1395                _ => panic!("unexpected message: {:?}", m),
1396            }
1397        }
1398    }
1399
1400    async fn read_all_stopped(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1401        let mut i: usize = 0;
1402        while i < alloc_len {
1403            let m = rx.recv().await.unwrap();
1404            match m {
1405                RemoteProcessProcStateMessage::Update(_, ProcState::Stopped { .. }) => i += 1,
1406                RemoteProcessProcStateMessage::HeartBeat => {}
1407                _ => panic!("unexpected message: {:?}", m),
1408            }
1409        }
1410    }
1411
1412    fn set_procstate_expectations(alloc: &mut MockAlloc, extent: Extent) {
1413        alloc.expect_extent().return_const(extent.clone());
1414        let mut create_keys = Vec::new();
1415        for point in extent.points() {
1416            let create_key = ShortUuid::generate();
1417            create_keys.push(create_key.clone());
1418            alloc.expect_next().times(1).return_once(move || {
1419                Some(ProcState::Created {
1420                    create_key: create_key.clone(),
1421                    point,
1422                    pid: 0,
1423                })
1424            });
1425        }
1426        for (i, create_key) in create_keys
1427            .iter()
1428            .take(extent.num_ranks())
1429            .cloned()
1430            .enumerate()
1431        {
1432            let proc_id = format!("test[{i}]").parse().unwrap();
1433            let mesh_agent = ActorRef::<ProcMeshAgent>::attest(
1434                format!("test[{i}].mesh_agent[{i}]").parse().unwrap(),
1435            );
1436            alloc.expect_next().times(1).return_once(move || {
1437                Some(ProcState::Running {
1438                    create_key,
1439                    proc_id,
1440                    addr: ChannelAddr::Unix("/proc0".parse().unwrap()),
1441                    mesh_agent,
1442                })
1443            });
1444        }
1445        for create_key in create_keys.iter().take(extent.num_ranks()).cloned() {
1446            alloc.expect_next().times(1).return_once(|| {
1447                Some(ProcState::Stopped {
1448                    create_key,
1449                    reason: ProcStopReason::Unknown,
1450                })
1451            });
1452        }
1453    }
1454
1455    #[timed_test::async_timed_test(timeout_secs = 5)]
1456    async fn test_simple() {
1457        let config = hyperactor_config::global::lock();
1458        let _guard = config.override_key(
1459            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1460            Duration::from_millis(100),
1461        );
1462        hyperactor_telemetry::initialize_logging_for_test();
1463        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1464        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1465        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1466
1467        let extent = extent!(host = 1, gpu = 2);
1468        let tx = channel::dial(serve_addr.clone()).unwrap();
1469
1470        let world_id: WorldId = id!(test_world_id);
1471        let mut alloc = MockAlloc::new();
1472        alloc.expect_world_id().return_const(world_id.clone());
1473        alloc.expect_extent().return_const(extent.clone());
1474
1475        set_procstate_expectations(&mut alloc, extent.clone());
1476
1477        // final none
1478        alloc.expect_next().return_const(None);
1479
1480        let mut allocator = MockAllocator::new();
1481        let total_messages = extent.num_ranks() * 3 + 1;
1482        let mock_wrapper = MockAllocWrapper::new_block_next(
1483            alloc,
1484            // block after create, running, stopped and done.
1485            total_messages,
1486        );
1487        allocator
1488            .expect_allocate()
1489            .times(1)
1490            .return_once(move |_| Ok(mock_wrapper));
1491
1492        let remote_allocator = RemoteProcessAllocator::new();
1493        let handle = tokio::spawn({
1494            let remote_allocator = remote_allocator.clone();
1495            async move {
1496                remote_allocator
1497                    .start_with_allocator(serve_addr, allocator, None)
1498                    .await
1499            }
1500        });
1501
1502        let alloc_key = ShortUuid::generate();
1503
1504        tx.send(RemoteProcessAllocatorMessage::Allocate {
1505            alloc_key: alloc_key.clone(),
1506            extent: extent.clone(),
1507            bootstrap_addr,
1508            hosts: vec![],
1509            client_context: None,
1510            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1511        })
1512        .await
1513        .unwrap();
1514
1515        // Allocated
1516        let m = rx.recv().await.unwrap();
1517        assert_matches!(
1518            m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1519                if got_world_id == world_id && got_alloc_key == alloc_key
1520        );
1521
1522        // All Created events
1523        let mut rank: usize = 0;
1524        let mut create_keys = Vec::with_capacity(extent.num_ranks());
1525        while rank < extent.num_ranks() {
1526            let m = rx.recv().await.unwrap();
1527            match m {
1528                RemoteProcessProcStateMessage::Update(
1529                    got_alloc_key,
1530                    ProcState::Created {
1531                        create_key, point, ..
1532                    },
1533                ) => {
1534                    let expected_point = extent.point_of_rank(rank).unwrap();
1535                    assert_eq!(got_alloc_key, alloc_key);
1536                    assert_eq!(point, expected_point);
1537                    create_keys.push(create_key);
1538                    rank += 1;
1539                }
1540                RemoteProcessProcStateMessage::HeartBeat => {}
1541                _ => panic!("unexpected message: {:?}", m),
1542            }
1543        }
1544        // All Running events
1545        let mut rank: usize = 0;
1546        while rank < extent.num_ranks() {
1547            let m = rx.recv().await.unwrap();
1548            match m {
1549                RemoteProcessProcStateMessage::Update(
1550                    got_alloc_key,
1551                    ProcState::Running {
1552                        create_key,
1553                        proc_id,
1554                        mesh_agent,
1555                        addr: _,
1556                    },
1557                ) => {
1558                    assert_eq!(got_alloc_key, alloc_key);
1559                    assert_eq!(create_key, create_keys[rank]);
1560                    let expected_proc_id = format!("test[{}]", rank).parse().unwrap();
1561                    let expected_mesh_agent = ActorRef::<ProcMeshAgent>::attest(
1562                        format!("test[{}].mesh_agent[{}]", rank, rank)
1563                            .parse()
1564                            .unwrap(),
1565                    );
1566                    assert_eq!(proc_id, expected_proc_id);
1567                    assert_eq!(mesh_agent, expected_mesh_agent);
1568                    rank += 1;
1569                }
1570                RemoteProcessProcStateMessage::HeartBeat => {}
1571                _ => panic!("unexpected message: {:?}", m),
1572            }
1573        }
1574        // All Stopped events
1575        let mut rank: usize = 0;
1576        while rank < extent.num_ranks() {
1577            let m = rx.recv().await.unwrap();
1578            match m {
1579                RemoteProcessProcStateMessage::Update(
1580                    got_alloc_key,
1581                    ProcState::Stopped {
1582                        create_key,
1583                        reason: ProcStopReason::Unknown,
1584                    },
1585                ) => {
1586                    assert_eq!(got_alloc_key, alloc_key);
1587                    assert_eq!(create_key, create_keys[rank]);
1588                    rank += 1;
1589                }
1590                RemoteProcessProcStateMessage::HeartBeat => {}
1591                _ => panic!("unexpected message: {:?}", m),
1592            }
1593        }
1594        // Done
1595        loop {
1596            let m = rx.recv().await.unwrap();
1597            match m {
1598                RemoteProcessProcStateMessage::Done(got_alloc_key) => {
1599                    assert_eq!(got_alloc_key, alloc_key);
1600                    break;
1601                }
1602                RemoteProcessProcStateMessage::HeartBeat => {}
1603                _ => panic!("unexpected message: {:?}", m),
1604            }
1605        }
1606
1607        remote_allocator.terminate();
1608        handle.await.unwrap().unwrap();
1609    }
1610
1611    #[timed_test::async_timed_test(timeout_secs = 15)]
1612    async fn test_normal_stop() {
1613        let config = hyperactor_config::global::lock();
1614        let _guard = config.override_key(
1615            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1616            Duration::from_millis(100),
1617        );
1618        hyperactor_telemetry::initialize_logging_for_test();
1619        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1620        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1621        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1622
1623        let extent = extent!(host = 1, gpu = 2);
1624        let tx = channel::dial(serve_addr.clone()).unwrap();
1625
1626        let world_id: WorldId = id!(test_world_id);
1627        let mut alloc = MockAllocWrapper::new_block_next(
1628            MockAlloc::new(),
1629            // block after all created, all running
1630            extent.num_ranks() * 2,
1631        );
1632        let next_tx = alloc.notify_tx();
1633        alloc.alloc.expect_world_id().return_const(world_id.clone());
1634        alloc.alloc.expect_extent().return_const(extent.clone());
1635
1636        set_procstate_expectations(&mut alloc.alloc, extent.clone());
1637
1638        alloc.alloc.expect_next().return_const(None);
1639        alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1640
1641        let mut allocator = MockAllocator::new();
1642        allocator
1643            .expect_allocate()
1644            .times(1)
1645            .return_once(|_| Ok(alloc));
1646
1647        let remote_allocator = RemoteProcessAllocator::new();
1648        let handle = tokio::spawn({
1649            let remote_allocator = remote_allocator.clone();
1650            async move {
1651                remote_allocator
1652                    .start_with_allocator(serve_addr, allocator, None)
1653                    .await
1654            }
1655        });
1656
1657        let alloc_key = ShortUuid::generate();
1658        tx.send(RemoteProcessAllocatorMessage::Allocate {
1659            alloc_key: alloc_key.clone(),
1660            extent: extent.clone(),
1661            bootstrap_addr,
1662            hosts: vec![],
1663            client_context: None,
1664            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1665        })
1666        .await
1667        .unwrap();
1668
1669        // Allocated
1670        let m = rx.recv().await.unwrap();
1671        assert_matches!(
1672            m,
1673            RemoteProcessProcStateMessage::Allocated {  world_id: got_world_id, alloc_key: got_alloc_key }
1674            if world_id == got_world_id && alloc_key == got_alloc_key
1675        );
1676
1677        read_all_created(&mut rx, extent.num_ranks()).await;
1678        read_all_running(&mut rx, extent.num_ranks()).await;
1679
1680        // allocation finished. now we stop it.
1681        tracing::info!("stopping allocation");
1682        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1683        // receive all stops
1684        next_tx.send(()).unwrap();
1685
1686        read_all_stopped(&mut rx, extent.num_ranks()).await;
1687
1688        remote_allocator.terminate();
1689        handle.await.unwrap().unwrap();
1690    }
1691
1692    #[timed_test::async_timed_test(timeout_secs = 15)]
1693    async fn test_realloc() {
1694        let config = hyperactor_config::global::lock();
1695        let _guard = config.override_key(
1696            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1697            Duration::from_millis(100),
1698        );
1699        hyperactor_telemetry::initialize_logging_for_test();
1700        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1701        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1702        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1703
1704        let extent = extent!(host = 1, gpu = 2);
1705
1706        let tx = channel::dial(serve_addr.clone()).unwrap();
1707
1708        let world_id: WorldId = id!(test_world_id);
1709        let mut alloc1 = MockAllocWrapper::new_block_next(
1710            MockAlloc::new(),
1711            // block after all created, all running
1712            extent.num_ranks() * 2,
1713        );
1714        let next_tx1 = alloc1.notify_tx();
1715        alloc1
1716            .alloc
1717            .expect_world_id()
1718            .return_const(world_id.clone());
1719        alloc1.alloc.expect_extent().return_const(extent.clone());
1720
1721        set_procstate_expectations(&mut alloc1.alloc, extent.clone());
1722        alloc1.alloc.expect_next().return_const(None);
1723        alloc1.alloc.expect_stop().times(1).return_once(|| Ok(()));
1724        // second allocation
1725        let mut alloc2 = MockAllocWrapper::new_block_next(
1726            MockAlloc::new(),
1727            // block after all created, all running
1728            extent.num_ranks() * 2,
1729        );
1730        let next_tx2 = alloc2.notify_tx();
1731        alloc2
1732            .alloc
1733            .expect_world_id()
1734            .return_const(world_id.clone());
1735        alloc2.alloc.expect_extent().return_const(extent.clone());
1736        set_procstate_expectations(&mut alloc2.alloc, extent.clone());
1737        alloc2.alloc.expect_next().return_const(None);
1738        alloc2.alloc.expect_stop().times(1).return_once(|| Ok(()));
1739
1740        let mut allocator = MockAllocator::new();
1741        allocator
1742            .expect_allocate()
1743            .times(1)
1744            .return_once(|_| Ok(alloc1));
1745        // second alloc
1746        allocator
1747            .expect_allocate()
1748            .times(1)
1749            .return_once(|_| Ok(alloc2));
1750
1751        let remote_allocator = RemoteProcessAllocator::new();
1752        let handle = tokio::spawn({
1753            let remote_allocator = remote_allocator.clone();
1754            async move {
1755                remote_allocator
1756                    .start_with_allocator(serve_addr, allocator, None)
1757                    .await
1758            }
1759        });
1760
1761        let alloc_key = ShortUuid::generate();
1762
1763        tx.send(RemoteProcessAllocatorMessage::Allocate {
1764            alloc_key: alloc_key.clone(),
1765            extent: extent.clone(),
1766            bootstrap_addr: bootstrap_addr.clone(),
1767            hosts: vec![],
1768            client_context: None,
1769            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1770        })
1771        .await
1772        .unwrap();
1773
1774        // Allocated
1775        let m = rx.recv().await.unwrap();
1776        assert_matches!(
1777            m,
1778            RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1779            if got_world_id == world_id && got_alloc_key == alloc_key
1780        );
1781
1782        read_all_created(&mut rx, extent.num_ranks()).await;
1783        read_all_running(&mut rx, extent.num_ranks()).await;
1784
1785        let alloc_key = ShortUuid::generate();
1786
1787        // allocation finished now we request a new one
1788        tx.send(RemoteProcessAllocatorMessage::Allocate {
1789            alloc_key: alloc_key.clone(),
1790            extent: extent.clone(),
1791            bootstrap_addr,
1792            hosts: vec![],
1793            client_context: None,
1794            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1795        })
1796        .await
1797        .unwrap();
1798        // unblock next for the first allocation
1799        next_tx1.send(()).unwrap();
1800        // we expect a stop(), then Stopped proc states, then a new Allocated
1801        read_all_stopped(&mut rx, extent.num_ranks()).await;
1802        let m = rx.recv().await.unwrap();
1803        assert_matches!(m, RemoteProcessProcStateMessage::Done(_));
1804        let m = rx.recv().await.unwrap();
1805        assert_matches!(
1806            m,
1807            RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1808            if got_world_id == world_id && got_alloc_key == alloc_key
1809        );
1810        // ProcStates for the new allocation
1811        read_all_created(&mut rx, extent.num_ranks()).await;
1812        read_all_running(&mut rx, extent.num_ranks()).await;
1813        // finally stop
1814        tracing::info!("stopping allocation");
1815        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1816        // receive all stops
1817        next_tx2.send(()).unwrap();
1818
1819        read_all_stopped(&mut rx, extent.num_ranks()).await;
1820
1821        remote_allocator.terminate();
1822        handle.await.unwrap().unwrap();
1823    }
1824
1825    #[timed_test::async_timed_test(timeout_secs = 15)]
1826    async fn test_upstream_closed() {
1827        // Use temporary config for this test
1828        let config = hyperactor_config::global::lock();
1829        let _guard1 = config.override_key(
1830            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1831            Duration::from_secs(1),
1832        );
1833        let _guard2 = config.override_key(
1834            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1835            Duration::from_millis(100),
1836        );
1837
1838        hyperactor_telemetry::initialize_logging_for_test();
1839        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1840        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1841        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1842
1843        let extent = extent!(host = 1, gpu = 2);
1844
1845        let tx = channel::dial(serve_addr.clone()).unwrap();
1846
1847        let world_id: WorldId = id!(test_world_id);
1848        let mut alloc = MockAllocWrapper::new_block_next(
1849            MockAlloc::new(),
1850            // block after all created, all running
1851            extent.num_ranks() * 2,
1852        );
1853        let next_tx = alloc.notify_tx();
1854        alloc.alloc.expect_world_id().return_const(world_id.clone());
1855        alloc.alloc.expect_extent().return_const(extent.clone());
1856
1857        set_procstate_expectations(&mut alloc.alloc, extent.clone());
1858
1859        alloc.alloc.expect_next().return_const(None);
1860        // we expect a stop due to the failure
1861        // synchronize test with the stop
1862        let (stop_tx, stop_rx) = oneshot::channel();
1863        alloc.alloc.expect_stop().times(1).return_once(|| {
1864            stop_tx.send(()).unwrap();
1865            Ok(())
1866        });
1867
1868        let mut allocator = MockAllocator::new();
1869        allocator
1870            .expect_allocate()
1871            .times(1)
1872            .return_once(|_| Ok(alloc));
1873
1874        let remote_allocator = RemoteProcessAllocator::new();
1875        let handle = tokio::spawn({
1876            let remote_allocator = remote_allocator.clone();
1877            async move {
1878                remote_allocator
1879                    .start_with_allocator(serve_addr, allocator, None)
1880                    .await
1881            }
1882        });
1883
1884        let alloc_key = ShortUuid::generate();
1885
1886        tx.send(RemoteProcessAllocatorMessage::Allocate {
1887            alloc_key: alloc_key.clone(),
1888            extent: extent.clone(),
1889            bootstrap_addr,
1890            hosts: vec![],
1891            client_context: None,
1892            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1893        })
1894        .await
1895        .unwrap();
1896
1897        // Allocated
1898        let m = rx.recv().await.unwrap();
1899        assert_matches!(
1900            m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1901                if got_world_id == world_id && got_alloc_key == alloc_key
1902        );
1903
1904        read_all_created(&mut rx, extent.num_ranks()).await;
1905        read_all_running(&mut rx, extent.num_ranks()).await;
1906
1907        // allocation finished. terminate connection.
1908        tracing::info!("closing upstream");
1909        drop(rx);
1910        // wait for the heartbeat to expire
1911        #[allow(clippy::disallowed_methods)]
1912        tokio::time::sleep(Duration::from_secs(2)).await;
1913        // wait for the stop to be called
1914        stop_rx.await.unwrap();
1915        // unblock next
1916        next_tx.send(()).unwrap();
1917        remote_allocator.terminate();
1918        handle.await.unwrap().unwrap();
1919    }
1920
1921    #[timed_test::async_timed_test(timeout_secs = 15)]
1922    async fn test_inner_alloc_failure() {
1923        let config = hyperactor_config::global::lock();
1924        let _guard = config.override_key(
1925            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1926            Duration::from_mins(1),
1927        );
1928        hyperactor_telemetry::initialize_logging_for_test();
1929        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1930        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1931        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1932
1933        let extent = extent!(host = 1, gpu = 2);
1934
1935        let tx = channel::dial(serve_addr.clone()).unwrap();
1936
1937        let test_world_id: WorldId = id!(test_world_id);
1938        let mut alloc = MockAllocWrapper::new_block_next(
1939            MockAlloc::new(),
1940            // block after the failure update
1941            1,
1942        );
1943        let next_tx = alloc.notify_tx();
1944        alloc
1945            .alloc
1946            .expect_world_id()
1947            .return_const(test_world_id.clone());
1948        alloc.alloc.expect_extent().return_const(extent.clone());
1949        alloc
1950            .alloc
1951            .expect_next()
1952            .times(1)
1953            .return_const(Some(ProcState::Failed {
1954                world_id: test_world_id.clone(),
1955                description: "test".to_string(),
1956            }));
1957        alloc.alloc.expect_next().times(1).return_const(None);
1958
1959        alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1960
1961        let mut allocator = MockAllocator::new();
1962        allocator
1963            .expect_allocate()
1964            .times(1)
1965            .return_once(|_| Ok(alloc));
1966
1967        let remote_allocator = RemoteProcessAllocator::new();
1968        let handle = tokio::spawn({
1969            let remote_allocator = remote_allocator.clone();
1970            async move {
1971                remote_allocator
1972                    .start_with_allocator(serve_addr, allocator, None)
1973                    .await
1974            }
1975        });
1976
1977        let alloc_key = ShortUuid::generate();
1978        tx.send(RemoteProcessAllocatorMessage::Allocate {
1979            alloc_key: alloc_key.clone(),
1980            extent: extent.clone(),
1981            bootstrap_addr,
1982            hosts: vec![],
1983            client_context: None,
1984            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1985        })
1986        .await
1987        .unwrap();
1988
1989        // Allocated
1990        let m = rx.recv().await.unwrap();
1991        assert_matches!(
1992            m,
1993            RemoteProcessProcStateMessage::Allocated {  world_id: got_world_id, alloc_key: got_alloc_key }
1994            if test_world_id == got_world_id && alloc_key == got_alloc_key
1995        );
1996
1997        // Failed
1998        let m = rx.recv().await.unwrap();
1999        assert_matches!(
2000            m,
2001            RemoteProcessProcStateMessage::Update(
2002                got_alloc_key,
2003                ProcState::Failed { world_id, description }
2004            ) if got_alloc_key == alloc_key && world_id == test_world_id && description == "test"
2005        );
2006
2007        tracing::info!("stopping allocation");
2008        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
2009        // receive all stops
2010        next_tx.send(()).unwrap();
2011        // we are expecting 1 Done when Alloc successfully stops.
2012        let m = rx.recv().await.unwrap();
2013        assert_matches!(
2014            m,
2015            RemoteProcessProcStateMessage::Done(got_alloc_key)
2016            if got_alloc_key == alloc_key
2017        );
2018
2019        remote_allocator.terminate();
2020        handle.await.unwrap().unwrap();
2021    }
2022
2023    #[timed_test::async_timed_test(timeout_secs = 15)]
2024    async fn test_trace_id_propagation() {
2025        let config = hyperactor_config::global::lock();
2026        let _guard = config.override_key(
2027            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2028            Duration::from_mins(1),
2029        );
2030        hyperactor_telemetry::initialize_logging(ClockKind::default());
2031        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
2032        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
2033        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
2034
2035        let extent = extent!(host = 1, gpu = 1);
2036        let tx = channel::dial(serve_addr.clone()).unwrap();
2037        let test_world_id: WorldId = id!(test_world_id);
2038        let test_trace_id = "test_trace_id_12345";
2039
2040        // Create a mock alloc that we can verify receives the correct trace id
2041        let mut alloc = MockAlloc::new();
2042        alloc.expect_world_id().return_const(test_world_id.clone());
2043        alloc.expect_extent().return_const(extent.clone());
2044        alloc.expect_next().return_const(None);
2045
2046        // Create a mock allocator that captures the AllocSpec passed to it
2047        let mut allocator = MockAllocator::new();
2048        allocator
2049            .expect_allocate()
2050            .times(1)
2051            .withf(move |spec: &AllocSpec| {
2052                // Verify that the trace id is correctly set in the constraints
2053                spec.constraints
2054                    .match_labels
2055                    .get(CLIENT_TRACE_ID_LABEL)
2056                    .is_some_and(|trace_id| trace_id == test_trace_id)
2057            })
2058            .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
2059
2060        let remote_allocator = RemoteProcessAllocator::new();
2061        let handle = tokio::spawn({
2062            let remote_allocator = remote_allocator.clone();
2063            async move {
2064                remote_allocator
2065                    .start_with_allocator(serve_addr, allocator, None)
2066                    .await
2067            }
2068        });
2069
2070        let alloc_key = ShortUuid::generate();
2071        tx.send(RemoteProcessAllocatorMessage::Allocate {
2072            alloc_key: alloc_key.clone(),
2073            extent: extent.clone(),
2074            bootstrap_addr,
2075            hosts: vec![],
2076            client_context: Some(ClientContext {
2077                trace_id: test_trace_id.to_string(),
2078            }),
2079            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
2080        })
2081        .await
2082        .unwrap();
2083
2084        // Verify we get the allocated message
2085        let m = rx.recv().await.unwrap();
2086        assert_matches!(
2087            m,
2088            RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
2089            if got_world_id == test_world_id && got_alloc_key == alloc_key
2090        );
2091
2092        // Verify we get the done message since the mock alloc returns None immediately
2093        let m = rx.recv().await.unwrap();
2094        assert_matches!(
2095            m,
2096            RemoteProcessProcStateMessage::Done(got_alloc_key)
2097            if alloc_key == got_alloc_key
2098        );
2099
2100        remote_allocator.terminate();
2101        handle.await.unwrap().unwrap();
2102    }
2103
2104    #[timed_test::async_timed_test(timeout_secs = 15)]
2105    async fn test_trace_id_propagation_no_client_context() {
2106        let config = hyperactor_config::global::lock();
2107        let _guard = config.override_key(
2108            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2109            Duration::from_mins(1),
2110        );
2111        hyperactor_telemetry::initialize_logging(ClockKind::default());
2112        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
2113        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
2114        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
2115
2116        let extent = extent!(host = 1, gpu = 1);
2117        let tx = channel::dial(serve_addr.clone()).unwrap();
2118        let test_world_id: WorldId = id!(test_world_id);
2119
2120        // Create a mock alloc
2121        let mut alloc = MockAlloc::new();
2122        alloc.expect_world_id().return_const(test_world_id.clone());
2123        alloc.expect_extent().return_const(extent.clone());
2124        alloc.expect_next().return_const(None);
2125
2126        // Create a mock allocator that verifies no trace id is set when client_context is None
2127        let mut allocator = MockAllocator::new();
2128        allocator
2129            .expect_allocate()
2130            .times(1)
2131            .withf(move |spec: &AllocSpec| {
2132                // Verify that no trace id is set in the constraints when client_context is None
2133                spec.constraints.match_labels.is_empty()
2134            })
2135            .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
2136
2137        let remote_allocator = RemoteProcessAllocator::new();
2138        let handle = tokio::spawn({
2139            let remote_allocator = remote_allocator.clone();
2140            async move {
2141                remote_allocator
2142                    .start_with_allocator(serve_addr, allocator, None)
2143                    .await
2144            }
2145        });
2146
2147        let alloc_key = ShortUuid::generate();
2148        tx.send(RemoteProcessAllocatorMessage::Allocate {
2149            alloc_key: alloc_key.clone(),
2150            extent: extent.clone(),
2151            bootstrap_addr,
2152            hosts: vec![],
2153            client_context: None,
2154            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
2155        })
2156        .await
2157        .unwrap();
2158
2159        // Verify we get the allocated message
2160        let m = rx.recv().await.unwrap();
2161        assert_matches!(
2162            m,
2163            RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
2164            if got_world_id == test_world_id && got_alloc_key == alloc_key
2165        );
2166
2167        // Verify we get the done message since the mock alloc returns None immediately
2168        let m = rx.recv().await.unwrap();
2169        assert_matches!(
2170            m,
2171            RemoteProcessProcStateMessage::Done(got_alloc_key)
2172            if got_alloc_key == alloc_key
2173        );
2174
2175        remote_allocator.terminate();
2176        handle.await.unwrap().unwrap();
2177    }
2178}
2179
2180#[cfg(test)]
2181mod test_alloc {
2182    use std::os::unix::process::ExitStatusExt;
2183
2184    use hyperactor::clock::ClockKind;
2185    use hyperactor_config;
2186    use ndslice::extent;
2187    use nix::sys::signal;
2188    use nix::unistd::Pid;
2189    use timed_test::async_timed_test;
2190
2191    use super::*;
2192
2193    #[async_timed_test(timeout_secs = 60)]
2194    #[cfg(fbcode_build)]
2195    async fn test_alloc_simple() {
2196        // Use temporary config for this test
2197        let config = hyperactor_config::global::lock();
2198        let _guard1 = config.override_key(
2199            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2200            Duration::from_secs(1),
2201        );
2202        let _guard2 = config.override_key(
2203            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2204            Duration::from_millis(100),
2205        );
2206        hyperactor_telemetry::initialize_logging(ClockKind::default());
2207
2208        let spec = AllocSpec {
2209            extent: extent!(host = 2, gpu = 2),
2210            constraints: Default::default(),
2211            proc_name: None,
2212            transport: ChannelTransport::Unix,
2213            proc_allocation_mode: Default::default(),
2214        };
2215        let world_id = WorldId("test_world_id".to_string());
2216
2217        let task1_allocator = RemoteProcessAllocator::new();
2218        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2219        let task1_addr_string = task1_addr.to_string();
2220        let task1_cmd = Command::new(crate::testresource::get(
2221            "monarch/hyperactor_mesh/bootstrap",
2222        ));
2223        let task2_allocator = RemoteProcessAllocator::new();
2224        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2225        let task2_addr_string = task2_addr.to_string();
2226        let task2_cmd = Command::new(crate::testresource::get(
2227            "monarch/hyperactor_mesh/bootstrap",
2228        ));
2229        let task1_allocator_copy = task1_allocator.clone();
2230        let task1_allocator_handle = tokio::spawn(async move {
2231            tracing::info!("spawning task1");
2232            task1_allocator_copy
2233                .start(task1_cmd, task1_addr, None)
2234                .await
2235                .unwrap();
2236        });
2237        let task2_allocator_copy = task2_allocator.clone();
2238        let task2_allocator_handle = tokio::spawn(async move {
2239            task2_allocator_copy
2240                .start(task2_cmd, task2_addr, None)
2241                .await
2242                .unwrap();
2243        });
2244
2245        let mut initializer = MockRemoteProcessAllocInitializer::new();
2246        initializer.expect_initialize_alloc().return_once(move || {
2247            Ok(vec![
2248                RemoteProcessAllocHost {
2249                    hostname: task1_addr_string,
2250                    id: "task1".to_string(),
2251                },
2252                RemoteProcessAllocHost {
2253                    hostname: task2_addr_string,
2254                    id: "task2".to_string(),
2255                },
2256            ])
2257        });
2258        let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2259            .await
2260            .unwrap();
2261        let mut created = HashSet::new();
2262        let mut running_procs = HashSet::new();
2263        let mut proc_points = HashSet::new();
2264        for _ in 0..spec.extent.num_ranks() * 2 {
2265            let proc_state = alloc.next().await.unwrap();
2266            tracing::debug!("test got message: {:?}", proc_state);
2267            match proc_state {
2268                ProcState::Created {
2269                    create_key, point, ..
2270                } => {
2271                    created.insert(create_key);
2272                    proc_points.insert(point);
2273                }
2274                ProcState::Running { create_key, .. } => {
2275                    assert!(created.remove(&create_key));
2276                    running_procs.insert(create_key);
2277                }
2278                _ => panic!("expected Created or Running"),
2279            }
2280        }
2281        assert!(created.is_empty());
2282        // ensure coords coverage
2283        assert!(
2284            spec.extent
2285                .points()
2286                .all(|point| proc_points.contains(&point))
2287        );
2288
2289        // ensure no more pending items
2290        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2291        tokio::select! {
2292            _ = hyperactor::clock::RealClock.sleep_until(timeout) => {},
2293            _ = alloc.next() => panic!("expected no more items"),
2294        }
2295
2296        // stop the allocation
2297        alloc.stop().await.unwrap();
2298        for _ in 0..spec.extent.num_ranks() {
2299            let proc_state = alloc.next().await.unwrap();
2300            tracing::info!("test received next proc_state: {:?}", proc_state);
2301            match proc_state {
2302                ProcState::Stopped {
2303                    create_key, reason, ..
2304                } => {
2305                    assert!(running_procs.remove(&create_key));
2306                    assert_eq!(reason, ProcStopReason::Stopped);
2307                }
2308                _ => panic!("expected stopped"),
2309            }
2310        }
2311        // Exactly one None
2312        let proc_state = alloc.next().await;
2313        assert!(proc_state.is_none());
2314        // Anything afterwards is None
2315        let proc_state = alloc.next().await;
2316        assert!(proc_state.is_none());
2317
2318        task1_allocator.terminate();
2319        task1_allocator_handle.await.unwrap();
2320        task2_allocator.terminate();
2321        task2_allocator_handle.await.unwrap();
2322    }
2323
2324    #[async_timed_test(timeout_secs = 60)]
2325    #[cfg(fbcode_build)]
2326    async fn test_alloc_host_failure() {
2327        // Use temporary config for this test
2328        let config = hyperactor_config::global::lock();
2329        let _guard1 = config.override_key(
2330            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2331            Duration::from_secs(1),
2332        );
2333        let _guard2 = config.override_key(
2334            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2335            Duration::from_millis(100),
2336        );
2337        hyperactor_telemetry::initialize_logging(ClockKind::default());
2338
2339        let spec = AllocSpec {
2340            extent: extent!(host = 2, gpu = 2),
2341            constraints: Default::default(),
2342            proc_name: None,
2343            transport: ChannelTransport::Unix,
2344            proc_allocation_mode: Default::default(),
2345        };
2346        let world_id = WorldId("test_world_id".to_string());
2347
2348        let task1_allocator = RemoteProcessAllocator::new();
2349        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2350        let task1_addr_string = task1_addr.to_string();
2351        let task1_cmd = Command::new(crate::testresource::get(
2352            "monarch/hyperactor_mesh/bootstrap",
2353        ));
2354        let task2_allocator = RemoteProcessAllocator::new();
2355        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2356        let task2_addr_string = task2_addr.to_string();
2357        let task2_cmd = Command::new(crate::testresource::get(
2358            "monarch/hyperactor_mesh/bootstrap",
2359        ));
2360        let task1_allocator_copy = task1_allocator.clone();
2361        let task1_allocator_handle = tokio::spawn(async move {
2362            tracing::info!("spawning task1");
2363            task1_allocator_copy
2364                .start(task1_cmd, task1_addr, None)
2365                .await
2366                .unwrap();
2367            tracing::info!("task1 terminated");
2368        });
2369        let task2_allocator_copy = task2_allocator.clone();
2370        let task2_allocator_handle = tokio::spawn(async move {
2371            task2_allocator_copy
2372                .start(task2_cmd, task2_addr, None)
2373                .await
2374                .unwrap();
2375            tracing::info!("task2 terminated");
2376        });
2377
2378        let mut initializer = MockRemoteProcessAllocInitializer::new();
2379        initializer.expect_initialize_alloc().return_once(move || {
2380            Ok(vec![
2381                RemoteProcessAllocHost {
2382                    hostname: task1_addr_string,
2383                    id: "task1".to_string(),
2384                },
2385                RemoteProcessAllocHost {
2386                    hostname: task2_addr_string,
2387                    id: "task2".to_string(),
2388                },
2389            ])
2390        });
2391        let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2392            .await
2393            .unwrap();
2394        for _ in 0..spec.extent.num_ranks() * 2 {
2395            match alloc.next().await {
2396                Some(ProcState::Created { .. }) | Some(ProcState::Running { .. }) => {}
2397                _ => panic!("expected Created or Running"),
2398            }
2399        }
2400
2401        // ensure no more pending items
2402        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2403        tokio::select! {
2404            _ = hyperactor::clock::RealClock
2405            .sleep_until(timeout) => {},
2406            _ = alloc.next() => panic!("expected no more items"),
2407        }
2408
2409        // now we kill task1 and wait for timeout
2410        tracing::info!("aborting task1 allocator");
2411        task1_allocator_handle.abort();
2412        RealClock
2413            .sleep(
2414                hyperactor_config::global::get(
2415                    hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2416                ) * 2,
2417            )
2418            .await;
2419        for _ in 0..spec.extent.num_ranks() / 2 {
2420            let proc_state = alloc.next().await.unwrap();
2421            tracing::info!("test received next proc_state: {:?}", proc_state);
2422            match proc_state {
2423                ProcState::Stopped { reason, .. } => {
2424                    assert_eq!(reason, ProcStopReason::HostWatchdog);
2425                }
2426                _ => panic!("expected stopped"),
2427            }
2428        }
2429        // no more events
2430        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2431        tokio::select! {
2432            _ = hyperactor::clock::RealClock
2433            .sleep_until(timeout) => {},
2434            _ = alloc.next() => panic!("expected no more items"),
2435        }
2436
2437        // abort the second host
2438        tracing::info!("aborting task2 allocator");
2439        task2_allocator_handle.abort();
2440        RealClock
2441            .sleep(
2442                hyperactor_config::global::get(
2443                    hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2444                ) * 2,
2445            )
2446            .await;
2447        for _ in 0..spec.extent.num_ranks() / 2 {
2448            let proc_state = alloc.next().await.unwrap();
2449            tracing::info!("test received next proc_state: {:?}", proc_state);
2450            match proc_state {
2451                ProcState::Stopped { reason, .. } => {
2452                    assert_eq!(reason, ProcStopReason::HostWatchdog);
2453                }
2454                _ => panic!("expected stopped"),
2455            }
2456        }
2457        // now the alloc has stopped we always get None
2458        let proc_state = alloc.next().await;
2459        assert!(proc_state.is_none());
2460        // Anything afterwards is None
2461        let proc_state = alloc.next().await;
2462        assert!(proc_state.is_none());
2463    }
2464
2465    #[async_timed_test(timeout_secs = 15)]
2466    #[cfg(fbcode_build)]
2467    async fn test_alloc_inner_alloc_failure() {
2468        // SAFETY: Test happens in single-threaded code.
2469        unsafe {
2470            std::env::set_var("MONARCH_MESSAGE_DELIVERY_TIMEOUT_SECS", "1");
2471        }
2472
2473        let config = hyperactor_config::global::lock();
2474        let _guard = config.override_key(
2475            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2476            Duration::from_millis(100),
2477        );
2478        hyperactor_telemetry::initialize_logging_for_test();
2479
2480        let spec = AllocSpec {
2481            extent: extent!(host = 2, gpu = 2),
2482            constraints: Default::default(),
2483            proc_name: None,
2484            transport: ChannelTransport::Unix,
2485            proc_allocation_mode: Default::default(),
2486        };
2487        let world_id = WorldId("test_world_id".to_string());
2488
2489        let task1_allocator = RemoteProcessAllocator::new();
2490        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2491        let task1_addr_string = task1_addr.to_string();
2492        let task1_cmd = Command::new(crate::testresource::get(
2493            "monarch/hyperactor_mesh/bootstrap",
2494        ));
2495        let task2_allocator = RemoteProcessAllocator::new();
2496        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2497        let task2_addr_string = task2_addr.to_string();
2498        // non-existent binary to fail the allocation
2499        let task2_cmd = Command::new("/caught/somewhere/in/time");
2500        let task1_allocator_copy = task1_allocator.clone();
2501        let task1_allocator_handle = tokio::spawn(async move {
2502            tracing::info!("spawning task1");
2503            task1_allocator_copy
2504                .start(task1_cmd, task1_addr, None)
2505                .await
2506                .unwrap();
2507        });
2508        let task2_allocator_copy = task2_allocator.clone();
2509        let task2_allocator_handle = tokio::spawn(async move {
2510            task2_allocator_copy
2511                .start(task2_cmd, task2_addr, None)
2512                .await
2513                .unwrap();
2514        });
2515
2516        let mut initializer = MockRemoteProcessAllocInitializer::new();
2517        initializer.expect_initialize_alloc().return_once(move || {
2518            Ok(vec![
2519                RemoteProcessAllocHost {
2520                    hostname: task1_addr_string,
2521                    id: "task1".to_string(),
2522                },
2523                RemoteProcessAllocHost {
2524                    hostname: task2_addr_string,
2525                    id: "task2".to_string(),
2526                },
2527            ])
2528        });
2529        let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2530            .await
2531            .unwrap();
2532        let mut created = HashSet::new();
2533        let mut started_procs = HashSet::new();
2534        let mut proc_points = HashSet::new();
2535        let mut failed = 0;
2536        // task 1 procs + 1 failed event for task 2
2537        for _ in 0..spec.extent.num_ranks() + 1 {
2538            let proc_state = alloc.next().await.unwrap();
2539            tracing::debug!("test got message: {:?}", proc_state);
2540            match proc_state {
2541                ProcState::Created {
2542                    create_key, point, ..
2543                } => {
2544                    created.insert(create_key);
2545                    proc_points.insert(point);
2546                }
2547                ProcState::Running { create_key, .. } => {
2548                    assert!(created.remove(&create_key));
2549                    started_procs.insert(create_key);
2550                }
2551                ProcState::Failed { .. } => {
2552                    failed += 1;
2553                }
2554                _ => panic!("expected Created, Running or Failed"),
2555            }
2556        }
2557        assert!(created.is_empty());
2558        assert_eq!(failed, 1);
2559        // ensure coords coverage for task 1
2560        for rank in 0..spec.extent.num_ranks() / 2 {
2561            let point = spec.extent.point_of_rank(rank).unwrap();
2562            assert!(proc_points.contains(&point));
2563        }
2564
2565        // ensure no more pending items
2566        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2567        tokio::select! {
2568            _ = hyperactor::clock::RealClock
2569            .sleep_until(timeout) => {},
2570            _ = alloc.next() => panic!("expected no more items"),
2571        }
2572
2573        // stop the allocation
2574        alloc.stop().await.unwrap();
2575        for _ in 0..spec.extent.num_ranks() / 2 {
2576            let proc_state = alloc.next().await.unwrap();
2577            tracing::info!("test received next proc_state: {:?}", proc_state);
2578            match proc_state {
2579                ProcState::Stopped {
2580                    create_key, reason, ..
2581                } => {
2582                    assert!(started_procs.remove(&create_key));
2583                    assert_eq!(reason, ProcStopReason::Stopped);
2584                }
2585                _ => panic!("expected stopped"),
2586            }
2587        }
2588        // Exactly one None
2589        let proc_state = alloc.next().await;
2590        assert!(proc_state.is_none());
2591        // Anything afterwards is None
2592        let proc_state = alloc.next().await;
2593        assert!(proc_state.is_none());
2594
2595        task1_allocator.terminate();
2596        task1_allocator_handle.await.unwrap();
2597        task2_allocator.terminate();
2598        task2_allocator_handle.await.unwrap();
2599    }
2600
2601    #[async_timed_test(timeout_secs = 60)]
2602    #[cfg(fbcode_build)]
2603    async fn test_remote_process_alloc_signal_handler() {
2604        hyperactor_telemetry::initialize_logging_for_test();
2605        let num_proc_meshes = 5;
2606        let hosts_per_proc_mesh = 5;
2607
2608        let pid_addr = ChannelAddr::any(ChannelTransport::Unix);
2609        let (pid_addr, mut pid_rx) = channel::serve::<u32>(pid_addr).unwrap();
2610
2611        let addresses = (0..(num_proc_meshes * hosts_per_proc_mesh))
2612            .map(|_| ChannelAddr::any(ChannelTransport::Unix).to_string())
2613            .collect::<Vec<_>>();
2614
2615        let remote_process_allocators = addresses
2616            .iter()
2617            .map(|addr| {
2618                Command::new(crate::testresource::get(
2619                    "monarch/hyperactor_mesh/remote_process_allocator",
2620                ))
2621                .env("RUST_LOG", "info")
2622                .arg(format!("--addr={addr}"))
2623                .stdout(std::process::Stdio::piped())
2624                .spawn()
2625                .unwrap()
2626            })
2627            .collect::<Vec<_>>();
2628
2629        let done_allocating_addr = ChannelAddr::any(ChannelTransport::Unix);
2630        let (done_allocating_addr, mut done_allocating_rx) =
2631            channel::serve::<()>(done_allocating_addr).unwrap();
2632        let mut remote_process_alloc = Command::new(crate::testresource::get(
2633            "monarch/hyperactor_mesh/remote_process_alloc",
2634        ))
2635        .arg(format!("--done-allocating-addr={}", done_allocating_addr))
2636        .arg(format!("--addresses={}", addresses.join(",")))
2637        .arg(format!("--num-proc-meshes={}", num_proc_meshes))
2638        .arg(format!("--hosts-per-proc-mesh={}", hosts_per_proc_mesh))
2639        .arg(format!("--pid-addr={}", pid_addr))
2640        .spawn()
2641        .unwrap();
2642
2643        done_allocating_rx.recv().await.unwrap();
2644        let mut received_pids = Vec::new();
2645        while let Ok(pid) = pid_rx.recv().await {
2646            received_pids.push(pid);
2647            if received_pids.len() == remote_process_allocators.len() {
2648                break;
2649            }
2650        }
2651
2652        signal::kill(
2653            Pid::from_raw(remote_process_alloc.id().unwrap() as i32),
2654            signal::SIGINT,
2655        )
2656        .unwrap();
2657
2658        assert_eq!(
2659            remote_process_alloc.wait().await.unwrap().signal(),
2660            Some(signal::SIGINT as i32)
2661        );
2662
2663        RealClock.sleep(tokio::time::Duration::from_secs(5)).await;
2664
2665        // Assert that the processes spawned by ProcessAllocator have been killed
2666        for child_pid in received_pids {
2667            let pid_check = Command::new("kill")
2668                .arg("-0")
2669                .arg(child_pid.to_string())
2670                .output()
2671                .await
2672                .expect("Failed to check if PID is alive");
2673
2674            assert!(
2675                !pid_check.status.success(),
2676                "PID {} should no longer be alive",
2677                child_pid
2678            );
2679        }
2680
2681        // Cleanup remote process allocator processes as SIGINT causes the current
2682        // allocs to stop but not the RemoteProcessAllocator loops
2683        for mut remote_process_allocator in remote_process_allocators {
2684            remote_process_allocator.kill().await.unwrap();
2685        }
2686    }
2687}