hyperactor_mesh/alloc/
remoteprocess.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::collections::HashSet;
11use std::collections::VecDeque;
12use std::sync::Arc;
13use std::time::Duration;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use dashmap::DashMap;
18use futures::FutureExt;
19use futures::future::join_all;
20use futures::future::select_all;
21use hyperactor::channel;
22use hyperactor::channel::ChannelAddr;
23use hyperactor::channel::ChannelRx;
24use hyperactor::channel::ChannelTransport;
25use hyperactor::channel::ChannelTx;
26use hyperactor::channel::Rx;
27use hyperactor::channel::TcpMode;
28use hyperactor::channel::Tx;
29use hyperactor::channel::TxStatus;
30use hyperactor::internal_macro_support::serde_json;
31use hyperactor::mailbox::DialMailboxRouter;
32use hyperactor::mailbox::MailboxServer;
33use hyperactor::observe_async;
34use hyperactor::observe_result;
35use mockall::automock;
36use ndslice::Region;
37use ndslice::Slice;
38use ndslice::View;
39use ndslice::ViewExt;
40use ndslice::view::Extent;
41use ndslice::view::Point;
42use serde::Deserialize;
43use serde::Serialize;
44use strum::AsRefStr;
45use tokio::io::AsyncWriteExt;
46use tokio::process::Command;
47use tokio::sync::mpsc::UnboundedReceiver;
48use tokio::sync::mpsc::UnboundedSender;
49use tokio::sync::mpsc::unbounded_channel;
50use tokio::task::JoinHandle;
51use tokio_stream::StreamExt;
52use tokio_stream::wrappers::WatchStream;
53use tokio_util::sync::CancellationToken;
54use typeuri::Named;
55
56use crate::alloc::Alloc;
57use crate::alloc::AllocConstraints;
58use crate::alloc::AllocName;
59use crate::alloc::AllocSpec;
60use crate::alloc::Allocator;
61use crate::alloc::AllocatorError;
62use crate::alloc::ProcState;
63use crate::alloc::ProcStopReason;
64use crate::alloc::ProcessAllocator;
65use crate::alloc::process::CLIENT_TRACE_ID_LABEL;
66use crate::alloc::process::ClientContext;
67use crate::alloc::with_unspecified_port_or_any;
68use crate::shortuuid::ShortUuid;
69
70/// Control messages sent from remote process allocator to local allocator.
71#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
72pub enum RemoteProcessAllocatorMessage {
73    /// Create allocation with given spec and send updates to bootstrap_addr.
74    Allocate {
75        /// The key used to identify this allocation.
76        alloc_key: ShortUuid,
77        /// The extent to allocate.
78        extent: Extent,
79        /// Bootstrap address to be used for sending updates.
80        bootstrap_addr: ChannelAddr,
81        /// Ordered list of hosts in this allocation. Can be used to
82        /// pre-populate the any local configurations such as torch.dist.
83        hosts: Vec<String>,
84        /// Client context which is passed to the ProcessAlloc
85        /// Todo: Once RemoteProcessAllocator moves to mailbox,
86        /// the client_context will go to the message header instead
87        client_context: Option<ClientContext>,
88        /// The address allocator should use for its forwarder.
89        forwarder_addr: ChannelAddr,
90    },
91    /// Stop allocation.
92    Stop,
93    /// Heartbeat message to check if remote process allocator and its
94    /// host are alive.
95    HeartBeat,
96}
97wirevalue::register_type!(RemoteProcessAllocatorMessage);
98
99/// Control message sent from local allocator to remote allocator
100/// relaying process state updates.
101/// AsRefStr allows us to log the values
102#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
103pub enum RemoteProcessProcStateMessage {
104    /// Allocation successful and Update, Done messages will follow.
105    Allocated {
106        alloc_key: ShortUuid,
107        alloc_name: AllocName,
108    },
109    /// ProcState updates.
110    Update(ShortUuid, ProcState),
111    /// Underlying Alloc is done.
112    Done(ShortUuid),
113    /// Heartbeat message to check if client is alive.
114    HeartBeat,
115}
116wirevalue::register_type!(RemoteProcessProcStateMessage);
117
118/// Allocator with a service frontend that wraps ProcessAllocator.
119pub struct RemoteProcessAllocator {
120    cancel_token: CancellationToken,
121}
122
123async fn conditional_sleeper<F: futures::Future<Output = ()>>(t: Option<F>) {
124    match t {
125        Some(timer) => timer.await,
126        None => futures::future::pending().await,
127    }
128}
129
130impl RemoteProcessAllocator {
131    /// Create a new allocator. It will not start until start() is called.
132    pub fn new() -> Arc<Self> {
133        Arc::new(Self {
134            cancel_token: CancellationToken::new(),
135        })
136    }
137
138    /// Stop the allocator. This will stop any ongoing allocations.
139    pub fn terminate(&self) {
140        self.cancel_token.cancel();
141    }
142
143    /// Start a remote process allocator with given cmd listening for
144    /// RemoteProcessAllocatorMessage on serve_addr. Call will block until cancelled.
145    /// The implementation is simple such that it can only handle one Alloc at
146    /// a time. Generally that's the most common use-case.
147    /// Flow works as follows:
148    /// 1. Client sends Allocate message to serve_addr.
149    /// 2. Allocator connects to bootstrap_addr, creates Alloc and sends Allocated message.
150    /// 3. Allocator streams one or more Update messages to bootstrap_addr as Alloc progresses
151    ///    making the following changes:
152    ///    * Remap mesh_agent listen address to our own forwarder actor address.
153    /// 4. Allocator sends Done message to bootstrap_addr when Alloc is done.
154    ///
155    /// At any point, client can send Stop message to serve_addr to stop the allocator.
156    /// If timeout is Some, the allocator will exit if no client connects within
157    /// that timeout, and no child allocation is running.
158    #[hyperactor::instrument]
159    pub async fn start(
160        &self,
161        cmd: Command,
162        serve_addr: ChannelAddr,
163        timeout: Option<Duration>,
164    ) -> Result<(), anyhow::Error> {
165        let process_allocator = ProcessAllocator::new(cmd);
166        self.start_with_allocator(serve_addr, process_allocator, timeout)
167            .await
168    }
169
170    /// Start a remote process allocator with given allocator listening for
171    /// RemoteProcessAllocatorMessage on serve_addr.
172    /// Used for testing.
173    #[hyperactor::instrument(fields(addr=serve_addr.to_string()))]
174    pub async fn start_with_allocator<A: Allocator + Send + Sync + 'static>(
175        &self,
176        serve_addr: ChannelAddr,
177        mut process_allocator: A,
178        timeout: Option<Duration>,
179    ) -> Result<(), anyhow::Error>
180    where
181        <A as Allocator>::Alloc: Send,
182        <A as Allocator>::Alloc: Sync,
183    {
184        tracing::info!("starting remote allocator on: {}", serve_addr);
185        let (_, mut rx) = channel::serve(serve_addr.clone()).map_err(anyhow::Error::from)?;
186
187        struct ActiveAllocation {
188            handle: JoinHandle<()>,
189            cancel_token: CancellationToken,
190        }
191        #[observe_async("RemoteProcessAllocator")]
192        async fn ensure_previous_alloc_stopped(active_allocation: &mut Option<ActiveAllocation>) {
193            if let Some(active_allocation) = active_allocation.take() {
194                tracing::info!("previous alloc found, stopping");
195                active_allocation.cancel_token.cancel();
196                match active_allocation.handle.await {
197                    Ok(_) => {
198                        // Todo, add named tracing for state change here
199                        tracing::info!("allocation stopped.")
200                    }
201                    Err(e) => {
202                        tracing::error!("allocation handler failed: {}", e);
203                    }
204                }
205            }
206        }
207
208        let mut active_allocation: Option<ActiveAllocation> = None;
209        loop {
210            // Refresh each loop iteration so the timer updates whenever a message
211            // is received.
212            let sleep = conditional_sleeper(timeout.map(|t| tokio::time::sleep(t)));
213            tokio::select! {
214                msg = rx.recv() => {
215                    match msg {
216                        Ok(RemoteProcessAllocatorMessage::Allocate {
217                            alloc_key,
218                            extent,
219                            bootstrap_addr,
220                            hosts,
221                            client_context,
222                            forwarder_addr,
223                        }) => {
224                            tracing::info!("received allocation request for {} with extent {}", alloc_key, extent);
225                            ensure_previous_alloc_stopped(&mut active_allocation).await;
226
227                            // Create the corresponding local allocation spec.
228                            let mut constraints: AllocConstraints = Default::default();
229                            if let Some(context) = &client_context {
230                                constraints = AllocConstraints {
231                                    match_labels: HashMap::from([(
232                                    CLIENT_TRACE_ID_LABEL.to_string(),
233                                    context.trace_id.to_string(),
234                                    )]
235                                )};
236                                tracing::info!(
237                                    monarch_client_trace_id = context.trace_id.to_string(),
238                                    "allocating...",
239                                );
240                            }
241
242
243                            let spec = AllocSpec {
244                                extent,
245                                constraints,
246                                proc_name: None, // TODO(meriksen, direct addressing): we need to pass the addressing mode here
247                                transport: ChannelTransport::Unix,
248                                proc_allocation_mode: Default::default(),
249                            };
250
251                            match process_allocator.allocate(spec.clone()).await {
252                                Ok(alloc) => {
253                                    let cancel_token = CancellationToken::new();
254                                    active_allocation = Some(ActiveAllocation {
255                                        cancel_token: cancel_token.clone(),
256                                        handle: tokio::spawn(Self::handle_allocation_request(
257                                            Box::new(alloc) as Box<dyn Alloc + Send + Sync>,
258                                            alloc_key,
259                                            bootstrap_addr,
260                                            hosts,
261                                            cancel_token,
262                                            forwarder_addr,
263                                        )),
264                                    })
265                                }
266                                Err(e) => {
267                                    tracing::error!("allocation for {:?} failed: {}", spec, e);
268                                    continue;
269                                }
270                            }
271                        }
272                        Ok(RemoteProcessAllocatorMessage::Stop) => {
273                            tracing::info!("received stop request");
274
275                            ensure_previous_alloc_stopped(&mut active_allocation).await;
276                        }
277                        // Hearbeat message is discarded immediately after being received, sender (client)
278                        // relies on channel ack to know if the receiver (remote process allocator) is
279                        // still alive. No state needs to be updated.
280                        Ok(RemoteProcessAllocatorMessage::HeartBeat) => {}
281                        Err(e) => {
282                            tracing::error!("upstream channel error: {}", e);
283                            continue;
284                        }
285                    }
286                }
287                _ = self.cancel_token.cancelled() => {
288                    tracing::info!("main loop cancelled");
289
290                    ensure_previous_alloc_stopped(&mut active_allocation).await;
291
292                    break;
293                }
294                _ = sleep => {
295                    // If there are any active allocations, reset the timeout.
296                    if active_allocation.is_some() {
297                        continue;
298                    }
299                    // Else, exit the loop as a client hasn't connected in a reasonable
300                    // amount of time.
301                    tracing::warn!("timeout of {} seconds elapsed without any allocations, exiting", timeout.unwrap_or_default().as_secs());
302                    break;
303                }
304            }
305        }
306
307        Ok(())
308    }
309
310    #[tracing::instrument(skip(alloc, cancel_token))]
311    #[observe_async("RemoteProcessAllocator")]
312    async fn handle_allocation_request(
313        alloc: Box<dyn Alloc + Send + Sync>,
314        alloc_key: ShortUuid,
315        bootstrap_addr: ChannelAddr,
316        hosts: Vec<String>,
317        cancel_token: CancellationToken,
318        forwarder_addr: ChannelAddr,
319    ) {
320        tracing::info!("handle allocation request, bootstrap_addr: {bootstrap_addr}");
321        // start proc message forwarder
322        let (forwarder_addr, forwarder_rx) = match channel::serve(forwarder_addr) {
323            Ok(v) => v,
324            Err(e) => {
325                tracing::error!("failed to to bootstrap forwarder actor: {}", e);
326                return;
327            }
328        };
329        let router = DialMailboxRouter::new();
330        let mailbox_handle = router.clone().serve(forwarder_rx);
331        tracing::info!("started forwarder on: {}", forwarder_addr);
332
333        // Check if we need to write TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE
334        // See: https://github.com/fairinternal/xlformers/blob/llama4_monarch/tools/launching/torchx/entrypoint/generate_ranks.py
335        if let Ok(hosts_file) = std::env::var("TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE") {
336            tracing::info!("writing hosts to {}", hosts_file);
337            #[derive(Serialize)]
338            struct Hosts {
339                hostnames: Vec<String>,
340            }
341            match serde_json::to_string(&Hosts { hostnames: hosts }) {
342                Ok(json) => match tokio::fs::File::create(&hosts_file).await {
343                    Ok(mut file) => {
344                        if file.write_all(json.as_bytes()).await.is_err() {
345                            tracing::error!("failed to write hosts to {}", hosts_file);
346                            return;
347                        }
348                    }
349                    Err(e) => {
350                        tracing::error!("failed to open hosts file {}: {}", hosts_file, e);
351                        return;
352                    }
353                },
354                Err(e) => {
355                    tracing::error!("failed to serialize hosts: {}", e);
356                    return;
357                }
358            }
359        }
360
361        Self::handle_allocation_loop(
362            alloc,
363            alloc_key,
364            bootstrap_addr,
365            router,
366            forwarder_addr,
367            cancel_token,
368        )
369        .await;
370
371        mailbox_handle.stop("alloc stopped");
372        if let Err(e) = mailbox_handle.await {
373            tracing::error!("failed to join forwarder: {}", e);
374        }
375    }
376
377    async fn handle_allocation_loop(
378        mut alloc: Box<dyn Alloc + Send + Sync>,
379        alloc_key: ShortUuid,
380        bootstrap_addr: ChannelAddr,
381        router: DialMailboxRouter,
382        forward_addr: ChannelAddr,
383        cancel_token: CancellationToken,
384    ) {
385        let alloc_name = alloc.alloc_name().clone();
386        tracing::info!("starting handle allocation loop for {}", alloc_name);
387        let tx = match channel::dial(bootstrap_addr) {
388            Ok(tx) => tx,
389            Err(err) => {
390                tracing::error!("failed to dial bootstrap address: {}", err);
391                return;
392            }
393        };
394        let message = RemoteProcessProcStateMessage::Allocated {
395            alloc_key: alloc_key.clone(),
396            alloc_name,
397        };
398        tracing::info!(name = message.as_ref(), "sending allocated message",);
399        if let Err(e) = tx.send(message).await {
400            tracing::error!("failed to send Allocated message: {}", e);
401            return;
402        }
403
404        let mut mesh_agents_by_create_key = HashMap::new();
405        let mut running = true;
406        let tx_status = tx.status().clone();
407        let mut tx_watcher = WatchStream::new(tx_status);
408        loop {
409            tokio::select! {
410                _ = cancel_token.cancelled(), if running => {
411                    tracing::info!("cancelled, stopping allocation");
412                    running = false;
413                    if let Err(e) = alloc.stop().await {
414                        tracing::error!("stop failed: {}", e);
415                        break;
416                    }
417                }
418                status = tx_watcher.next(), if running => {
419                    match status  {
420                        Some(TxStatus::Closed) => {
421                            tracing::error!("upstream channel state closed");
422                            break;
423                        },
424                        _ => {
425                            tracing::debug!("got channel event: {:?}", status.unwrap());
426                            continue;
427                        }
428                    }
429                }
430                e = alloc.next() => {
431                    match e {
432                        Some(event) => {
433                            tracing::debug!(name = event.as_ref(), "got event: {:?}", event);
434                            let event = match event {
435                                ProcState::Created { .. } => event,
436                                ProcState::Running { create_key, proc_id, mesh_agent, addr } => {
437                                    // TODO(meriksen, direct addressing): disable remapping in direct addressing mode
438                                    tracing::debug!("remapping mesh_agent {}: addr {} -> {}", mesh_agent, addr, forward_addr);
439                                    mesh_agents_by_create_key.insert(create_key.clone(), mesh_agent.clone());
440                                    router.bind(mesh_agent.actor_id().proc_id().clone().into(), addr);
441                                    ProcState::Running { create_key, proc_id, mesh_agent, addr: forward_addr.clone() }
442                                },
443                                ProcState::Stopped { create_key, reason } => {
444                                    match mesh_agents_by_create_key.remove(&create_key) {
445                                        Some(mesh_agent) => {
446                                            tracing::debug!("unmapping mesh_agent {}", mesh_agent);
447                                            let agent_ref: hyperactor::reference::Reference = mesh_agent.actor_id().proc_id().clone().into();
448                                            router.unbind(&agent_ref);
449                                        },
450                                        None => {
451                                            tracing::warn!("mesh_agent not found for create key {}", create_key);
452                                        }
453                                    }
454                                    ProcState::Stopped { create_key, reason }
455                                },
456                                ProcState::Failed { ref alloc_name, ref description } => {
457                                    tracing::error!("allocation failed for {}: {}", alloc_name, description);
458                                    event
459                                }
460                            };
461                            tracing::debug!(name = event.as_ref(), "sending event: {:?}", event);
462                            tx.post(RemoteProcessProcStateMessage::Update(alloc_key.clone(), event));
463                        }
464                        None => {
465                            tracing::debug!("sending done");
466                            // Use send().await (not post()) to ensure Done is
467                            // delivered before this task completes. This
468                            // guarantees all prior posted messages (Stopped
469                            // updates) have been flushed through this
470                            // connection's FIFO buffer, so downstream consumers
471                            // see Done before Allocated from a subsequent
472                            // allocation.
473                            if let Err(e) = tx.send(RemoteProcessProcStateMessage::Done(alloc_key.clone())).await {
474                                tracing::error!("failed to send Done message: {}", e);
475                            }
476                            running = false;
477                            break;
478                        }
479                    }
480                }
481                _ = tokio::time::sleep(hyperactor_config::global::get(hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)) => {
482                    tracing::trace!("sending heartbeat");
483                    tx.post(RemoteProcessProcStateMessage::HeartBeat);
484                }
485            }
486        }
487        tracing::info!("allocation handler loop exited");
488        if running {
489            tracing::info!("stopping processes");
490            if let Err(e) = alloc.stop_and_wait().await {
491                tracing::error!("stop failed: {}", e);
492                return;
493            }
494            tracing::info!("stop finished");
495        }
496    }
497}
498
499/// HostId is a string-based identifier of the host. It may be the same
500/// as the hostname but in some situations a separate ID may be used.
501type HostId = String;
502
503/// A host entry passed down from the initializer.
504#[derive(Clone)]
505pub struct RemoteProcessAllocHost {
506    /// A unique identifier of that host. Hostname may be used by ID can
507    /// be different in some situations.
508    pub id: HostId,
509    /// The FQDN of the host.
510    pub hostname: String,
511}
512
513/// State of a host in the RemoteProcessAlloc.
514struct RemoteProcessAllocHostState {
515    /// The allocation key used to identify the host.
516    alloc_key: ShortUuid,
517    /// The host ID of the remote host.
518    host_id: HostId,
519    /// TX channel to the remote host allocator.
520    tx: ChannelTx<RemoteProcessAllocatorMessage>,
521    /// Set of active processes on this host.
522    active_procs: HashSet<ShortUuid>,
523    /// Region allocated by host.
524    region: Region,
525    /// Alloc name for this host as indicated from Allocated message.
526    alloc_name: Option<AllocName>,
527    /// If remote allocator sent us ProcState::Failed.
528    failed: bool,
529    /// If remote allocater has ever allocated a proc.
530    allocated: bool,
531}
532
533#[automock]
534#[async_trait]
535/// Interface to provide the set of hosts to be used by RemoteProcessAlloc.
536pub trait RemoteProcessAllocInitializer {
537    /// Initializes and returns a list of hosts to be used by this RemoteProcessAlloc.
538    async fn initialize_alloc(&mut self) -> Result<Vec<RemoteProcessAllocHost>, anyhow::Error>;
539}
540
541/// Wrapper struct around `HashMap<HostId, RemoteProcessAllocHostState>`
542/// to ensure that host addresses are synced with the signal handler
543struct HostStates {
544    inner: HashMap<HostId, RemoteProcessAllocHostState>,
545    host_addresses: Arc<DashMap<HostId, ChannelAddr>>,
546}
547
548impl HostStates {
549    fn new(host_addresses: Arc<DashMap<HostId, ChannelAddr>>) -> HostStates {
550        Self {
551            inner: HashMap::new(),
552            host_addresses,
553        }
554    }
555
556    fn insert(
557        &mut self,
558        host_id: HostId,
559        state: RemoteProcessAllocHostState,
560        address: ChannelAddr,
561    ) {
562        self.host_addresses.insert(host_id.clone(), address);
563        self.inner.insert(host_id, state);
564    }
565
566    fn get(&self, host_id: &HostId) -> Option<&RemoteProcessAllocHostState> {
567        self.inner.get(host_id)
568    }
569
570    fn get_mut(&mut self, host_id: &HostId) -> Option<&mut RemoteProcessAllocHostState> {
571        self.inner.get_mut(host_id)
572    }
573
574    fn remove(&mut self, host_id: &HostId) -> Option<RemoteProcessAllocHostState> {
575        self.host_addresses.remove(host_id);
576        self.inner.remove(host_id)
577    }
578
579    fn iter(&self) -> impl Iterator<Item = (&HostId, &RemoteProcessAllocHostState)> {
580        self.inner.iter()
581    }
582
583    fn iter_mut(&mut self) -> impl Iterator<Item = (&HostId, &mut RemoteProcessAllocHostState)> {
584        self.inner.iter_mut()
585    }
586
587    fn is_empty(&self) -> bool {
588        self.inner.is_empty()
589    }
590    // Any missing HashMap methods should be added here as needed
591}
592
593/// A generalized implementation of an Alloc using one or more hosts running
594/// RemoteProcessAlloc for process allocation.
595pub struct RemoteProcessAlloc {
596    // The initializer to be called at the first next() invocation to obtain
597    // allocated hosts.
598    initializer: Box<dyn RemoteProcessAllocInitializer + Send + Sync>,
599    spec: AllocSpec,
600    remote_allocator_port: u16,
601    alloc_name: AllocName,
602    ordered_hosts: Vec<RemoteProcessAllocHost>,
603    // Indicates that the initial remote allocation requests have been sent.
604    started: bool,
605    // Indicates that this Alloc is active (we have at least one remote process running).
606    running: bool,
607    // Inidicates that the allocation process has permanently failed.
608    failed: bool,
609    // Maps the alloc key to the host.
610    alloc_to_host: HashMap<ShortUuid, HostId>,
611    host_states: HostStates,
612    event_queue: VecDeque<ProcState>,
613    comm_watcher_tx: UnboundedSender<HostId>,
614    comm_watcher_rx: UnboundedReceiver<HostId>,
615
616    bootstrap_addr: ChannelAddr,
617    rx: ChannelRx<RemoteProcessProcStateMessage>,
618    _signal_cleanup_guard: hyperactor::SignalCleanupGuard,
619}
620
621impl RemoteProcessAlloc {
622    /// Create a new Alloc. initializer will be called on the first invocation of next()
623    /// to obtain a list of allocate hosts. Then Allocate message will be sent to all
624    /// RemoteProcessAllocator on all hosts. Heartbeats will be used to maintain health
625    /// status of remote hosts.
626    #[tracing::instrument(skip(initializer))]
627    #[observe_result("RemoteProcessAlloc")]
628    pub async fn new(
629        spec: AllocSpec,
630        alloc_name: AllocName,
631        remote_allocator_port: u16,
632        initializer: impl RemoteProcessAllocInitializer + Send + Sync + 'static,
633    ) -> Result<Self, anyhow::Error> {
634        let alloc_serve_addr = ChannelAddr::any(spec.transport.clone());
635
636        let (bootstrap_addr, rx) = channel::serve(alloc_serve_addr)?;
637
638        tracing::info!(
639            "starting alloc for {} on: {}",
640            alloc_name,
641            bootstrap_addr.clone()
642        );
643
644        let (comm_watcher_tx, comm_watcher_rx) = unbounded_channel();
645
646        let host_addresses = Arc::new(DashMap::<HostId, ChannelAddr>::new());
647        let host_addresses_for_signal = host_addresses.clone();
648
649        // Register cleanup callback with global signal manager
650        let signal_cleanup_guard =
651            hyperactor::register_signal_cleanup_scoped(Box::pin(async move {
652                join_all(host_addresses_for_signal.iter().map(|entry| async move {
653                    let addr = entry.value().clone();
654                    match channel::dial(addr.clone()) {
655                        Ok(tx) => {
656                            if let Err(e) = tx.send(RemoteProcessAllocatorMessage::Stop).await {
657                                tracing::error!("Failed to send Stop to {}: {}", addr, e);
658                            }
659                        }
660                        Err(e) => {
661                            tracing::error!("Failed to dial {} during signal cleanup: {}", addr, e);
662                        }
663                    }
664                }))
665                .await;
666            }));
667
668        Ok(Self {
669            spec,
670            alloc_name,
671            remote_allocator_port,
672            initializer: Box::new(initializer),
673            ordered_hosts: Vec::new(),
674            alloc_to_host: HashMap::new(),
675            host_states: HostStates::new(host_addresses),
676            bootstrap_addr,
677            event_queue: VecDeque::new(),
678            comm_watcher_tx,
679            comm_watcher_rx,
680            rx,
681            started: false,
682            running: true,
683            failed: false,
684            _signal_cleanup_guard: signal_cleanup_guard,
685        })
686    }
687
688    /// Start a RemoteProcessAllocator tx watcher. It spawns a task that monitor the underlying
689    /// TxStatus and repot to the main next() loop via comm_watcher_tx. It does not however
690    /// send the actual heartbteats. Just watches the status.
691    /// The task will terminate once comm_watcher_rx is closed.
692    ///
693    /// A task approach was used here as opposed to select! to avoid nesting complexities with
694    /// select!() and select_all().
695    async fn start_comm_watcher(&self) {
696        let mut tx_watchers = Vec::new();
697        for host in &self.ordered_hosts {
698            let tx_status = self.host_states.get(&host.id).unwrap().tx.status().clone();
699            let watcher = WatchStream::new(tx_status);
700            tx_watchers.push((watcher, host.id.clone()));
701        }
702        assert!(!tx_watchers.is_empty());
703        let tx = self.comm_watcher_tx.clone();
704        tokio::spawn(async move {
705            loop {
706                let mut tx_status_futures = Vec::new();
707                for (watcher, _) in &mut tx_watchers {
708                    let fut = watcher.next().boxed();
709                    tx_status_futures.push(fut);
710                }
711                let (tx_status, index, _) = select_all(tx_status_futures).await;
712                let host_id = match tx_watchers.get(index) {
713                    Some((_, host_id)) => host_id.clone(),
714                    None => {
715                        // Should never happen
716                        tracing::error!(
717                            "got selected index {} with no matching host in {}",
718                            index,
719                            tx_watchers.len()
720                        );
721                        continue;
722                    }
723                };
724                if let Some(tx_status) = tx_status {
725                    tracing::debug!("host {} channel event: {:?}", host_id, tx_status);
726                    if tx_status == TxStatus::Closed {
727                        if tx.send(host_id.clone()).is_err() {
728                            // other side closed
729                            break;
730                        }
731                        tx_watchers.remove(index);
732                        if tx_watchers.is_empty() {
733                            // All of the statuses have been closed, exit the loop.
734                            break;
735                        }
736                    }
737                }
738            }
739        });
740    }
741
742    /// Ensure that we have the list of allocated hosts and that we send Allocate
743    /// request to all o fthem. Once done, start_comm_watcher() is called to start
744    /// the channel watcher.
745    /// Function is idempotent.
746    async fn ensure_started(&mut self) -> Result<(), anyhow::Error> {
747        if self.started || self.failed {
748            return Ok(());
749        }
750
751        self.started = true;
752        let hosts = self
753            .initializer
754            .initialize_alloc()
755            .await
756            .context("alloc initializer error")?;
757        if hosts.is_empty() {
758            anyhow::bail!("initializer returned empty list of hosts");
759        }
760        // prepare a list of host names in this allocation to be sent
761        // to remote allocators.
762        let hostnames: Vec<_> = hosts.iter().map(|e| e.hostname.clone()).collect();
763        tracing::info!("obtained {} hosts for this allocation", hostnames.len());
764
765        // Split the extent into regions, one per host.
766        use crate::alloc::ProcAllocationMode;
767
768        // For HostLevel, pre-compute regions. For ProcLevel, skip this step.
769        let regions: Option<Vec<_>> = match self.spec.proc_allocation_mode {
770            ProcAllocationMode::ProcLevel => {
771                // We require at least a dimension for hosts, and one for sub-host (e.g., GPUs)
772                anyhow::ensure!(
773                    self.spec.extent.len() >= 2,
774                    "invalid extent: {}, expected at least 2 dimensions",
775                    self.spec.extent
776                );
777                None
778            }
779            ProcAllocationMode::HostLevel => Some({
780                // HostLevel: each point is a host, create a region for each point
781                let num_points = self.spec.extent.num_ranks();
782                anyhow::ensure!(
783                    hosts.len() >= num_points,
784                    "HostLevel allocation mode requires {} hosts (one per point in extent {}), but only {} hosts were provided",
785                    num_points,
786                    self.spec.extent,
787                    hosts.len()
788                );
789
790                // For HostLevel, create a single-point region for each rank
791                // Each region contains one point that maps to the correct global rank
792                let labels = self.spec.extent.labels().to_vec();
793
794                // Compute strides for row-major layout: strides[i] = product of sizes[i+1..n]
795                let extent_sizes = self.spec.extent.sizes();
796                let mut parent_strides = vec![1; extent_sizes.len()];
797                for i in (0..extent_sizes.len() - 1).rev() {
798                    parent_strides[i] = parent_strides[i + 1] * extent_sizes[i + 1];
799                }
800
801                (0..num_points)
802                    .map(|rank| {
803                        // Create a slice containing only this rank
804                        // Use parent's strides so local point [0,0,...] maps to the correct global rank
805                        let sizes = vec![1; labels.len()];
806                        Region::new(
807                            labels.clone(),
808                            Slice::new(rank, sizes, parent_strides.clone()).unwrap(),
809                        )
810                    })
811                    .collect()
812            }),
813        };
814
815        match self.spec.proc_allocation_mode {
816            ProcAllocationMode::ProcLevel => {
817                // We group by the innermost dimension of the extent.
818                let split_dim = &self.spec.extent.labels()[self.spec.extent.len() - 1];
819                for (i, region) in self.spec.extent.group_by(split_dim)?.enumerate() {
820                    let host = &hosts[i];
821                    tracing::debug!("allocating: {} for host: {}", region, host.id);
822
823                    let remote_addr = match self.spec.transport {
824                        ChannelTransport::MetaTls(_) => {
825                            format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
826                        }
827                        ChannelTransport::Tcp(TcpMode::Localhost) => {
828                            // TODO: @rusch see about moving over to config for this
829                            format!("tcp![::1]:{}", self.remote_allocator_port)
830                        }
831                        ChannelTransport::Tcp(TcpMode::Hostname) => {
832                            format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
833                        }
834                        // Used only for testing.
835                        ChannelTransport::Unix => host.hostname.clone(),
836                        _ => {
837                            anyhow::bail!(
838                                "unsupported transport for host {}: {:?}",
839                                host.id,
840                                self.spec.transport,
841                            );
842                        }
843                    };
844
845                    tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
846                    let remote_addr = remote_addr.parse::<ChannelAddr>()?;
847                    let tx = channel::dial(remote_addr.clone())
848                        .map_err(anyhow::Error::from)
849                        .context(format!(
850                            "failed to dial remote {} for host {}",
851                            remote_addr, host.id
852                        ))?;
853
854                    // Possibly we could use the HostId directly here.
855                    let alloc_key = ShortUuid::generate();
856                    assert!(
857                        self.alloc_to_host
858                            .insert(alloc_key.clone(), host.id.clone())
859                            .is_none()
860                    );
861
862                    let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
863                    let client_context = Some(ClientContext { trace_id });
864                    let message = RemoteProcessAllocatorMessage::Allocate {
865                        alloc_key: alloc_key.clone(),
866                        extent: region.extent(),
867                        bootstrap_addr: self.bootstrap_addr.clone(),
868                        hosts: hostnames.clone(),
869                        client_context,
870                        // Make sure allocator's forwarder uses the same IP address
871                        // which is known to alloc. This is to avoid allocator picks
872                        // its host's private IP address, while its known addres to
873                        // alloc is a public IP address. In some environment, that
874                        // could lead to port unreachable error.
875                        forwarder_addr: with_unspecified_port_or_any(&remote_addr),
876                    };
877                    tracing::info!(
878                        name = message.as_ref(),
879                        "sending allocate message to workers"
880                    );
881                    tx.post(message);
882
883                    self.host_states.insert(
884                        host.id.clone(),
885                        RemoteProcessAllocHostState {
886                            alloc_key,
887                            host_id: host.id.clone(),
888                            tx,
889                            active_procs: HashSet::new(),
890                            region,
891                            alloc_name: None,
892                            failed: false,
893                            allocated: false,
894                        },
895                        remote_addr,
896                    );
897                }
898
899                self.ordered_hosts = hosts;
900            }
901            ProcAllocationMode::HostLevel => {
902                let regions = regions.unwrap();
903                let num_regions = regions.len();
904                for (i, region) in regions.into_iter().enumerate() {
905                    let host = &hosts[i];
906                    tracing::debug!("allocating: {} for host: {}", region, host.id);
907
908                    let remote_addr = match self.spec.transport {
909                        ChannelTransport::MetaTls(_) => {
910                            format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
911                        }
912                        ChannelTransport::Tcp(TcpMode::Localhost) => {
913                            // TODO: @rusch see about moving over to config for this
914                            format!("tcp![::1]:{}", self.remote_allocator_port)
915                        }
916                        ChannelTransport::Tcp(TcpMode::Hostname) => {
917                            format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
918                        }
919                        // Used only for testing.
920                        ChannelTransport::Unix => host.hostname.clone(),
921                        _ => {
922                            anyhow::bail!(
923                                "unsupported transport for host {}: {:?}",
924                                host.id,
925                                self.spec.transport,
926                            );
927                        }
928                    };
929
930                    tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
931                    let remote_addr = remote_addr.parse::<ChannelAddr>()?;
932                    let tx = channel::dial(remote_addr.clone())
933                        .map_err(anyhow::Error::from)
934                        .context(format!(
935                            "failed to dial remote {} for host {}",
936                            remote_addr, host.id
937                        ))?;
938
939                    // Possibly we could use the HostId directly here.
940                    let alloc_key = ShortUuid::generate();
941                    assert!(
942                        self.alloc_to_host
943                            .insert(alloc_key.clone(), host.id.clone())
944                            .is_none()
945                    );
946
947                    let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
948                    let client_context = Some(ClientContext { trace_id });
949                    let message = RemoteProcessAllocatorMessage::Allocate {
950                        alloc_key: alloc_key.clone(),
951                        extent: region.extent(),
952                        bootstrap_addr: self.bootstrap_addr.clone(),
953                        hosts: hostnames.clone(),
954                        client_context,
955                        // Make sure allocator's forwarder uses the same IP address
956                        // which is known to alloc. This is to avoid allocator picks
957                        // its host's private IP address, while its known addres to
958                        // alloc is a public IP address. In some environment, that
959                        // could lead to port unreachable error.
960                        forwarder_addr: with_unspecified_port_or_any(&remote_addr),
961                    };
962                    tracing::info!(
963                        name = message.as_ref(),
964                        "sending allocate message to workers"
965                    );
966                    tx.post(message);
967
968                    self.host_states.insert(
969                        host.id.clone(),
970                        RemoteProcessAllocHostState {
971                            alloc_key,
972                            host_id: host.id.clone(),
973                            tx,
974                            active_procs: HashSet::new(),
975                            region,
976                            alloc_name: None,
977                            failed: false,
978                            allocated: false,
979                        },
980                        remote_addr,
981                    );
982                }
983
984                // Only store hosts that were actually used for regions
985                // If num_regions < hosts.len(), we only use the first num_regions hosts
986                self.ordered_hosts = hosts.into_iter().take(num_regions).collect();
987            }
988        }
989        self.start_comm_watcher().await;
990        self.started = true;
991
992        Ok(())
993    }
994
995    // Given a proc_id, obtain the internal HostState structure.
996    fn get_host_state_mut(
997        &mut self,
998        alloc_key: &ShortUuid,
999    ) -> Result<&mut RemoteProcessAllocHostState, anyhow::Error> {
1000        let host_id: &HostId = self
1001            .alloc_to_host
1002            .get(alloc_key)
1003            .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1004
1005        self.host_states
1006            .get_mut(host_id)
1007            .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1008    }
1009
1010    // Given a proc_id, obtain the internal HostState structure.
1011    fn get_host_state(
1012        &self,
1013        alloc_key: &ShortUuid,
1014    ) -> Result<&RemoteProcessAllocHostState, anyhow::Error> {
1015        let host_id: &HostId = self
1016            .alloc_to_host
1017            .get(alloc_key)
1018            .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1019
1020        self.host_states
1021            .get(host_id)
1022            .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1023    }
1024
1025    fn remove_host_state(
1026        &mut self,
1027        alloc_key: &ShortUuid,
1028    ) -> Result<RemoteProcessAllocHostState, anyhow::Error> {
1029        let host_id: &HostId = self
1030            .alloc_to_host
1031            .get(alloc_key)
1032            .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1033
1034        self.host_states
1035            .remove(host_id)
1036            .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1037    }
1038
1039    fn add_proc_id_to_host_state(
1040        &mut self,
1041        alloc_key: &ShortUuid,
1042        create_key: &ShortUuid,
1043    ) -> Result<(), anyhow::Error> {
1044        let task_state = self.get_host_state_mut(alloc_key)?;
1045        if !task_state.active_procs.insert(create_key.clone()) {
1046            // Should not happen but we can ignore
1047            tracing::error!("proc with create key {} already in host state", create_key);
1048        }
1049        task_state.allocated = true;
1050        Ok(())
1051    }
1052
1053    fn remove_proc_from_host_state(
1054        &mut self,
1055        alloc_key: &ShortUuid,
1056        create_key: &ShortUuid,
1057    ) -> Result<(), anyhow::Error> {
1058        let task_state = self.get_host_state_mut(alloc_key)?;
1059        if !task_state.active_procs.remove(create_key) {
1060            // Should not happen but we can ignore
1061            tracing::error!("proc with create_key already in host state: {}", create_key);
1062        }
1063        Ok(())
1064    }
1065
1066    // Reproject proc world coords to global shape coords.
1067    fn project_proc_into_global_extent(
1068        &self,
1069        alloc_key: &ShortUuid,
1070        point: &Point,
1071    ) -> Result<Point, anyhow::Error> {
1072        let global_rank = self
1073            .get_host_state(alloc_key)?
1074            .region
1075            .get(point.rank())
1076            .ok_or_else(|| {
1077                anyhow::anyhow!(
1078                    "rank {} out of bounds for in alloc {}",
1079                    point.rank(),
1080                    alloc_key
1081                )
1082            })?;
1083        Ok(self.spec.extent.point_of_rank(global_rank)?)
1084    }
1085
1086    // Cleanup a comm-failed host information by its ID.
1087    fn cleanup_host_channel_closed(
1088        &mut self,
1089        host_id: HostId,
1090    ) -> Result<Vec<ShortUuid>, anyhow::Error> {
1091        let state = match self.host_states.remove(&host_id) {
1092            Some(state) => state,
1093            None => {
1094                // this should never happen.
1095                anyhow::bail!(
1096                    "got channel closed event for host {} which has no known state",
1097                    host_id
1098                );
1099            }
1100        };
1101        self.ordered_hosts.retain(|host| host.id != host_id);
1102        self.alloc_to_host.remove(&state.alloc_key);
1103        let create_keys = state.active_procs.iter().cloned().collect();
1104
1105        Ok(create_keys)
1106    }
1107}
1108
1109#[async_trait]
1110impl Alloc for RemoteProcessAlloc {
1111    async fn next(&mut self) -> Option<ProcState> {
1112        loop {
1113            if let state @ Some(_) = self.event_queue.pop_front() {
1114                break state;
1115            }
1116
1117            if !self.running {
1118                break None;
1119            }
1120
1121            if let Err(e) = self.ensure_started().await {
1122                break Some(ProcState::Failed {
1123                    alloc_name: self.alloc_name.clone(),
1124                    description: format!("failed to ensure started: {:#}", e),
1125                });
1126            }
1127
1128            let heartbeat_interval = hyperactor_config::global::get(
1129                hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1130            );
1131            let mut heartbeat_time = tokio::time::Instant::now() + heartbeat_interval;
1132            // rerun outer loop in case we pushed new items to the event queue
1133            let mut reloop = false;
1134            let update = loop {
1135                tokio::select! {
1136                    msg = self.rx.recv() => {
1137                        tracing::debug!("got ProcState message from allocator: {:?}", msg);
1138                        match msg {
1139                            Ok(RemoteProcessProcStateMessage::Allocated { alloc_key, alloc_name }) => {
1140                                tracing::info!("remote alloc {}: allocated", alloc_key);
1141                                match self.get_host_state_mut(&alloc_key) {
1142                                    Ok(state) => {
1143                                        state.alloc_name = Some(alloc_name.clone());
1144                                    }
1145                                    Err(err) => {
1146                                        // should never happenA
1147                                        tracing::error!(
1148                                            "received allocated message alloc: {} with no known state: {}",
1149                                            alloc_key, err,
1150                                        );
1151                                    }
1152                                }
1153                            }
1154                            Ok(RemoteProcessProcStateMessage::Update(alloc_key, proc_state)) => {
1155                                let update = match proc_state {
1156                                    ProcState::Created { ref create_key, .. } => {
1157                                        if let Err(e) = self.add_proc_id_to_host_state(&alloc_key, create_key) {
1158                                            tracing::error!("failed to add proc with create key {} host state: {}", create_key, e);
1159                                        }
1160                                        proc_state
1161                                    }
1162                                    ProcState::Stopped{ ref create_key, ..} => {
1163                                        if let Err(e) = self.remove_proc_from_host_state(&alloc_key, create_key) {
1164                                            tracing::error!("failed to remove proc with create key {} host state: {}", create_key, e);
1165                                        }
1166                                        proc_state
1167                                    }
1168                                    ProcState::Failed { ref alloc_name, ref description } => {
1169                                        match self.get_host_state_mut(&alloc_key) {
1170                                            Ok(state) => {
1171                                                state.failed = true;
1172                                                ProcState::Failed {
1173                                                    alloc_name: alloc_name.clone(),
1174                                                    description: format!("host {} failed: {}", state.host_id, description),
1175                                                }
1176                                            }
1177                                            Err(e) => {
1178                                                tracing::error!("failed to find host state for alloc: {}: {}", alloc_name, e);
1179                                                proc_state
1180                                            }
1181                                        }
1182                                    }
1183                                    _ => proc_state
1184                                };
1185
1186                                break Some((Some(alloc_key), update));
1187                            }
1188                            Ok(RemoteProcessProcStateMessage::Done(alloc_key)) => {
1189                                tracing::info!("allocator {} is done", alloc_key);
1190
1191                                if let Ok(state) = self.remove_host_state(&alloc_key) {
1192                                    if !state.active_procs.is_empty() {
1193                                        tracing::error!("received done for alloc {} with active procs: {:?}", alloc_key, state.active_procs);
1194                                    }
1195                                } else {
1196                                    tracing::error!("received done for unknown alloc {}", alloc_key);
1197                                }
1198
1199                                if self.host_states.is_empty() {
1200                                    self.running = false;
1201                                    break None;
1202                                }
1203                            }
1204                            // Hearbeat message is discarded immediately after being received, sender (remote
1205                            // process allocator) relies on channel ack to know if the receiver (client) is
1206                            // still alive. No state needs to be updated.
1207                            Ok(RemoteProcessProcStateMessage::HeartBeat) => {}
1208                            Err(e) => {
1209                                break Some((None, ProcState::Failed {alloc_name: self.alloc_name.clone(), description: format!("error receiving events: {}", e)}));
1210                            }
1211                        }
1212                    }
1213
1214                    _ = tokio::time::sleep_until(heartbeat_time) => {
1215                        self.host_states.iter().for_each(|(_, host_state)| host_state.tx.post(RemoteProcessAllocatorMessage::HeartBeat));
1216                        heartbeat_time = tokio::time::Instant::now() + heartbeat_interval;
1217                    }
1218
1219                    closed_host_id = self.comm_watcher_rx.recv() => {
1220                        if let Some(closed_host_id) = closed_host_id {
1221                            tracing::debug!("host {} channel closed, cleaning up", closed_host_id);
1222                            if let Some(state) = self.host_states.get(&closed_host_id)
1223                                && !state.allocated {
1224                                    break Some((None, ProcState::Failed {
1225                                        alloc_name: self.alloc_name.clone(),
1226                                        description: format!(
1227                                            "no process has ever been allocated on {}, connection is not established; \
1228                                            a common issue could be allocator port mismatch, or transport not enabled on client",
1229                                            closed_host_id
1230                                        )}));
1231                                }
1232                            let create_keys = match self.cleanup_host_channel_closed(closed_host_id) {
1233                                Ok(create_keys) => create_keys,
1234                                Err(err) => {
1235                                    tracing::error!("failed to cleanup disconnected host: {}", err);
1236                                    continue;
1237                                }
1238                            };
1239                            for create_key in create_keys {
1240                                tracing::debug!("queuing Stopped state for proc with create key {}", create_key);
1241                                self.event_queue.push_back(
1242                                    ProcState::Stopped {
1243                                        create_key,
1244                                        reason: ProcStopReason::HostWatchdog
1245                                    }
1246                                );
1247                            }
1248                            // Check if there are any hosts left
1249                            if self.host_states.is_empty() {
1250                                tracing::info!("no more hosts left, stopping the alloc");
1251                                self.running = false;
1252                            }
1253                            // Kick back to the outer loop to pop off the queue if necessary
1254                            reloop = true;
1255                            break None;
1256                        } else {
1257                            // other side closed
1258                            tracing::warn!("unexpected comm watcher channel close");
1259                            break None;
1260                        }
1261                    }
1262                }
1263            };
1264
1265            if reloop {
1266                // Kick back to the outer loop to pop off the queue.
1267                // Should not have produced an update.
1268                assert!(update.is_none());
1269                continue;
1270            }
1271
1272            break match update {
1273                Some((
1274                    Some(alloc_key),
1275                    ProcState::Created {
1276                        create_key,
1277                        point,
1278                        pid,
1279                    },
1280                )) => match self.project_proc_into_global_extent(&alloc_key, &point) {
1281                    Ok(global_point) => {
1282                        tracing::debug!("reprojected coords: {} -> {}", point, global_point);
1283                        Some(ProcState::Created {
1284                            create_key,
1285                            point: global_point,
1286                            pid,
1287                        })
1288                    }
1289                    Err(e) => {
1290                        tracing::error!(
1291                            "failed to project coords for proc: {}.{}: {}",
1292                            alloc_key,
1293                            create_key,
1294                            e
1295                        );
1296                        None
1297                    }
1298                },
1299                Some((None, ProcState::Created { .. })) => {
1300                    panic!("illegal state: missing alloc_key for ProcState::Created event")
1301                }
1302                Some((_, update)) => {
1303                    if let ProcState::Failed { description, .. } = &update {
1304                        tracing::error!(description);
1305                        self.failed = true;
1306                    }
1307                    Some(update)
1308                }
1309                None => None,
1310            };
1311        }
1312    }
1313
1314    fn spec(&self) -> &AllocSpec {
1315        &self.spec
1316    }
1317
1318    fn extent(&self) -> &Extent {
1319        &self.spec.extent
1320    }
1321
1322    fn alloc_name(&self) -> &AllocName {
1323        &self.alloc_name
1324    }
1325
1326    async fn stop(&mut self) -> Result<(), AllocatorError> {
1327        tracing::info!("stopping alloc");
1328
1329        for (host_id, task_state) in self.host_states.iter_mut() {
1330            tracing::debug!("stopping alloc at host {}", host_id);
1331            task_state.tx.post(RemoteProcessAllocatorMessage::Stop);
1332        }
1333
1334        Ok(())
1335    }
1336
1337    /// For Tcp and Metatls, return a router address which has the same IP address
1338    /// as the alloc's bootstrap address, but with port set as 0. We do this
1339    /// instead of using `ChannelAddr::any` is because we want alloc uses the
1340    /// same IP address for all its frontend ports. In some environment, the
1341    /// host can have public IP address and private IP address, and use the wrong
1342    /// one could lead to port unreachable error.
1343    ///
1344    /// For other channel types, this method still uses ChannelAddr::any.
1345    fn client_router_addr(&self) -> ChannelAddr {
1346        with_unspecified_port_or_any(&self.bootstrap_addr)
1347    }
1348}
1349
1350impl Drop for RemoteProcessAlloc {
1351    fn drop(&mut self) {
1352        tracing::debug!(
1353            "dropping RemoteProcessAlloc of alloc_name {}",
1354            self.alloc_name
1355        );
1356    }
1357}
1358
1359#[cfg(test)]
1360mod test {
1361    use std::assert_matches::assert_matches;
1362
1363    use hyperactor::channel::ChannelRx;
1364    use hyperactor::reference as hyperactor_reference;
1365    use hyperactor::testing::ids::test_proc_id;
1366    use ndslice::extent;
1367    use tokio::sync::oneshot;
1368
1369    use super::*;
1370    use crate::alloc::ChannelTransport;
1371    use crate::alloc::MockAlloc;
1372    use crate::alloc::MockAllocWrapper;
1373    use crate::alloc::MockAllocator;
1374    use crate::alloc::ProcStopReason;
1375    use crate::alloc::with_unspecified_port_or_any;
1376    use crate::proc_agent::ProcAgent;
1377
1378    async fn read_all_created(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1379        let mut i: usize = 0;
1380        while i < alloc_len {
1381            let m = rx.recv().await.unwrap();
1382            match m {
1383                RemoteProcessProcStateMessage::Update(_, ProcState::Created { .. }) => i += 1,
1384                RemoteProcessProcStateMessage::HeartBeat => {}
1385                _ => panic!("unexpected message: {:?}", m),
1386            }
1387        }
1388    }
1389
1390    async fn read_all_running(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1391        let mut i: usize = 0;
1392        while i < alloc_len {
1393            let m = rx.recv().await.unwrap();
1394            match m {
1395                RemoteProcessProcStateMessage::Update(_, ProcState::Running { .. }) => i += 1,
1396                RemoteProcessProcStateMessage::HeartBeat => {}
1397                _ => panic!("unexpected message: {:?}", m),
1398            }
1399        }
1400    }
1401
1402    async fn read_all_stopped(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1403        let mut i: usize = 0;
1404        while i < alloc_len {
1405            let m = rx.recv().await.unwrap();
1406            match m {
1407                RemoteProcessProcStateMessage::Update(_, ProcState::Stopped { .. }) => i += 1,
1408                RemoteProcessProcStateMessage::HeartBeat => {}
1409                _ => panic!("unexpected message: {:?}", m),
1410            }
1411        }
1412    }
1413
1414    fn set_procstate_expectations(alloc: &mut MockAlloc, extent: Extent) {
1415        alloc.expect_extent().return_const(extent.clone());
1416        let mut create_keys = Vec::new();
1417        for point in extent.points() {
1418            let create_key = ShortUuid::generate();
1419            create_keys.push(create_key.clone());
1420            alloc.expect_next().times(1).return_once(move || {
1421                Some(ProcState::Created {
1422                    create_key: create_key.clone(),
1423                    point,
1424                    pid: 0,
1425                })
1426            });
1427        }
1428        for (i, create_key) in create_keys
1429            .iter()
1430            .take(extent.num_ranks())
1431            .cloned()
1432            .enumerate()
1433        {
1434            let proc_id = test_proc_id(&format!("{i}"));
1435            let mesh_agent = hyperactor_reference::ActorRef::<ProcAgent>::attest(
1436                proc_id.actor_id("mesh_agent", i),
1437            );
1438            alloc.expect_next().times(1).return_once(move || {
1439                Some(ProcState::Running {
1440                    create_key,
1441                    proc_id,
1442                    addr: ChannelAddr::Unix("/proc0".parse().unwrap()),
1443                    mesh_agent,
1444                })
1445            });
1446        }
1447        for create_key in create_keys.iter().take(extent.num_ranks()).cloned() {
1448            alloc.expect_next().times(1).return_once(|| {
1449                Some(ProcState::Stopped {
1450                    create_key,
1451                    reason: ProcStopReason::Unknown,
1452                })
1453            });
1454        }
1455    }
1456
1457    #[timed_test::async_timed_test(timeout_secs = 60)]
1458    async fn test_simple() {
1459        let config = hyperactor_config::global::lock();
1460        let _guard = config.override_key(
1461            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1462            Duration::from_millis(100),
1463        );
1464        hyperactor_telemetry::initialize_logging_for_test();
1465        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1466        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1467        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1468
1469        let extent = extent!(host = 1, gpu = 2);
1470        let tx = channel::dial(serve_addr.clone()).unwrap();
1471
1472        let alloc_name: AllocName = AllocName("test_alloc_name".to_string());
1473        let mut alloc = MockAlloc::new();
1474        alloc.expect_alloc_name().return_const(alloc_name.clone());
1475        alloc.expect_extent().return_const(extent.clone());
1476
1477        set_procstate_expectations(&mut alloc, extent.clone());
1478
1479        // final none
1480        alloc.expect_next().return_const(None);
1481
1482        let mut allocator = MockAllocator::new();
1483        let total_messages = extent.num_ranks() * 3 + 1;
1484        let mock_wrapper = MockAllocWrapper::new_block_next(
1485            alloc,
1486            // block after create, running, stopped and done.
1487            total_messages,
1488        );
1489        allocator
1490            .expect_allocate()
1491            .times(1)
1492            .return_once(move |_| Ok(mock_wrapper));
1493
1494        let remote_allocator = RemoteProcessAllocator::new();
1495        let handle = tokio::spawn({
1496            let remote_allocator = remote_allocator.clone();
1497            async move {
1498                remote_allocator
1499                    .start_with_allocator(serve_addr, allocator, None)
1500                    .await
1501            }
1502        });
1503
1504        let alloc_key = ShortUuid::generate();
1505
1506        tx.send(RemoteProcessAllocatorMessage::Allocate {
1507            alloc_key: alloc_key.clone(),
1508            extent: extent.clone(),
1509            bootstrap_addr,
1510            hosts: vec![],
1511            client_context: None,
1512            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1513        })
1514        .await
1515        .unwrap();
1516
1517        // Allocated
1518        let m = rx.recv().await.unwrap();
1519        assert_matches!(
1520            m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, alloc_name: got_alloc_name }
1521                if got_alloc_name == alloc_name && got_alloc_key == alloc_key
1522        );
1523
1524        // All Created events
1525        let mut rank: usize = 0;
1526        let mut create_keys = Vec::with_capacity(extent.num_ranks());
1527        while rank < extent.num_ranks() {
1528            let m = rx.recv().await.unwrap();
1529            match m {
1530                RemoteProcessProcStateMessage::Update(
1531                    got_alloc_key,
1532                    ProcState::Created {
1533                        create_key, point, ..
1534                    },
1535                ) => {
1536                    let expected_point = extent.point_of_rank(rank).unwrap();
1537                    assert_eq!(got_alloc_key, alloc_key);
1538                    assert_eq!(point, expected_point);
1539                    create_keys.push(create_key);
1540                    rank += 1;
1541                }
1542                RemoteProcessProcStateMessage::HeartBeat => {}
1543                _ => panic!("unexpected message: {:?}", m),
1544            }
1545        }
1546        // All Running events
1547        let mut rank: usize = 0;
1548        while rank < extent.num_ranks() {
1549            let m = rx.recv().await.unwrap();
1550            match m {
1551                RemoteProcessProcStateMessage::Update(
1552                    got_alloc_key,
1553                    ProcState::Running {
1554                        create_key,
1555                        proc_id,
1556                        mesh_agent,
1557                        addr: _,
1558                    },
1559                ) => {
1560                    assert_eq!(got_alloc_key, alloc_key);
1561                    assert_eq!(create_key, create_keys[rank]);
1562                    let expected_proc_id = test_proc_id(&format!("{}", rank));
1563                    let expected_mesh_agent = hyperactor_reference::ActorRef::<ProcAgent>::attest(
1564                        expected_proc_id.actor_id("mesh_agent", rank),
1565                    );
1566                    assert_eq!(proc_id, expected_proc_id);
1567                    assert_eq!(mesh_agent, expected_mesh_agent);
1568                    rank += 1;
1569                }
1570                RemoteProcessProcStateMessage::HeartBeat => {}
1571                _ => panic!("unexpected message: {:?}", m),
1572            }
1573        }
1574        // All Stopped events
1575        let mut rank: usize = 0;
1576        while rank < extent.num_ranks() {
1577            let m = rx.recv().await.unwrap();
1578            match m {
1579                RemoteProcessProcStateMessage::Update(
1580                    got_alloc_key,
1581                    ProcState::Stopped {
1582                        create_key,
1583                        reason: ProcStopReason::Unknown,
1584                    },
1585                ) => {
1586                    assert_eq!(got_alloc_key, alloc_key);
1587                    assert_eq!(create_key, create_keys[rank]);
1588                    rank += 1;
1589                }
1590                RemoteProcessProcStateMessage::HeartBeat => {}
1591                _ => panic!("unexpected message: {:?}", m),
1592            }
1593        }
1594        // Done
1595        loop {
1596            let m = rx.recv().await.unwrap();
1597            match m {
1598                RemoteProcessProcStateMessage::Done(got_alloc_key) => {
1599                    assert_eq!(got_alloc_key, alloc_key);
1600                    break;
1601                }
1602                RemoteProcessProcStateMessage::HeartBeat => {}
1603                _ => panic!("unexpected message: {:?}", m),
1604            }
1605        }
1606
1607        remote_allocator.terminate();
1608        handle.await.unwrap().unwrap();
1609    }
1610
1611    #[timed_test::async_timed_test(timeout_secs = 60)]
1612    // This test is flaky in OSS CI, but healthy in sandcastle. Considering we
1613    // are in the middle of deprecating allocator, it is not worth the effort to
1614    // fix it for OSS.
1615    #[cfg_attr(not(fbcode_build), ignore)]
1616    async fn test_normal_stop() {
1617        let config = hyperactor_config::global::lock();
1618        let _guard = config.override_key(
1619            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1620            Duration::from_millis(100),
1621        );
1622        hyperactor_telemetry::initialize_logging_for_test();
1623        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1624        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1625        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1626
1627        let extent = extent!(host = 1, gpu = 2);
1628        let tx = channel::dial(serve_addr.clone()).unwrap();
1629
1630        let alloc_name: AllocName = AllocName("test_alloc_name".to_string());
1631        let mut alloc = MockAllocWrapper::new_block_next(
1632            MockAlloc::new(),
1633            // block after all created, all running
1634            extent.num_ranks() * 2,
1635        );
1636        let next_tx = alloc.notify_tx();
1637        alloc
1638            .alloc
1639            .expect_alloc_name()
1640            .return_const(alloc_name.clone());
1641        alloc.alloc.expect_extent().return_const(extent.clone());
1642
1643        set_procstate_expectations(&mut alloc.alloc, extent.clone());
1644
1645        alloc.alloc.expect_next().return_const(None);
1646        alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1647
1648        let mut allocator = MockAllocator::new();
1649        allocator
1650            .expect_allocate()
1651            .times(1)
1652            .return_once(|_| Ok(alloc));
1653
1654        let remote_allocator = RemoteProcessAllocator::new();
1655        let handle = tokio::spawn({
1656            let remote_allocator = remote_allocator.clone();
1657            async move {
1658                remote_allocator
1659                    .start_with_allocator(serve_addr, allocator, None)
1660                    .await
1661            }
1662        });
1663
1664        let alloc_key = ShortUuid::generate();
1665        tx.send(RemoteProcessAllocatorMessage::Allocate {
1666            alloc_key: alloc_key.clone(),
1667            extent: extent.clone(),
1668            bootstrap_addr,
1669            hosts: vec![],
1670            client_context: None,
1671            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1672        })
1673        .await
1674        .unwrap();
1675
1676        // Allocated
1677        let m = rx.recv().await.unwrap();
1678        assert_matches!(
1679            m,
1680            RemoteProcessProcStateMessage::Allocated {  alloc_name: got_alloc_name, alloc_key: got_alloc_key }
1681            if alloc_name == got_alloc_name && alloc_key == got_alloc_key
1682        );
1683
1684        read_all_created(&mut rx, extent.num_ranks()).await;
1685        read_all_running(&mut rx, extent.num_ranks()).await;
1686
1687        // allocation finished. now we stop it.
1688        tracing::info!("stopping allocation");
1689        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1690        // receive all stops
1691        next_tx.send(()).unwrap();
1692
1693        read_all_stopped(&mut rx, extent.num_ranks()).await;
1694
1695        remote_allocator.terminate();
1696        handle.await.unwrap().unwrap();
1697    }
1698
1699    #[timed_test::async_timed_test(timeout_secs = 60)]
1700    // This test is flaky in OSS CI, but healthy in sandcastle. Considering we
1701    // are in the middle of deprecating allocator, it is not worth the effort to
1702    // fix it for OSS.
1703    #[cfg_attr(not(fbcode_build), ignore)]
1704    async fn test_realloc() {
1705        let config = hyperactor_config::global::lock();
1706        let _guard = config.override_key(
1707            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1708            Duration::from_millis(100),
1709        );
1710        hyperactor_telemetry::initialize_logging_for_test();
1711        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1712        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1713        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1714
1715        let extent = extent!(host = 1, gpu = 2);
1716
1717        let tx = channel::dial(serve_addr.clone()).unwrap();
1718
1719        let alloc_name: AllocName = AllocName("test_alloc_name".to_string());
1720        let mut alloc1 = MockAllocWrapper::new_block_next(
1721            MockAlloc::new(),
1722            // block after all created, all running
1723            extent.num_ranks() * 2,
1724        );
1725        let next_tx1 = alloc1.notify_tx();
1726        alloc1
1727            .alloc
1728            .expect_alloc_name()
1729            .return_const(alloc_name.clone());
1730        alloc1.alloc.expect_extent().return_const(extent.clone());
1731
1732        set_procstate_expectations(&mut alloc1.alloc, extent.clone());
1733        alloc1.alloc.expect_next().return_const(None);
1734        alloc1.alloc.expect_stop().times(1).return_once(|| Ok(()));
1735        // second allocation
1736        let mut alloc2 = MockAllocWrapper::new_block_next(
1737            MockAlloc::new(),
1738            // block after all created, all running
1739            extent.num_ranks() * 2,
1740        );
1741        let next_tx2 = alloc2.notify_tx();
1742        alloc2
1743            .alloc
1744            .expect_alloc_name()
1745            .return_const(alloc_name.clone());
1746        alloc2.alloc.expect_extent().return_const(extent.clone());
1747        set_procstate_expectations(&mut alloc2.alloc, extent.clone());
1748        alloc2.alloc.expect_next().return_const(None);
1749        alloc2.alloc.expect_stop().times(1).return_once(|| Ok(()));
1750
1751        let mut allocator = MockAllocator::new();
1752        allocator
1753            .expect_allocate()
1754            .times(1)
1755            .return_once(|_| Ok(alloc1));
1756        // second alloc
1757        allocator
1758            .expect_allocate()
1759            .times(1)
1760            .return_once(|_| Ok(alloc2));
1761
1762        let remote_allocator = RemoteProcessAllocator::new();
1763        let handle = tokio::spawn({
1764            let remote_allocator = remote_allocator.clone();
1765            async move {
1766                remote_allocator
1767                    .start_with_allocator(serve_addr, allocator, None)
1768                    .await
1769            }
1770        });
1771
1772        let alloc_key = ShortUuid::generate();
1773
1774        tx.send(RemoteProcessAllocatorMessage::Allocate {
1775            alloc_key: alloc_key.clone(),
1776            extent: extent.clone(),
1777            bootstrap_addr: bootstrap_addr.clone(),
1778            hosts: vec![],
1779            client_context: None,
1780            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1781        })
1782        .await
1783        .unwrap();
1784
1785        // Allocated
1786        let m = rx.recv().await.unwrap();
1787        assert_matches!(
1788            m,
1789            RemoteProcessProcStateMessage::Allocated { alloc_name: got_alloc_name, alloc_key: got_alloc_key }
1790            if got_alloc_name == alloc_name && got_alloc_key == alloc_key
1791        );
1792
1793        read_all_created(&mut rx, extent.num_ranks()).await;
1794        read_all_running(&mut rx, extent.num_ranks()).await;
1795
1796        let alloc_key = ShortUuid::generate();
1797
1798        // allocation finished now we request a new one
1799        tx.send(RemoteProcessAllocatorMessage::Allocate {
1800            alloc_key: alloc_key.clone(),
1801            extent: extent.clone(),
1802            bootstrap_addr,
1803            hosts: vec![],
1804            client_context: None,
1805            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1806        })
1807        .await
1808        .unwrap();
1809        // unblock next for the first allocation
1810        next_tx1.send(()).unwrap();
1811        // we expect a stop(), then Stopped proc states, then a new Allocated
1812        read_all_stopped(&mut rx, extent.num_ranks()).await;
1813        let m = rx.recv().await.unwrap();
1814        assert_matches!(m, RemoteProcessProcStateMessage::Done(_));
1815        let m = rx.recv().await.unwrap();
1816        assert_matches!(
1817            m,
1818            RemoteProcessProcStateMessage::Allocated { alloc_name: got_alloc_name, alloc_key: got_alloc_key }
1819            if got_alloc_name == alloc_name && got_alloc_key == alloc_key
1820        );
1821        // ProcStates for the new allocation
1822        read_all_created(&mut rx, extent.num_ranks()).await;
1823        read_all_running(&mut rx, extent.num_ranks()).await;
1824        // finally stop
1825        tracing::info!("stopping allocation");
1826        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1827        // receive all stops
1828        next_tx2.send(()).unwrap();
1829
1830        read_all_stopped(&mut rx, extent.num_ranks()).await;
1831
1832        remote_allocator.terminate();
1833        handle.await.unwrap().unwrap();
1834    }
1835
1836    #[timed_test::async_timed_test(timeout_secs = 60)]
1837    async fn test_upstream_closed() {
1838        // Use temporary config for this test
1839        let config = hyperactor_config::global::lock();
1840        let _guard1 = config.override_key(
1841            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1842            Duration::from_secs(1),
1843        );
1844        let _guard2 = config.override_key(
1845            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1846            Duration::from_millis(100),
1847        );
1848
1849        hyperactor_telemetry::initialize_logging_for_test();
1850        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1851        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1852        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1853
1854        let extent = extent!(host = 1, gpu = 2);
1855
1856        let tx = channel::dial(serve_addr.clone()).unwrap();
1857
1858        let alloc_name: AllocName = AllocName("test_alloc_name".to_string());
1859        let mut alloc = MockAllocWrapper::new_block_next(
1860            MockAlloc::new(),
1861            // block after all created, all running
1862            extent.num_ranks() * 2,
1863        );
1864        let next_tx = alloc.notify_tx();
1865        alloc
1866            .alloc
1867            .expect_alloc_name()
1868            .return_const(alloc_name.clone());
1869        alloc.alloc.expect_extent().return_const(extent.clone());
1870
1871        set_procstate_expectations(&mut alloc.alloc, extent.clone());
1872
1873        alloc.alloc.expect_next().return_const(None);
1874        // we expect a stop due to the failure
1875        // synchronize test with the stop
1876        let (stop_tx, stop_rx) = oneshot::channel();
1877        alloc.alloc.expect_stop().times(1).return_once(|| {
1878            stop_tx.send(()).unwrap();
1879            Ok(())
1880        });
1881
1882        let mut allocator = MockAllocator::new();
1883        allocator
1884            .expect_allocate()
1885            .times(1)
1886            .return_once(|_| Ok(alloc));
1887
1888        let remote_allocator = RemoteProcessAllocator::new();
1889        let handle = tokio::spawn({
1890            let remote_allocator = remote_allocator.clone();
1891            async move {
1892                remote_allocator
1893                    .start_with_allocator(serve_addr, allocator, None)
1894                    .await
1895            }
1896        });
1897
1898        let alloc_key = ShortUuid::generate();
1899
1900        tx.send(RemoteProcessAllocatorMessage::Allocate {
1901            alloc_key: alloc_key.clone(),
1902            extent: extent.clone(),
1903            bootstrap_addr,
1904            hosts: vec![],
1905            client_context: None,
1906            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1907        })
1908        .await
1909        .unwrap();
1910
1911        // Allocated
1912        let m = rx.recv().await.unwrap();
1913        assert_matches!(
1914            m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, alloc_name: got_alloc_name }
1915                if got_alloc_name == alloc_name && got_alloc_key == alloc_key
1916        );
1917
1918        read_all_created(&mut rx, extent.num_ranks()).await;
1919        read_all_running(&mut rx, extent.num_ranks()).await;
1920
1921        // allocation finished. terminate connection.
1922        tracing::info!("closing upstream");
1923        drop(rx);
1924        // wait for the heartbeat to expire
1925        tokio::time::sleep(Duration::from_secs(2)).await;
1926        // wait for the stop to be called
1927        stop_rx.await.unwrap();
1928        // unblock next
1929        next_tx.send(()).unwrap();
1930        remote_allocator.terminate();
1931        handle.await.unwrap().unwrap();
1932    }
1933
1934    #[timed_test::async_timed_test(timeout_secs = 60)]
1935    async fn test_inner_alloc_failure() {
1936        let config = hyperactor_config::global::lock();
1937        let _guard = config.override_key(
1938            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1939            Duration::from_mins(1),
1940        );
1941        hyperactor_telemetry::initialize_logging_for_test();
1942        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1943        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1944        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1945
1946        let extent = extent!(host = 1, gpu = 2);
1947
1948        let tx = channel::dial(serve_addr.clone()).unwrap();
1949
1950        let test_alloc_name: AllocName = AllocName("test_alloc_name".to_string());
1951        let mut alloc = MockAllocWrapper::new_block_next(
1952            MockAlloc::new(),
1953            // block after the failure update
1954            1,
1955        );
1956        let next_tx = alloc.notify_tx();
1957        alloc
1958            .alloc
1959            .expect_alloc_name()
1960            .return_const(test_alloc_name.clone());
1961        alloc.alloc.expect_extent().return_const(extent.clone());
1962        alloc
1963            .alloc
1964            .expect_next()
1965            .times(1)
1966            .return_const(Some(ProcState::Failed {
1967                alloc_name: test_alloc_name.clone(),
1968                description: "test".to_string(),
1969            }));
1970        alloc.alloc.expect_next().times(1).return_const(None);
1971
1972        alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1973
1974        let mut allocator = MockAllocator::new();
1975        allocator
1976            .expect_allocate()
1977            .times(1)
1978            .return_once(|_| Ok(alloc));
1979
1980        let remote_allocator = RemoteProcessAllocator::new();
1981        let handle = tokio::spawn({
1982            let remote_allocator = remote_allocator.clone();
1983            async move {
1984                remote_allocator
1985                    .start_with_allocator(serve_addr, allocator, None)
1986                    .await
1987            }
1988        });
1989
1990        let alloc_key = ShortUuid::generate();
1991        tx.send(RemoteProcessAllocatorMessage::Allocate {
1992            alloc_key: alloc_key.clone(),
1993            extent: extent.clone(),
1994            bootstrap_addr,
1995            hosts: vec![],
1996            client_context: None,
1997            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1998        })
1999        .await
2000        .unwrap();
2001
2002        // Allocated
2003        let m = rx.recv().await.unwrap();
2004        assert_matches!(
2005            m,
2006            RemoteProcessProcStateMessage::Allocated {  alloc_name: got_alloc_name, alloc_key: got_alloc_key }
2007            if test_alloc_name == got_alloc_name && alloc_key == got_alloc_key
2008        );
2009
2010        // Failed
2011        let m = rx.recv().await.unwrap();
2012        assert_matches!(
2013            m,
2014            RemoteProcessProcStateMessage::Update(
2015                got_alloc_key,
2016                ProcState::Failed { alloc_name, description }
2017            ) if got_alloc_key == alloc_key && alloc_name == test_alloc_name && description == "test"
2018        );
2019
2020        tracing::info!("stopping allocation");
2021        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
2022        // receive all stops
2023        next_tx.send(()).unwrap();
2024        // we are expecting 1 Done when Alloc successfully stops.
2025        let m = rx.recv().await.unwrap();
2026        assert_matches!(
2027            m,
2028            RemoteProcessProcStateMessage::Done(got_alloc_key)
2029            if got_alloc_key == alloc_key
2030        );
2031
2032        remote_allocator.terminate();
2033        handle.await.unwrap().unwrap();
2034    }
2035
2036    #[timed_test::async_timed_test(timeout_secs = 60)]
2037    async fn test_trace_id_propagation() {
2038        let config = hyperactor_config::global::lock();
2039        let _guard = config.override_key(
2040            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2041            Duration::from_mins(1),
2042        );
2043        hyperactor_telemetry::initialize_logging(hyperactor_telemetry::DefaultTelemetryClock {});
2044        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
2045        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
2046        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
2047
2048        let extent = extent!(host = 1, gpu = 1);
2049        let tx = channel::dial(serve_addr.clone()).unwrap();
2050        let test_alloc_name: AllocName = AllocName("test_alloc_name".to_string());
2051        let test_trace_id = "test_trace_id_12345";
2052
2053        // Create a mock alloc that we can verify receives the correct trace id
2054        let mut alloc = MockAlloc::new();
2055        alloc
2056            .expect_alloc_name()
2057            .return_const(test_alloc_name.clone());
2058        alloc.expect_extent().return_const(extent.clone());
2059        alloc.expect_next().return_const(None);
2060
2061        // Create a mock allocator that captures the AllocSpec passed to it
2062        let mut allocator = MockAllocator::new();
2063        allocator
2064            .expect_allocate()
2065            .times(1)
2066            .withf(move |spec: &AllocSpec| {
2067                // Verify that the trace id is correctly set in the constraints
2068                spec.constraints
2069                    .match_labels
2070                    .get(CLIENT_TRACE_ID_LABEL)
2071                    .is_some_and(|trace_id| trace_id == test_trace_id)
2072            })
2073            .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
2074
2075        let remote_allocator = RemoteProcessAllocator::new();
2076        let handle = tokio::spawn({
2077            let remote_allocator = remote_allocator.clone();
2078            async move {
2079                remote_allocator
2080                    .start_with_allocator(serve_addr, allocator, None)
2081                    .await
2082            }
2083        });
2084
2085        let alloc_key = ShortUuid::generate();
2086        tx.send(RemoteProcessAllocatorMessage::Allocate {
2087            alloc_key: alloc_key.clone(),
2088            extent: extent.clone(),
2089            bootstrap_addr,
2090            hosts: vec![],
2091            client_context: Some(ClientContext {
2092                trace_id: test_trace_id.to_string(),
2093            }),
2094            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
2095        })
2096        .await
2097        .unwrap();
2098
2099        // Verify we get the allocated message
2100        let m = rx.recv().await.unwrap();
2101        assert_matches!(
2102            m,
2103            RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, alloc_name: got_alloc_name }
2104            if got_alloc_name == test_alloc_name && got_alloc_key == alloc_key
2105        );
2106
2107        // Verify we get the done message since the mock alloc returns None immediately
2108        let m = rx.recv().await.unwrap();
2109        assert_matches!(
2110            m,
2111            RemoteProcessProcStateMessage::Done(got_alloc_key)
2112            if alloc_key == got_alloc_key
2113        );
2114
2115        remote_allocator.terminate();
2116        handle.await.unwrap().unwrap();
2117    }
2118
2119    #[timed_test::async_timed_test(timeout_secs = 60)]
2120    async fn test_trace_id_propagation_no_client_context() {
2121        let config = hyperactor_config::global::lock();
2122        let _guard = config.override_key(
2123            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2124            Duration::from_mins(1),
2125        );
2126        hyperactor_telemetry::initialize_logging(hyperactor_telemetry::DefaultTelemetryClock {});
2127        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
2128        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
2129        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
2130
2131        let extent = extent!(host = 1, gpu = 1);
2132        let tx = channel::dial(serve_addr.clone()).unwrap();
2133        let test_alloc_name: AllocName = AllocName("test_alloc_name".to_string());
2134
2135        // Create a mock alloc
2136        let mut alloc = MockAlloc::new();
2137        alloc
2138            .expect_alloc_name()
2139            .return_const(test_alloc_name.clone());
2140        alloc.expect_extent().return_const(extent.clone());
2141        alloc.expect_next().return_const(None);
2142
2143        // Create a mock allocator that verifies no trace id is set when client_context is None
2144        let mut allocator = MockAllocator::new();
2145        allocator
2146            .expect_allocate()
2147            .times(1)
2148            .withf(move |spec: &AllocSpec| {
2149                // Verify that no trace id is set in the constraints when client_context is None
2150                spec.constraints.match_labels.is_empty()
2151            })
2152            .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
2153
2154        let remote_allocator = RemoteProcessAllocator::new();
2155        let handle = tokio::spawn({
2156            let remote_allocator = remote_allocator.clone();
2157            async move {
2158                remote_allocator
2159                    .start_with_allocator(serve_addr, allocator, None)
2160                    .await
2161            }
2162        });
2163
2164        let alloc_key = ShortUuid::generate();
2165        tx.send(RemoteProcessAllocatorMessage::Allocate {
2166            alloc_key: alloc_key.clone(),
2167            extent: extent.clone(),
2168            bootstrap_addr,
2169            hosts: vec![],
2170            client_context: None,
2171            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
2172        })
2173        .await
2174        .unwrap();
2175
2176        // Verify we get the allocated message
2177        let m = rx.recv().await.unwrap();
2178        assert_matches!(
2179            m,
2180            RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, alloc_name: got_alloc_name }
2181            if got_alloc_name == test_alloc_name && got_alloc_key == alloc_key
2182        );
2183
2184        // Verify we get the done message since the mock alloc returns None immediately
2185        let m = rx.recv().await.unwrap();
2186        assert_matches!(
2187            m,
2188            RemoteProcessProcStateMessage::Done(got_alloc_key)
2189            if got_alloc_key == alloc_key
2190        );
2191
2192        remote_allocator.terminate();
2193        handle.await.unwrap().unwrap();
2194    }
2195}
2196
2197#[cfg(test)]
2198mod test_alloc {
2199    use std::os::unix::process::ExitStatusExt;
2200
2201    use hyperactor_config;
2202    use ndslice::extent;
2203    use nix::sys::signal;
2204    use nix::unistd::Pid;
2205    use timed_test::async_timed_test;
2206
2207    use super::*;
2208
2209    #[async_timed_test(timeout_secs = 60)]
2210    #[cfg(fbcode_build)]
2211    async fn test_alloc_simple() {
2212        // Use temporary config for this test
2213        let config = hyperactor_config::global::lock();
2214        let _guard1 = config.override_key(
2215            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2216            Duration::from_secs(1),
2217        );
2218        let _guard2 = config.override_key(
2219            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2220            Duration::from_millis(100),
2221        );
2222        hyperactor_telemetry::initialize_logging(hyperactor_telemetry::DefaultTelemetryClock {});
2223
2224        let spec = AllocSpec {
2225            extent: extent!(host = 2, gpu = 2),
2226            constraints: Default::default(),
2227            proc_name: None,
2228            transport: ChannelTransport::Unix,
2229            proc_allocation_mode: Default::default(),
2230        };
2231        let alloc_name = AllocName("test_alloc_name".to_string());
2232
2233        let task1_allocator = RemoteProcessAllocator::new();
2234        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2235        let task1_addr_string = task1_addr.to_string();
2236        let task1_cmd = Command::new(crate::testresource::get(
2237            "monarch/hyperactor_mesh/bootstrap",
2238        ));
2239        let task2_allocator = RemoteProcessAllocator::new();
2240        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2241        let task2_addr_string = task2_addr.to_string();
2242        let task2_cmd = Command::new(crate::testresource::get(
2243            "monarch/hyperactor_mesh/bootstrap",
2244        ));
2245        let task1_allocator_copy = task1_allocator.clone();
2246        let task1_allocator_handle = tokio::spawn(async move {
2247            tracing::info!("spawning task1");
2248            task1_allocator_copy
2249                .start(task1_cmd, task1_addr, None)
2250                .await
2251                .unwrap();
2252        });
2253        let task2_allocator_copy = task2_allocator.clone();
2254        let task2_allocator_handle = tokio::spawn(async move {
2255            task2_allocator_copy
2256                .start(task2_cmd, task2_addr, None)
2257                .await
2258                .unwrap();
2259        });
2260
2261        let mut initializer = MockRemoteProcessAllocInitializer::new();
2262        initializer.expect_initialize_alloc().return_once(move || {
2263            Ok(vec![
2264                RemoteProcessAllocHost {
2265                    hostname: task1_addr_string,
2266                    id: "task1".to_string(),
2267                },
2268                RemoteProcessAllocHost {
2269                    hostname: task2_addr_string,
2270                    id: "task2".to_string(),
2271                },
2272            ])
2273        });
2274        let mut alloc = RemoteProcessAlloc::new(spec.clone(), alloc_name, 0, initializer)
2275            .await
2276            .unwrap();
2277        let mut created = HashSet::new();
2278        let mut running_procs = HashSet::new();
2279        let mut proc_points = HashSet::new();
2280        for _ in 0..spec.extent.num_ranks() * 2 {
2281            let proc_state = alloc.next().await.unwrap();
2282            tracing::debug!("test got message: {:?}", proc_state);
2283            match proc_state {
2284                ProcState::Created {
2285                    create_key, point, ..
2286                } => {
2287                    created.insert(create_key);
2288                    proc_points.insert(point);
2289                }
2290                ProcState::Running { create_key, .. } => {
2291                    assert!(created.remove(&create_key));
2292                    running_procs.insert(create_key);
2293                }
2294                _ => panic!("expected Created or Running"),
2295            }
2296        }
2297        assert!(created.is_empty());
2298        // ensure coords coverage
2299        assert!(
2300            spec.extent
2301                .points()
2302                .all(|point| proc_points.contains(&point))
2303        );
2304
2305        // ensure no more pending items
2306        let timeout = tokio::time::Instant::now() + std::time::Duration::from_millis(1000);
2307        tokio::select! {
2308            _ = tokio::time::sleep_until(timeout) => {},
2309            _ = alloc.next() => panic!("expected no more items"),
2310        }
2311
2312        // stop the allocation
2313        alloc.stop().await.unwrap();
2314        for _ in 0..spec.extent.num_ranks() {
2315            let proc_state = alloc.next().await.unwrap();
2316            tracing::info!("test received next proc_state: {:?}", proc_state);
2317            match proc_state {
2318                ProcState::Stopped {
2319                    create_key, reason, ..
2320                } => {
2321                    assert!(running_procs.remove(&create_key));
2322                    assert_eq!(reason, ProcStopReason::Stopped);
2323                }
2324                _ => panic!("expected stopped"),
2325            }
2326        }
2327        // Exactly one None
2328        let proc_state = alloc.next().await;
2329        assert!(proc_state.is_none());
2330        // Anything afterwards is None
2331        let proc_state = alloc.next().await;
2332        assert!(proc_state.is_none());
2333
2334        task1_allocator.terminate();
2335        task1_allocator_handle.await.unwrap();
2336        task2_allocator.terminate();
2337        task2_allocator_handle.await.unwrap();
2338    }
2339
2340    #[async_timed_test(timeout_secs = 60)]
2341    #[cfg(fbcode_build)]
2342    async fn test_alloc_host_failure() {
2343        // Use temporary config for this test
2344        let config = hyperactor_config::global::lock();
2345        let _guard1 = config.override_key(
2346            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2347            Duration::from_secs(1),
2348        );
2349        let _guard2 = config.override_key(
2350            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2351            Duration::from_millis(100),
2352        );
2353        hyperactor_telemetry::initialize_logging(hyperactor_telemetry::DefaultTelemetryClock {});
2354
2355        let spec = AllocSpec {
2356            extent: extent!(host = 2, gpu = 2),
2357            constraints: Default::default(),
2358            proc_name: None,
2359            transport: ChannelTransport::Unix,
2360            proc_allocation_mode: Default::default(),
2361        };
2362        let alloc_name = AllocName("test_alloc_name".to_string());
2363
2364        let task1_allocator = RemoteProcessAllocator::new();
2365        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2366        let task1_addr_string = task1_addr.to_string();
2367        let task1_cmd = Command::new(crate::testresource::get(
2368            "monarch/hyperactor_mesh/bootstrap",
2369        ));
2370        let task2_allocator = RemoteProcessAllocator::new();
2371        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2372        let task2_addr_string = task2_addr.to_string();
2373        let task2_cmd = Command::new(crate::testresource::get(
2374            "monarch/hyperactor_mesh/bootstrap",
2375        ));
2376        let task1_allocator_copy = task1_allocator.clone();
2377        let task1_allocator_handle = tokio::spawn(async move {
2378            tracing::info!("spawning task1");
2379            task1_allocator_copy
2380                .start(task1_cmd, task1_addr, None)
2381                .await
2382                .unwrap();
2383            tracing::info!("task1 terminated");
2384        });
2385        let task2_allocator_copy = task2_allocator.clone();
2386        let task2_allocator_handle = tokio::spawn(async move {
2387            task2_allocator_copy
2388                .start(task2_cmd, task2_addr, None)
2389                .await
2390                .unwrap();
2391            tracing::info!("task2 terminated");
2392        });
2393
2394        let mut initializer = MockRemoteProcessAllocInitializer::new();
2395        initializer.expect_initialize_alloc().return_once(move || {
2396            Ok(vec![
2397                RemoteProcessAllocHost {
2398                    hostname: task1_addr_string,
2399                    id: "task1".to_string(),
2400                },
2401                RemoteProcessAllocHost {
2402                    hostname: task2_addr_string,
2403                    id: "task2".to_string(),
2404                },
2405            ])
2406        });
2407        let mut alloc = RemoteProcessAlloc::new(spec.clone(), alloc_name, 0, initializer)
2408            .await
2409            .unwrap();
2410        for _ in 0..spec.extent.num_ranks() * 2 {
2411            match alloc.next().await {
2412                Some(ProcState::Created { .. }) | Some(ProcState::Running { .. }) => {}
2413                _ => panic!("expected Created or Running"),
2414            }
2415        }
2416
2417        // ensure no more pending items
2418        let timeout = tokio::time::Instant::now() + std::time::Duration::from_millis(1000);
2419        tokio::select! {
2420            _ = tokio::time::sleep_until(timeout) => {},
2421            _ = alloc.next() => panic!("expected no more items"),
2422        }
2423
2424        // now we kill task1 and wait for timeout
2425        tracing::info!("aborting task1 allocator");
2426        task1_allocator_handle.abort();
2427        tokio::time::sleep(
2428            hyperactor_config::global::get(hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)
2429                * 2,
2430        )
2431        .await;
2432        for _ in 0..spec.extent.num_ranks() / 2 {
2433            let proc_state = alloc.next().await.unwrap();
2434            tracing::info!("test received next proc_state: {:?}", proc_state);
2435            match proc_state {
2436                ProcState::Stopped { reason, .. } => {
2437                    assert_eq!(reason, ProcStopReason::HostWatchdog);
2438                }
2439                _ => panic!("expected stopped"),
2440            }
2441        }
2442        // no more events
2443        let timeout = tokio::time::Instant::now() + std::time::Duration::from_millis(1000);
2444        tokio::select! {
2445            _ = tokio::time::sleep_until(timeout) => {},
2446            _ = alloc.next() => panic!("expected no more items"),
2447        }
2448
2449        // abort the second host
2450        tracing::info!("aborting task2 allocator");
2451        task2_allocator_handle.abort();
2452        tokio::time::sleep(
2453            hyperactor_config::global::get(hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)
2454                * 2,
2455        )
2456        .await;
2457        for _ in 0..spec.extent.num_ranks() / 2 {
2458            let proc_state = alloc.next().await.unwrap();
2459            tracing::info!("test received next proc_state: {:?}", proc_state);
2460            match proc_state {
2461                ProcState::Stopped { reason, .. } => {
2462                    assert_eq!(reason, ProcStopReason::HostWatchdog);
2463                }
2464                _ => panic!("expected stopped"),
2465            }
2466        }
2467        // now the alloc has stopped we always get None
2468        let proc_state = alloc.next().await;
2469        assert!(proc_state.is_none());
2470        // Anything afterwards is None
2471        let proc_state = alloc.next().await;
2472        assert!(proc_state.is_none());
2473    }
2474
2475    #[async_timed_test(timeout_secs = 60)]
2476    #[cfg(fbcode_build)]
2477    async fn test_alloc_inner_alloc_failure() {
2478        // SAFETY: Test happens in single-threaded code.
2479        unsafe {
2480            std::env::set_var("MONARCH_MESSAGE_DELIVERY_TIMEOUT_SECS", "1");
2481        }
2482
2483        let config = hyperactor_config::global::lock();
2484        let _guard = config.override_key(
2485            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2486            Duration::from_millis(100),
2487        );
2488        hyperactor_telemetry::initialize_logging_for_test();
2489
2490        let spec = AllocSpec {
2491            extent: extent!(host = 2, gpu = 2),
2492            constraints: Default::default(),
2493            proc_name: None,
2494            transport: ChannelTransport::Unix,
2495            proc_allocation_mode: Default::default(),
2496        };
2497        let alloc_name = AllocName("test_alloc_name".to_string());
2498
2499        let task1_allocator = RemoteProcessAllocator::new();
2500        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2501        let task1_addr_string = task1_addr.to_string();
2502        let task1_cmd = Command::new(crate::testresource::get(
2503            "monarch/hyperactor_mesh/bootstrap",
2504        ));
2505        let task2_allocator = RemoteProcessAllocator::new();
2506        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2507        let task2_addr_string = task2_addr.to_string();
2508        // non-existent binary to fail the allocation
2509        let task2_cmd = Command::new("/caught/somewhere/in/time");
2510        let task1_allocator_copy = task1_allocator.clone();
2511        let task1_allocator_handle = tokio::spawn(async move {
2512            tracing::info!("spawning task1");
2513            task1_allocator_copy
2514                .start(task1_cmd, task1_addr, None)
2515                .await
2516                .unwrap();
2517        });
2518        let task2_allocator_copy = task2_allocator.clone();
2519        let task2_allocator_handle = tokio::spawn(async move {
2520            task2_allocator_copy
2521                .start(task2_cmd, task2_addr, None)
2522                .await
2523                .unwrap();
2524        });
2525
2526        let mut initializer = MockRemoteProcessAllocInitializer::new();
2527        initializer.expect_initialize_alloc().return_once(move || {
2528            Ok(vec![
2529                RemoteProcessAllocHost {
2530                    hostname: task1_addr_string,
2531                    id: "task1".to_string(),
2532                },
2533                RemoteProcessAllocHost {
2534                    hostname: task2_addr_string,
2535                    id: "task2".to_string(),
2536                },
2537            ])
2538        });
2539        let mut alloc = RemoteProcessAlloc::new(spec.clone(), alloc_name, 0, initializer)
2540            .await
2541            .unwrap();
2542        let mut created = HashSet::new();
2543        let mut started_procs = HashSet::new();
2544        let mut proc_points = HashSet::new();
2545        let mut failed = 0;
2546        // task 1 procs + 1 failed event for task 2
2547        for _ in 0..spec.extent.num_ranks() + 1 {
2548            let proc_state = alloc.next().await.unwrap();
2549            tracing::debug!("test got message: {:?}", proc_state);
2550            match proc_state {
2551                ProcState::Created {
2552                    create_key, point, ..
2553                } => {
2554                    created.insert(create_key);
2555                    proc_points.insert(point);
2556                }
2557                ProcState::Running { create_key, .. } => {
2558                    assert!(created.remove(&create_key));
2559                    started_procs.insert(create_key);
2560                }
2561                ProcState::Failed { .. } => {
2562                    failed += 1;
2563                }
2564                _ => panic!("expected Created, Running or Failed"),
2565            }
2566        }
2567        assert!(created.is_empty());
2568        assert_eq!(failed, 1);
2569        // ensure coords coverage for task 1
2570        for rank in 0..spec.extent.num_ranks() / 2 {
2571            let point = spec.extent.point_of_rank(rank).unwrap();
2572            assert!(proc_points.contains(&point));
2573        }
2574
2575        // ensure no more pending items
2576        let timeout = tokio::time::Instant::now() + std::time::Duration::from_millis(1000);
2577        tokio::select! {
2578            _ = tokio::time::sleep_until(timeout) => {},
2579            _ = alloc.next() => panic!("expected no more items"),
2580        }
2581
2582        // stop the allocation
2583        alloc.stop().await.unwrap();
2584        for _ in 0..spec.extent.num_ranks() / 2 {
2585            let proc_state = alloc.next().await.unwrap();
2586            tracing::info!("test received next proc_state: {:?}", proc_state);
2587            match proc_state {
2588                ProcState::Stopped {
2589                    create_key, reason, ..
2590                } => {
2591                    assert!(started_procs.remove(&create_key));
2592                    assert_eq!(reason, ProcStopReason::Stopped);
2593                }
2594                _ => panic!("expected stopped"),
2595            }
2596        }
2597        // Exactly one None
2598        let proc_state = alloc.next().await;
2599        assert!(proc_state.is_none());
2600        // Anything afterwards is None
2601        let proc_state = alloc.next().await;
2602        assert!(proc_state.is_none());
2603
2604        task1_allocator.terminate();
2605        task1_allocator_handle.await.unwrap();
2606        task2_allocator.terminate();
2607        task2_allocator_handle.await.unwrap();
2608    }
2609
2610    #[async_timed_test(timeout_secs = 180)]
2611    #[cfg(fbcode_build)]
2612    async fn test_remote_process_alloc_signal_handler() {
2613        hyperactor_telemetry::initialize_logging_for_test();
2614        let num_proc_meshes = 5;
2615        let hosts_per_proc_mesh = 5;
2616
2617        let pid_addr = ChannelAddr::any(ChannelTransport::Unix);
2618        let (pid_addr, mut pid_rx) = channel::serve::<u32>(pid_addr).unwrap();
2619
2620        let addresses = (0..(num_proc_meshes * hosts_per_proc_mesh))
2621            .map(|_| ChannelAddr::any(ChannelTransport::Unix).to_string())
2622            .collect::<Vec<_>>();
2623
2624        let remote_process_allocators = addresses
2625            .iter()
2626            .map(|addr| {
2627                Command::new(crate::testresource::get(
2628                    "monarch/hyperactor_mesh/remote_process_allocator",
2629                ))
2630                .env("RUST_LOG", "info")
2631                .arg(format!("--addr={addr}"))
2632                .stdout(std::process::Stdio::piped())
2633                .spawn()
2634                .unwrap()
2635            })
2636            .collect::<Vec<_>>();
2637
2638        let done_allocating_addr = ChannelAddr::any(ChannelTransport::Unix);
2639        let (done_allocating_addr, mut done_allocating_rx) =
2640            channel::serve::<()>(done_allocating_addr).unwrap();
2641        let mut remote_process_alloc = Command::new(crate::testresource::get(
2642            "monarch/hyperactor_mesh/remote_process_alloc",
2643        ))
2644        .arg(format!("--done-allocating-addr={}", done_allocating_addr))
2645        .arg(format!("--addresses={}", addresses.join(",")))
2646        .arg(format!("--num-proc-meshes={}", num_proc_meshes))
2647        .arg(format!("--hosts-per-proc-mesh={}", hosts_per_proc_mesh))
2648        .arg(format!("--pid-addr={}", pid_addr))
2649        .spawn()
2650        .unwrap();
2651
2652        done_allocating_rx.recv().await.unwrap();
2653        let mut received_pids = Vec::new();
2654        while let Ok(pid) = pid_rx.recv().await {
2655            received_pids.push(pid);
2656            if received_pids.len() == remote_process_allocators.len() {
2657                break;
2658            }
2659        }
2660
2661        signal::kill(
2662            Pid::from_raw(remote_process_alloc.id().unwrap() as i32),
2663            signal::SIGINT,
2664        )
2665        .unwrap();
2666
2667        assert_eq!(
2668            remote_process_alloc.wait().await.unwrap().signal(),
2669            Some(signal::SIGINT as i32)
2670        );
2671
2672        tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
2673
2674        // Assert that the processes spawned by ProcessAllocator have been killed
2675        for child_pid in received_pids {
2676            let pid_check = Command::new("kill")
2677                .arg("-0")
2678                .arg(child_pid.to_string())
2679                .output()
2680                .await
2681                .expect("Failed to check if PID is alive");
2682
2683            assert!(
2684                !pid_check.status.success(),
2685                "PID {} should no longer be alive",
2686                child_pid
2687            );
2688        }
2689
2690        // Cleanup remote process allocator processes as SIGINT causes the current
2691        // allocs to stop but not the RemoteProcessAllocator loops
2692        for mut remote_process_allocator in remote_process_allocators {
2693            remote_process_allocator.kill().await.unwrap();
2694        }
2695    }
2696}