hyperactor_mesh/alloc/
remoteprocess.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::collections::HashSet;
11use std::collections::VecDeque;
12use std::sync::Arc;
13use std::time::Duration;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use dashmap::DashMap;
18use futures::FutureExt;
19use futures::future::join_all;
20use futures::future::select_all;
21use hyperactor::Named;
22use hyperactor::WorldId;
23use hyperactor::channel;
24use hyperactor::channel::ChannelAddr;
25use hyperactor::channel::ChannelRx;
26use hyperactor::channel::ChannelTransport;
27use hyperactor::channel::ChannelTx;
28use hyperactor::channel::Rx;
29use hyperactor::channel::TcpMode;
30use hyperactor::channel::Tx;
31use hyperactor::channel::TxStatus;
32use hyperactor::clock;
33use hyperactor::clock::Clock;
34use hyperactor::clock::RealClock;
35use hyperactor::config;
36use hyperactor::mailbox::DialMailboxRouter;
37use hyperactor::mailbox::MailboxServer;
38use hyperactor::observe_async;
39use hyperactor::observe_result;
40use hyperactor::reference::Reference;
41use hyperactor::serde_json;
42use mockall::automock;
43use ndslice::Region;
44use ndslice::Slice;
45use ndslice::View;
46use ndslice::ViewExt;
47use ndslice::view::Extent;
48use ndslice::view::Point;
49use serde::Deserialize;
50use serde::Serialize;
51use strum::AsRefStr;
52use tokio::io::AsyncWriteExt;
53use tokio::process::Command;
54use tokio::sync::mpsc::UnboundedReceiver;
55use tokio::sync::mpsc::UnboundedSender;
56use tokio::sync::mpsc::unbounded_channel;
57use tokio::task::JoinHandle;
58use tokio_stream::StreamExt;
59use tokio_stream::wrappers::WatchStream;
60use tokio_util::sync::CancellationToken;
61
62use crate::alloc::Alloc;
63use crate::alloc::AllocConstraints;
64use crate::alloc::AllocSpec;
65use crate::alloc::Allocator;
66use crate::alloc::AllocatorError;
67use crate::alloc::ProcState;
68use crate::alloc::ProcStopReason;
69use crate::alloc::ProcessAllocator;
70use crate::alloc::REMOTE_ALLOC_BOOTSTRAP_ADDR;
71use crate::alloc::process::CLIENT_TRACE_ID_LABEL;
72use crate::alloc::process::ClientContext;
73use crate::alloc::serve_with_config;
74use crate::alloc::with_unspecified_port_or_any;
75use crate::shortuuid::ShortUuid;
76
77/// Control messages sent from remote process allocator to local allocator.
78#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
79pub enum RemoteProcessAllocatorMessage {
80    /// Create allocation with given spec and send updates to bootstrap_addr.
81    Allocate {
82        /// The key used to identify this allocation.
83        alloc_key: ShortUuid,
84        /// The extent to allocate.
85        extent: Extent,
86        /// Bootstrap address to be used for sending updates.
87        bootstrap_addr: ChannelAddr,
88        /// Ordered list of hosts in this allocation. Can be used to
89        /// pre-populate the any local configurations such as torch.dist.
90        hosts: Vec<String>,
91        /// Client context which is passed to the ProcessAlloc
92        /// Todo: Once RemoteProcessAllocator moves to mailbox,
93        /// the client_context will go to the message header instead
94        client_context: Option<ClientContext>,
95        /// The address allocator should use for its forwarder.
96        forwarder_addr: ChannelAddr,
97    },
98    /// Stop allocation.
99    Stop,
100    /// Heartbeat message to check if remote process allocator and its
101    /// host are alive.
102    HeartBeat,
103}
104
105/// Control message sent from local allocator to remote allocator
106/// relaying process state updates.
107/// AsRefStr allows us to log the values
108#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
109pub enum RemoteProcessProcStateMessage {
110    /// Allocation successful and Update, Done messages will follow.
111    Allocated {
112        alloc_key: ShortUuid,
113        world_id: WorldId,
114    },
115    /// ProcState updates.
116    Update(ShortUuid, ProcState),
117    /// Underlying Alloc is done.
118    Done(ShortUuid),
119    /// Heartbeat message to check if client is alive.
120    HeartBeat,
121}
122
123/// Allocator with a service frontend that wraps ProcessAllocator.
124pub struct RemoteProcessAllocator {
125    cancel_token: CancellationToken,
126}
127
128async fn conditional_sleeper<F: futures::Future<Output = ()>>(t: Option<F>) {
129    match t {
130        Some(timer) => timer.await,
131        None => futures::future::pending().await,
132    }
133}
134
135impl RemoteProcessAllocator {
136    /// Create a new allocator. It will not start until start() is called.
137    pub fn new() -> Arc<Self> {
138        Arc::new(Self {
139            cancel_token: CancellationToken::new(),
140        })
141    }
142
143    /// Stop the allocator. This will stop any ongoing allocations.
144    pub fn terminate(&self) {
145        self.cancel_token.cancel();
146    }
147
148    /// Start a remote process allocator with given cmd listening for
149    /// RemoteProcessAllocatorMessage on serve_addr. Call will block until cancelled.
150    /// The implementation is simple such that it can only handle one Alloc at
151    /// a time. Generally that's the most common use-case.
152    /// Flow works as follows:
153    /// 1. Client sends Allocate message to serve_addr.
154    /// 2. Allocator connects to bootstrap_addr, creates Alloc and sends Allocated message.
155    /// 3. Allocator streams one or more Update messages to bootstrap_addr as Alloc progresses
156    ///    making the following changes:
157    ///    * Remap mesh_agent listen address to our own forwarder actor address.
158    /// 4. Allocator sends Done message to bootstrap_addr when Alloc is done.
159    ///
160    /// At any point, client can send Stop message to serve_addr to stop the allocator.
161    /// If timeout is Some, the allocator will exit if no client connects within
162    /// that timeout, and no child allocation is running.
163    #[hyperactor::instrument]
164    pub async fn start(
165        &self,
166        cmd: Command,
167        serve_addr: ChannelAddr,
168        timeout: Option<Duration>,
169    ) -> Result<(), anyhow::Error> {
170        let process_allocator = ProcessAllocator::new(cmd);
171        self.start_with_allocator(serve_addr, process_allocator, timeout)
172            .await
173    }
174
175    /// Start a remote process allocator with given allocator listening for
176    /// RemoteProcessAllocatorMessage on serve_addr.
177    /// Used for testing.
178    #[tracing::instrument(skip(self, process_allocator))]
179    pub async fn start_with_allocator<A: Allocator + Send + Sync + 'static>(
180        &self,
181        serve_addr: ChannelAddr,
182        mut process_allocator: A,
183        timeout: Option<Duration>,
184    ) -> Result<(), anyhow::Error>
185    where
186        <A as Allocator>::Alloc: Send,
187        <A as Allocator>::Alloc: Sync,
188    {
189        tracing::info!("starting remote allocator on: {}", serve_addr);
190        let (_, mut rx) = channel::serve(serve_addr.clone()).map_err(anyhow::Error::from)?;
191
192        struct ActiveAllocation {
193            handle: JoinHandle<()>,
194            cancel_token: CancellationToken,
195        }
196        #[observe_async("RemoteProcessAllocator")]
197        async fn ensure_previous_alloc_stopped(active_allocation: &mut Option<ActiveAllocation>) {
198            if let Some(active_allocation) = active_allocation.take() {
199                tracing::info!("previous alloc found, stopping");
200                active_allocation.cancel_token.cancel();
201                match active_allocation.handle.await {
202                    Ok(_) => {
203                        // Todo, add named tracing for state change here
204                        tracing::info!("allocation stopped.")
205                    }
206                    Err(e) => {
207                        tracing::error!("allocation handler failed: {}", e);
208                    }
209                }
210            }
211        }
212
213        let mut active_allocation: Option<ActiveAllocation> = None;
214        loop {
215            // Refresh each loop iteration so the timer updates whenever a message
216            // is received.
217            let sleep = conditional_sleeper(timeout.map(|t| RealClock.sleep(t)));
218            tokio::select! {
219                msg = rx.recv() => {
220                    match msg {
221                        Ok(RemoteProcessAllocatorMessage::Allocate {
222                            alloc_key,
223                            extent,
224                            bootstrap_addr,
225                            hosts,
226                            client_context,
227                            forwarder_addr,
228                        }) => {
229                            tracing::info!("received allocation request for {} with extent {}", alloc_key, extent);
230                            ensure_previous_alloc_stopped(&mut active_allocation).await;
231
232                            // Create the corresponding local allocation spec.
233                            let mut constraints: AllocConstraints = Default::default();
234                            if let Some(context) = &client_context {
235                                constraints = AllocConstraints {
236                                    match_labels: HashMap::from([(
237                                    CLIENT_TRACE_ID_LABEL.to_string(),
238                                    context.trace_id.to_string(),
239                                    )]
240                                )};
241                                tracing::info!(
242                                    monarch_client_trace_id = context.trace_id.to_string(),
243                                    "allocating...",
244                                );
245                            }
246
247
248                            let spec = AllocSpec {
249                                extent,
250                                constraints,
251                                proc_name: None, // TODO(meriksen, direct addressing): we need to pass the addressing mode here
252                                transport: ChannelTransport::Unix,
253                                proc_allocation_mode: Default::default(),
254                            };
255
256                            match process_allocator.allocate(spec.clone()).await {
257                                Ok(alloc) => {
258                                    let cancel_token = CancellationToken::new();
259                                    active_allocation = Some(ActiveAllocation {
260                                        cancel_token: cancel_token.clone(),
261                                        handle: tokio::spawn(Self::handle_allocation_request(
262                                            Box::new(alloc) as Box<dyn Alloc + Send + Sync>,
263                                            alloc_key,
264                                            bootstrap_addr,
265                                            hosts,
266                                            cancel_token,
267                                            forwarder_addr,
268                                        )),
269                                    })
270                                }
271                                Err(e) => {
272                                    tracing::error!("allocation for {:?} failed: {}", spec, e);
273                                    continue;
274                                }
275                            }
276                        }
277                        Ok(RemoteProcessAllocatorMessage::Stop) => {
278                            tracing::info!("received stop request");
279
280                            ensure_previous_alloc_stopped(&mut active_allocation).await;
281                        }
282                        // Hearbeat message is discarded immediately after being received, sender (client)
283                        // relies on channel ack to know if the receiver (remote process allocator) is
284                        // still alive. No state needs to be updated.
285                        Ok(RemoteProcessAllocatorMessage::HeartBeat) => {}
286                        Err(e) => {
287                            tracing::error!("upstream channel error: {}", e);
288                            continue;
289                        }
290                    }
291                }
292                _ = self.cancel_token.cancelled() => {
293                    tracing::info!("main loop cancelled");
294
295                    ensure_previous_alloc_stopped(&mut active_allocation).await;
296
297                    break;
298                }
299                _ = sleep => {
300                    // If there are any active allocations, reset the timeout.
301                    if active_allocation.is_some() {
302                        continue;
303                    }
304                    // Else, exit the loop as a client hasn't connected in a reasonable
305                    // amount of time.
306                    tracing::warn!("timeout of {} seconds elapsed without any allocations, exiting", timeout.unwrap_or_default().as_secs());
307                    break;
308                }
309            }
310        }
311
312        Ok(())
313    }
314
315    #[tracing::instrument(skip(alloc, cancel_token))]
316    #[observe_async("RemoteProcessAllocator")]
317    async fn handle_allocation_request(
318        alloc: Box<dyn Alloc + Send + Sync>,
319        alloc_key: ShortUuid,
320        bootstrap_addr: ChannelAddr,
321        hosts: Vec<String>,
322        cancel_token: CancellationToken,
323        forwarder_addr: ChannelAddr,
324    ) {
325        tracing::info!("handle allocation request, bootstrap_addr: {bootstrap_addr}");
326        // start proc message forwarder
327        let (forwarder_addr, forwarder_rx) = match serve_with_config(forwarder_addr) {
328            Ok(v) => v,
329            Err(e) => {
330                tracing::error!("failed to to bootstrap forwarder actor: {}", e);
331                return;
332            }
333        };
334        let router = DialMailboxRouter::new();
335        let mailbox_handle = router.clone().serve(forwarder_rx);
336        tracing::info!("started forwarder on: {}", forwarder_addr);
337
338        // Check if we need to write TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE
339        // See: https://github.com/fairinternal/xlformers/blob/llama4_monarch/tools/launching/torchx/entrypoint/generate_ranks.py
340        if let Ok(hosts_file) = std::env::var("TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE") {
341            tracing::info!("writing hosts to {}", hosts_file);
342            #[derive(Serialize)]
343            struct Hosts {
344                hostnames: Vec<String>,
345            }
346            match serde_json::to_string(&Hosts { hostnames: hosts }) {
347                Ok(json) => match tokio::fs::File::create(&hosts_file).await {
348                    Ok(mut file) => {
349                        if file.write_all(json.as_bytes()).await.is_err() {
350                            tracing::error!("failed to write hosts to {}", hosts_file);
351                            return;
352                        }
353                    }
354                    Err(e) => {
355                        tracing::error!("failed to open hosts file {}: {}", hosts_file, e);
356                        return;
357                    }
358                },
359                Err(e) => {
360                    tracing::error!("failed to serialize hosts: {}", e);
361                    return;
362                }
363            }
364        }
365
366        Self::handle_allocation_loop(
367            alloc,
368            alloc_key,
369            bootstrap_addr,
370            router,
371            forwarder_addr,
372            cancel_token,
373        )
374        .await;
375
376        mailbox_handle.stop("alloc stopped");
377        if let Err(e) = mailbox_handle.await {
378            tracing::error!("failed to join forwarder: {}", e);
379        }
380    }
381
382    async fn handle_allocation_loop(
383        mut alloc: Box<dyn Alloc + Send + Sync>,
384        alloc_key: ShortUuid,
385        bootstrap_addr: ChannelAddr,
386        router: DialMailboxRouter,
387        forward_addr: ChannelAddr,
388        cancel_token: CancellationToken,
389    ) {
390        let world_id = alloc.world_id().clone();
391        tracing::info!("starting handle allocation loop for {}", world_id);
392        let tx = match channel::dial(bootstrap_addr) {
393            Ok(tx) => tx,
394            Err(err) => {
395                tracing::error!("failed to dial bootstrap address: {}", err);
396                return;
397            }
398        };
399        let message = RemoteProcessProcStateMessage::Allocated {
400            alloc_key: alloc_key.clone(),
401            world_id,
402        };
403        tracing::info!(name = message.as_ref(), "sending allocated message",);
404        if let Err(e) = tx.send(message).await {
405            tracing::error!("failed to send Allocated message: {}", e);
406            return;
407        }
408
409        let mut mesh_agents_by_create_key = HashMap::new();
410        let mut running = true;
411        let tx_status = tx.status().clone();
412        let mut tx_watcher = WatchStream::new(tx_status);
413        loop {
414            tokio::select! {
415                _ = cancel_token.cancelled(), if running => {
416                    tracing::info!("cancelled, stopping allocation");
417                    running = false;
418                    if let Err(e) = alloc.stop().await {
419                        tracing::error!("stop failed: {}", e);
420                        break;
421                    }
422                }
423                status = tx_watcher.next(), if running => {
424                    match status  {
425                        Some(TxStatus::Closed) => {
426                            tracing::error!("upstream channel state closed");
427                            break;
428                        },
429                        _ => {
430                            tracing::debug!("got channel event: {:?}", status.unwrap());
431                            continue;
432                        }
433                    }
434                }
435                e = alloc.next() => {
436                    match e {
437                        Some(event) => {
438                            tracing::debug!(name = event.as_ref(), "got event: {:?}", event);
439                            let event = match event {
440                                ProcState::Created { .. } => event,
441                                ProcState::Running { create_key, proc_id, mesh_agent, addr } => {
442                                    // TODO(meriksen, direct addressing): disable remapping in direct addressing mode
443                                    tracing::debug!("remapping mesh_agent {}: addr {} -> {}", mesh_agent, addr, forward_addr);
444                                    mesh_agents_by_create_key.insert(create_key.clone(), mesh_agent.clone());
445                                    router.bind(mesh_agent.actor_id().proc_id().clone().into(), addr);
446                                    ProcState::Running { create_key, proc_id, mesh_agent, addr: forward_addr.clone() }
447                                },
448                                ProcState::Stopped { create_key, reason } => {
449                                    match mesh_agents_by_create_key.remove(&create_key) {
450                                        Some(mesh_agent) => {
451                                            tracing::debug!("unmapping mesh_agent {}", mesh_agent);
452                                            let agent_ref: Reference = mesh_agent.actor_id().proc_id().clone().into();
453                                            router.unbind(&agent_ref);
454                                        },
455                                        None => {
456                                            tracing::warn!("mesh_agent not found for create key {}", create_key);
457                                        }
458                                    }
459                                    ProcState::Stopped { create_key, reason }
460                                },
461                                ProcState::Failed { ref world_id, ref description } => {
462                                    tracing::error!("allocation failed for {}: {}", world_id, description);
463                                    event
464                                }
465                            };
466                            tracing::debug!(name = event.as_ref(), "sending event: {:?}", event);
467                            tx.post(RemoteProcessProcStateMessage::Update(alloc_key.clone(), event));
468                        }
469                        None => {
470                            tracing::debug!("sending done");
471                            tx.post(RemoteProcessProcStateMessage::Done(alloc_key.clone()));
472                            running = false;
473                            break;
474                        }
475                    }
476                }
477                _ = RealClock.sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)) => {
478                    tracing::trace!("sending heartbeat");
479                    tx.post(RemoteProcessProcStateMessage::HeartBeat);
480                }
481            }
482        }
483        tracing::info!("allocation handler loop exited");
484        if running {
485            tracing::info!("stopping processes");
486            if let Err(e) = alloc.stop_and_wait().await {
487                tracing::error!("stop failed: {}", e);
488                return;
489            }
490            tracing::info!("stop finished");
491        }
492    }
493}
494
495/// HostId is a string-based identifier of the host. It may be the same
496/// as the hostname but in some situations a separate ID may be used.
497type HostId = String;
498
499/// A host entry passed down from the initializer.
500#[derive(Clone)]
501pub struct RemoteProcessAllocHost {
502    /// A unique identifier of that host. Hostname may be used by ID can
503    /// be different in some situations.
504    pub id: HostId,
505    /// The FQDN of the host.
506    pub hostname: String,
507}
508
509/// State of a host in the RemoteProcessAlloc.
510struct RemoteProcessAllocHostState {
511    /// The allocation key used to identify the host.
512    alloc_key: ShortUuid,
513    /// The host ID of the remote host.
514    host_id: HostId,
515    /// TX channel to the remote host allocator.
516    tx: ChannelTx<RemoteProcessAllocatorMessage>,
517    /// Set of active processes on this host.
518    active_procs: HashSet<ShortUuid>,
519    /// Region allocated by host.
520    region: Region,
521    /// World ID for this host as indicated from Allocated message.
522    world_id: Option<WorldId>,
523    /// If remote allocator sent us ProcState::Failed.
524    failed: bool,
525    /// If remote allocater has ever allocated a proc.
526    allocated: bool,
527}
528
529#[automock]
530#[async_trait]
531/// Interface to provide the set of hosts to be used by RemoteProcessAlloc.
532pub trait RemoteProcessAllocInitializer {
533    /// Initializes and returns a list of hosts to be used by this RemoteProcessAlloc.
534    async fn initialize_alloc(&mut self) -> Result<Vec<RemoteProcessAllocHost>, anyhow::Error>;
535}
536
537/// Wrapper struct around `HashMap<HostId, RemoteProcessAllocHostState>`
538/// to ensure that host addresses are synced with the signal handler
539struct HostStates {
540    inner: HashMap<HostId, RemoteProcessAllocHostState>,
541    host_addresses: Arc<DashMap<HostId, ChannelAddr>>,
542}
543
544impl HostStates {
545    fn new(host_addresses: Arc<DashMap<HostId, ChannelAddr>>) -> HostStates {
546        Self {
547            inner: HashMap::new(),
548            host_addresses,
549        }
550    }
551
552    fn insert(
553        &mut self,
554        host_id: HostId,
555        state: RemoteProcessAllocHostState,
556        address: ChannelAddr,
557    ) {
558        self.host_addresses.insert(host_id.clone(), address);
559        self.inner.insert(host_id, state);
560    }
561
562    fn get(&self, host_id: &HostId) -> Option<&RemoteProcessAllocHostState> {
563        self.inner.get(host_id)
564    }
565
566    fn get_mut(&mut self, host_id: &HostId) -> Option<&mut RemoteProcessAllocHostState> {
567        self.inner.get_mut(host_id)
568    }
569
570    fn remove(&mut self, host_id: &HostId) -> Option<RemoteProcessAllocHostState> {
571        self.host_addresses.remove(host_id);
572        self.inner.remove(host_id)
573    }
574
575    fn iter(&self) -> impl Iterator<Item = (&HostId, &RemoteProcessAllocHostState)> {
576        self.inner.iter()
577    }
578
579    fn iter_mut(&mut self) -> impl Iterator<Item = (&HostId, &mut RemoteProcessAllocHostState)> {
580        self.inner.iter_mut()
581    }
582
583    fn is_empty(&self) -> bool {
584        self.inner.is_empty()
585    }
586    // Any missing HashMap methods should be added here as needed
587}
588
589/// A generalized implementation of an Alloc using one or more hosts running
590/// RemoteProcessAlloc for process allocation.
591pub struct RemoteProcessAlloc {
592    // The initializer to be called at the first next() invocation to obtain
593    // allocated hosts.
594    initializer: Box<dyn RemoteProcessAllocInitializer + Send + Sync>,
595    spec: AllocSpec,
596    remote_allocator_port: u16,
597    world_id: WorldId,
598    ordered_hosts: Vec<RemoteProcessAllocHost>,
599    // Indicates that the initial remote allocation requests have been sent.
600    started: bool,
601    // Indicates that this Alloc is active (we have at least one remote process running).
602    running: bool,
603    // Inidicates that the allocation process has permanently failed.
604    failed: bool,
605    // Maps the alloc key to the host.
606    alloc_to_host: HashMap<ShortUuid, HostId>,
607    host_states: HostStates,
608    world_offsets: HashMap<WorldId, usize>,
609    event_queue: VecDeque<ProcState>,
610    comm_watcher_tx: UnboundedSender<HostId>,
611    comm_watcher_rx: UnboundedReceiver<HostId>,
612
613    bootstrap_addr: ChannelAddr,
614    rx: ChannelRx<RemoteProcessProcStateMessage>,
615    _signal_cleanup_guard: hyperactor::SignalCleanupGuard,
616}
617
618impl RemoteProcessAlloc {
619    /// Create a new Alloc. initializer will be called on the first invocation of next()
620    /// to obtain a list of allocate hosts. Then Allocate message will be sent to all
621    /// RemoteProcessAllocator on all hosts. Heartbeats will be used to maintain health
622    /// status of remote hosts.
623    #[tracing::instrument(skip(initializer))]
624    #[observe_result("RemoteProcessAlloc")]
625    pub async fn new(
626        spec: AllocSpec,
627        world_id: WorldId,
628        remote_allocator_port: u16,
629        initializer: impl RemoteProcessAllocInitializer + Send + Sync + 'static,
630    ) -> Result<Self, anyhow::Error> {
631        let alloc_serve_addr = match config::global::try_get_cloned(REMOTE_ALLOC_BOOTSTRAP_ADDR) {
632            Some(addr_str) => addr_str.parse()?,
633            None => ChannelAddr::any(spec.transport.clone()),
634        };
635
636        let (bootstrap_addr, rx) = serve_with_config(alloc_serve_addr)?;
637
638        tracing::info!(
639            "starting alloc for {} on: {}",
640            world_id,
641            bootstrap_addr.clone()
642        );
643
644        let (comm_watcher_tx, comm_watcher_rx) = unbounded_channel();
645
646        let host_addresses = Arc::new(DashMap::<HostId, ChannelAddr>::new());
647        let host_addresses_for_signal = host_addresses.clone();
648
649        // Register cleanup callback with global signal manager
650        let signal_cleanup_guard =
651            hyperactor::register_signal_cleanup_scoped(Box::pin(async move {
652                join_all(host_addresses_for_signal.iter().map(|entry| async move {
653                    let addr = entry.value().clone();
654                    match channel::dial(addr.clone()) {
655                        Ok(tx) => {
656                            if let Err(e) = tx.send(RemoteProcessAllocatorMessage::Stop).await {
657                                tracing::error!("Failed to send Stop to {}: {}", addr, e);
658                            }
659                        }
660                        Err(e) => {
661                            tracing::error!("Failed to dial {} during signal cleanup: {}", addr, e);
662                        }
663                    }
664                }))
665                .await;
666            }));
667
668        Ok(Self {
669            spec,
670            world_id,
671            remote_allocator_port,
672            initializer: Box::new(initializer),
673            world_offsets: HashMap::new(),
674            ordered_hosts: Vec::new(),
675            alloc_to_host: HashMap::new(),
676            host_states: HostStates::new(host_addresses),
677            bootstrap_addr,
678            event_queue: VecDeque::new(),
679            comm_watcher_tx,
680            comm_watcher_rx,
681            rx,
682            started: false,
683            running: true,
684            failed: false,
685            _signal_cleanup_guard: signal_cleanup_guard,
686        })
687    }
688
689    /// Start a RemoteProcessAllocator tx watcher. It spawns a task that monitor the underlying
690    /// TxStatus and repot to the main next() loop via comm_watcher_tx. It does not however
691    /// send the actual heartbteats. Just watches the status.
692    /// The task will terminate once comm_watcher_rx is closed.
693    ///
694    /// A task approach was used here as opposed to select! to avoid nesting complexities with
695    /// select!() and select_all().
696    async fn start_comm_watcher(&self) {
697        let mut tx_watchers = Vec::new();
698        for host in &self.ordered_hosts {
699            let tx_status = self.host_states.get(&host.id).unwrap().tx.status().clone();
700            let watcher = WatchStream::new(tx_status);
701            tx_watchers.push((watcher, host.id.clone()));
702        }
703        assert!(!tx_watchers.is_empty());
704        let tx = self.comm_watcher_tx.clone();
705        tokio::spawn(async move {
706            loop {
707                let mut tx_status_futures = Vec::new();
708                for (watcher, _) in &mut tx_watchers {
709                    let fut = watcher.next().boxed();
710                    tx_status_futures.push(fut);
711                }
712                let (tx_status, index, _) = select_all(tx_status_futures).await;
713                let host_id = match tx_watchers.get(index) {
714                    Some((_, host_id)) => host_id.clone(),
715                    None => {
716                        // Should never happen
717                        tracing::error!(
718                            "got selected index {} with no matching host in {}",
719                            index,
720                            tx_watchers.len()
721                        );
722                        continue;
723                    }
724                };
725                if let Some(tx_status) = tx_status {
726                    tracing::debug!("host {} channel event: {:?}", host_id, tx_status);
727                    if tx_status == TxStatus::Closed {
728                        if tx.send(host_id.clone()).is_err() {
729                            // other side closed
730                            break;
731                        }
732                        tx_watchers.remove(index);
733                        if tx_watchers.is_empty() {
734                            // All of the statuses have been closed, exit the loop.
735                            break;
736                        }
737                    }
738                }
739            }
740        });
741    }
742
743    /// Ensure that we have the list of allocated hosts and that we send Allocate
744    /// request to all o fthem. Once done, start_comm_watcher() is called to start
745    /// the channel watcher.
746    /// Function is idempotent.
747    async fn ensure_started(&mut self) -> Result<(), anyhow::Error> {
748        if self.started || self.failed {
749            return Ok(());
750        }
751
752        self.started = true;
753        let hosts = self
754            .initializer
755            .initialize_alloc()
756            .await
757            .context("alloc initializer error")?;
758        if hosts.is_empty() {
759            anyhow::bail!("initializer returned empty list of hosts");
760        }
761        // prepare a list of host names in this allocation to be sent
762        // to remote allocators.
763        let hostnames: Vec<_> = hosts.iter().map(|e| e.hostname.clone()).collect();
764        tracing::info!("obtained {} hosts for this allocation", hostnames.len());
765
766        // Split the extent into regions, one per host.
767        use crate::alloc::ProcAllocationMode;
768
769        // For HostLevel, pre-compute regions. For ProcLevel, skip this step.
770        let regions: Option<Vec<_>> = match self.spec.proc_allocation_mode {
771            ProcAllocationMode::ProcLevel => {
772                // We require at least a dimension for hosts, and one for sub-host (e.g., GPUs)
773                anyhow::ensure!(
774                    self.spec.extent.len() >= 2,
775                    "invalid extent: {}, expected at least 2 dimensions",
776                    self.spec.extent
777                );
778                None
779            }
780            ProcAllocationMode::HostLevel => Some({
781                // HostLevel: each point is a host, create a region for each point
782                let num_points = self.spec.extent.num_ranks();
783                anyhow::ensure!(
784                    hosts.len() >= num_points,
785                    "HostLevel allocation mode requires {} hosts (one per point in extent {}), but only {} hosts were provided",
786                    num_points,
787                    self.spec.extent,
788                    hosts.len()
789                );
790
791                // For HostLevel, create a single-point region for each rank
792                // Each region contains one point that maps to the correct global rank
793                let labels = self.spec.extent.labels().to_vec();
794
795                // Compute strides for row-major layout: strides[i] = product of sizes[i+1..n]
796                let extent_sizes = self.spec.extent.sizes();
797                let mut parent_strides = vec![1; extent_sizes.len()];
798                for i in (0..extent_sizes.len() - 1).rev() {
799                    parent_strides[i] = parent_strides[i + 1] * extent_sizes[i + 1];
800                }
801
802                (0..num_points)
803                    .map(|rank| {
804                        // Create a slice containing only this rank
805                        // Use parent's strides so local point [0,0,...] maps to the correct global rank
806                        let sizes = vec![1; labels.len()];
807                        Region::new(
808                            labels.clone(),
809                            Slice::new(rank, sizes, parent_strides.clone()).unwrap(),
810                        )
811                    })
812                    .collect()
813            }),
814        };
815
816        match self.spec.proc_allocation_mode {
817            ProcAllocationMode::ProcLevel => {
818                // We group by the innermost dimension of the extent.
819                let split_dim = &self.spec.extent.labels()[self.spec.extent.len() - 1];
820                for (i, region) in self.spec.extent.group_by(split_dim)?.enumerate() {
821                    let host = &hosts[i];
822                    tracing::debug!("allocating: {} for host: {}", region, host.id);
823
824                    let remote_addr = match self.spec.transport {
825                        ChannelTransport::MetaTls(_) => {
826                            format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
827                        }
828                        ChannelTransport::Tcp(TcpMode::Localhost) => {
829                            // TODO: @rusch see about moving over to config for this
830                            format!("tcp![::1]:{}", self.remote_allocator_port)
831                        }
832                        ChannelTransport::Tcp(TcpMode::Hostname) => {
833                            format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
834                        }
835                        // Used only for testing.
836                        ChannelTransport::Unix => host.hostname.clone(),
837                        _ => {
838                            anyhow::bail!(
839                                "unsupported transport for host {}: {:?}",
840                                host.id,
841                                self.spec.transport,
842                            );
843                        }
844                    };
845
846                    tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
847                    let remote_addr = remote_addr.parse::<ChannelAddr>()?;
848                    let tx = channel::dial(remote_addr.clone())
849                        .map_err(anyhow::Error::from)
850                        .context(format!(
851                            "failed to dial remote {} for host {}",
852                            remote_addr, host.id
853                        ))?;
854
855                    // Possibly we could use the HostId directly here.
856                    let alloc_key = ShortUuid::generate();
857                    assert!(
858                        self.alloc_to_host
859                            .insert(alloc_key.clone(), host.id.clone())
860                            .is_none()
861                    );
862
863                    let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
864                    let client_context = Some(ClientContext { trace_id });
865                    let message = RemoteProcessAllocatorMessage::Allocate {
866                        alloc_key: alloc_key.clone(),
867                        extent: region.extent(),
868                        bootstrap_addr: self.bootstrap_addr.clone(),
869                        hosts: hostnames.clone(),
870                        client_context,
871                        // Make sure allocator's forwarder uses the same IP address
872                        // which is known to alloc. This is to avoid allocator picks
873                        // its host's private IP address, while its known addres to
874                        // alloc is a public IP address. In some environment, that
875                        // could lead to port unreachable error.
876                        forwarder_addr: with_unspecified_port_or_any(&remote_addr),
877                    };
878                    tracing::info!(
879                        name = message.as_ref(),
880                        "sending allocate message to workers"
881                    );
882                    tx.post(message);
883
884                    self.host_states.insert(
885                        host.id.clone(),
886                        RemoteProcessAllocHostState {
887                            alloc_key,
888                            host_id: host.id.clone(),
889                            tx,
890                            active_procs: HashSet::new(),
891                            region,
892                            world_id: None,
893                            failed: false,
894                            allocated: false,
895                        },
896                        remote_addr,
897                    );
898                }
899
900                self.ordered_hosts = hosts;
901            }
902            ProcAllocationMode::HostLevel => {
903                let regions = regions.unwrap();
904                let num_regions = regions.len();
905                for (i, region) in regions.into_iter().enumerate() {
906                    let host = &hosts[i];
907                    tracing::debug!("allocating: {} for host: {}", region, host.id);
908
909                    let remote_addr = match self.spec.transport {
910                        ChannelTransport::MetaTls(_) => {
911                            format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
912                        }
913                        ChannelTransport::Tcp(TcpMode::Localhost) => {
914                            // TODO: @rusch see about moving over to config for this
915                            format!("tcp![::1]:{}", self.remote_allocator_port)
916                        }
917                        ChannelTransport::Tcp(TcpMode::Hostname) => {
918                            format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
919                        }
920                        // Used only for testing.
921                        ChannelTransport::Unix => host.hostname.clone(),
922                        _ => {
923                            anyhow::bail!(
924                                "unsupported transport for host {}: {:?}",
925                                host.id,
926                                self.spec.transport,
927                            );
928                        }
929                    };
930
931                    tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
932                    let remote_addr = remote_addr.parse::<ChannelAddr>()?;
933                    let tx = channel::dial(remote_addr.clone())
934                        .map_err(anyhow::Error::from)
935                        .context(format!(
936                            "failed to dial remote {} for host {}",
937                            remote_addr, host.id
938                        ))?;
939
940                    // Possibly we could use the HostId directly here.
941                    let alloc_key = ShortUuid::generate();
942                    assert!(
943                        self.alloc_to_host
944                            .insert(alloc_key.clone(), host.id.clone())
945                            .is_none()
946                    );
947
948                    let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
949                    let client_context = Some(ClientContext { trace_id });
950                    let message = RemoteProcessAllocatorMessage::Allocate {
951                        alloc_key: alloc_key.clone(),
952                        extent: region.extent(),
953                        bootstrap_addr: self.bootstrap_addr.clone(),
954                        hosts: hostnames.clone(),
955                        client_context,
956                        // Make sure allocator's forwarder uses the same IP address
957                        // which is known to alloc. This is to avoid allocator picks
958                        // its host's private IP address, while its known addres to
959                        // alloc is a public IP address. In some environment, that
960                        // could lead to port unreachable error.
961                        forwarder_addr: with_unspecified_port_or_any(&remote_addr),
962                    };
963                    tracing::info!(
964                        name = message.as_ref(),
965                        "sending allocate message to workers"
966                    );
967                    tx.post(message);
968
969                    self.host_states.insert(
970                        host.id.clone(),
971                        RemoteProcessAllocHostState {
972                            alloc_key,
973                            host_id: host.id.clone(),
974                            tx,
975                            active_procs: HashSet::new(),
976                            region,
977                            world_id: None,
978                            failed: false,
979                            allocated: false,
980                        },
981                        remote_addr,
982                    );
983                }
984
985                // Only store hosts that were actually used for regions
986                // If num_regions < hosts.len(), we only use the first num_regions hosts
987                self.ordered_hosts = hosts.into_iter().take(num_regions).collect();
988            }
989        }
990        self.start_comm_watcher().await;
991        self.started = true;
992
993        Ok(())
994    }
995
996    // Given a proc_id, obtain the internal HostState structure.
997    fn get_host_state_mut(
998        &mut self,
999        alloc_key: &ShortUuid,
1000    ) -> Result<&mut RemoteProcessAllocHostState, anyhow::Error> {
1001        let host_id: &HostId = self
1002            .alloc_to_host
1003            .get(alloc_key)
1004            .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1005
1006        self.host_states
1007            .get_mut(host_id)
1008            .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1009    }
1010
1011    // Given a proc_id, obtain the internal HostState structure.
1012    fn get_host_state(
1013        &self,
1014        alloc_key: &ShortUuid,
1015    ) -> Result<&RemoteProcessAllocHostState, anyhow::Error> {
1016        let host_id: &HostId = self
1017            .alloc_to_host
1018            .get(alloc_key)
1019            .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1020
1021        self.host_states
1022            .get(host_id)
1023            .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1024    }
1025
1026    fn remove_host_state(
1027        &mut self,
1028        alloc_key: &ShortUuid,
1029    ) -> Result<RemoteProcessAllocHostState, anyhow::Error> {
1030        let host_id: &HostId = self
1031            .alloc_to_host
1032            .get(alloc_key)
1033            .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
1034
1035        self.host_states
1036            .remove(host_id)
1037            .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
1038    }
1039
1040    fn add_proc_id_to_host_state(
1041        &mut self,
1042        alloc_key: &ShortUuid,
1043        create_key: &ShortUuid,
1044    ) -> Result<(), anyhow::Error> {
1045        let task_state = self.get_host_state_mut(alloc_key)?;
1046        if !task_state.active_procs.insert(create_key.clone()) {
1047            // Should not happen but we can ignore
1048            tracing::error!("proc with create key {} already in host state", create_key);
1049        }
1050        task_state.allocated = true;
1051        Ok(())
1052    }
1053
1054    fn remove_proc_from_host_state(
1055        &mut self,
1056        alloc_key: &ShortUuid,
1057        create_key: &ShortUuid,
1058    ) -> Result<(), anyhow::Error> {
1059        let task_state = self.get_host_state_mut(alloc_key)?;
1060        if !task_state.active_procs.remove(create_key) {
1061            // Should not happen but we can ignore
1062            tracing::error!("proc with create_key already in host state: {}", create_key);
1063        }
1064        Ok(())
1065    }
1066
1067    // Reproject proc world coords to global shape coords.
1068    fn project_proc_into_global_extent(
1069        &self,
1070        alloc_key: &ShortUuid,
1071        point: &Point,
1072    ) -> Result<Point, anyhow::Error> {
1073        let global_rank = self
1074            .get_host_state(alloc_key)?
1075            .region
1076            .get(point.rank())
1077            .ok_or_else(|| {
1078                anyhow::anyhow!(
1079                    "rank {} out of bounds for in alloc {}",
1080                    point.rank(),
1081                    alloc_key
1082                )
1083            })?;
1084        Ok(self.spec.extent.point_of_rank(global_rank)?)
1085    }
1086
1087    // Cleanup a comm-failed host information by its ID.
1088    fn cleanup_host_channel_closed(
1089        &mut self,
1090        host_id: HostId,
1091    ) -> Result<Vec<ShortUuid>, anyhow::Error> {
1092        let state = match self.host_states.remove(&host_id) {
1093            Some(state) => state,
1094            None => {
1095                // this should never happen.
1096                anyhow::bail!(
1097                    "got channel closed event for host {} which has no known state",
1098                    host_id
1099                );
1100            }
1101        };
1102        self.ordered_hosts.retain(|host| host.id != host_id);
1103        self.alloc_to_host.remove(&state.alloc_key);
1104        if let Some(world_id) = state.world_id {
1105            self.world_offsets.remove(&world_id);
1106        }
1107        let create_keys = state.active_procs.iter().cloned().collect();
1108
1109        Ok(create_keys)
1110    }
1111}
1112
1113#[async_trait]
1114impl Alloc for RemoteProcessAlloc {
1115    async fn next(&mut self) -> Option<ProcState> {
1116        loop {
1117            if let state @ Some(_) = self.event_queue.pop_front() {
1118                break state;
1119            }
1120
1121            if !self.running {
1122                break None;
1123            }
1124
1125            if let Err(e) = self.ensure_started().await {
1126                break Some(ProcState::Failed {
1127                    world_id: self.world_id.clone(),
1128                    description: format!("failed to ensure started: {:#}", e),
1129                });
1130            }
1131
1132            let heartbeat_interval =
1133                config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL);
1134            let mut heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
1135            // rerun outer loop in case we pushed new items to the event queue
1136            let mut reloop = false;
1137            let update = loop {
1138                tokio::select! {
1139                    msg = self.rx.recv() => {
1140                        tracing::debug!("got ProcState message from allocator: {:?}", msg);
1141                        match msg {
1142                            Ok(RemoteProcessProcStateMessage::Allocated { alloc_key, world_id }) => {
1143                                tracing::info!("remote alloc {}: allocated", alloc_key);
1144                                match self.get_host_state_mut(&alloc_key) {
1145                                    Ok(state) => {
1146                                        state.world_id = Some(world_id.clone());
1147                                    }
1148                                    Err(err) => {
1149                                        // should never happenA
1150                                        tracing::error!(
1151                                            "received allocated message alloc: {} with no known state: {}",
1152                                            alloc_key, err,
1153                                        );
1154                                    }
1155                                }
1156                            }
1157                            Ok(RemoteProcessProcStateMessage::Update(alloc_key, proc_state)) => {
1158                                let update = match proc_state {
1159                                    ProcState::Created { ref create_key, .. } => {
1160                                        if let Err(e) = self.add_proc_id_to_host_state(&alloc_key, create_key) {
1161                                            tracing::error!("failed to add proc with create key {} host state: {}", create_key, e);
1162                                        }
1163                                        proc_state
1164                                    }
1165                                    ProcState::Stopped{ ref create_key, ..} => {
1166                                        if let Err(e) = self.remove_proc_from_host_state(&alloc_key, create_key) {
1167                                            tracing::error!("failed to remove proc with create key {} host state: {}", create_key, e);
1168                                        }
1169                                        proc_state
1170                                    }
1171                                    ProcState::Failed { ref world_id, ref description } => {
1172                                        match self.get_host_state_mut(&alloc_key) {
1173                                            Ok(state) => {
1174                                                state.failed = true;
1175                                                ProcState::Failed {
1176                                                    world_id: world_id.clone(),
1177                                                    description: format!("host {} failed: {}", state.host_id, description),
1178                                                }
1179                                            }
1180                                            Err(e) => {
1181                                                tracing::error!("failed to find host state for world id: {}: {}", world_id, e);
1182                                                proc_state
1183                                            }
1184                                        }
1185                                    }
1186                                    _ => proc_state
1187                                };
1188
1189                                break Some((Some(alloc_key), update));
1190                            }
1191                            Ok(RemoteProcessProcStateMessage::Done(alloc_key)) => {
1192                                tracing::info!("allocator {} is done", alloc_key);
1193
1194                                if let Ok(state) = self.remove_host_state(&alloc_key) {
1195                                    if !state.active_procs.is_empty() {
1196                                        tracing::error!("received done for alloc {} with active procs: {:?}", alloc_key, state.active_procs);
1197                                    }
1198                                } else {
1199                                    tracing::error!("received done for unknown alloc {}", alloc_key);
1200                                }
1201
1202                                if self.host_states.is_empty() {
1203                                    self.running = false;
1204                                    break None;
1205                                }
1206                            }
1207                            // Hearbeat message is discarded immediately after being received, sender (remote
1208                            // process allocator) relies on channel ack to know if the receiver (client) is
1209                            // still alive. No state needs to be updated.
1210                            Ok(RemoteProcessProcStateMessage::HeartBeat) => {}
1211                            Err(e) => {
1212                                break Some((None, ProcState::Failed {world_id: self.world_id.clone(), description: format!("error receiving events: {}", e)}));
1213                            }
1214                        }
1215                    }
1216
1217                    _ = clock::RealClock.sleep_until(heartbeat_time) => {
1218                        self.host_states.iter().for_each(|(_, host_state)| host_state.tx.post(RemoteProcessAllocatorMessage::HeartBeat));
1219                        heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
1220                    }
1221
1222                    closed_host_id = self.comm_watcher_rx.recv() => {
1223                        if let Some(closed_host_id) = closed_host_id {
1224                            tracing::debug!("host {} channel closed, cleaning up", closed_host_id);
1225                            if let Some(state) = self.host_states.get(&closed_host_id)
1226                                && !state.allocated {
1227                                    break Some((None, ProcState::Failed {
1228                                        world_id: self.world_id.clone(),
1229                                        description: format!(
1230                                            "no process has ever been allocated on {} before the channel is closed; \
1231                                            a common issue could be the channel was never established",
1232                                            closed_host_id
1233                                        )}));
1234                                }
1235                            let create_keys = match self.cleanup_host_channel_closed(closed_host_id) {
1236                                Ok(create_keys) => create_keys,
1237                                Err(err) => {
1238                                    tracing::error!("failed to cleanup disconnected host: {}", err);
1239                                    continue;
1240                                }
1241                            };
1242                            for create_key in create_keys {
1243                                tracing::debug!("queuing Stopped state for proc with create key {}", create_key);
1244                                self.event_queue.push_back(
1245                                    ProcState::Stopped {
1246                                        create_key,
1247                                        reason: ProcStopReason::HostWatchdog
1248                                    }
1249                                );
1250                            }
1251                            // Check if there are any hosts left
1252                            if self.host_states.is_empty() {
1253                                tracing::info!("no more hosts left, stopping the alloc");
1254                                self.running = false;
1255                            }
1256                            // Kick back to the outer loop to pop off the queue if necessary
1257                            reloop = true;
1258                            break None;
1259                        } else {
1260                            // other side closed
1261                            tracing::warn!("unexpected comm watcher channel close");
1262                            break None;
1263                        }
1264                    }
1265                }
1266            };
1267
1268            if reloop {
1269                // Kick back to the outer loop to pop off the queue.
1270                // Should not have produced an update.
1271                assert!(update.is_none());
1272                continue;
1273            }
1274
1275            break match update {
1276                Some((
1277                    Some(alloc_key),
1278                    ProcState::Created {
1279                        create_key,
1280                        point,
1281                        pid,
1282                    },
1283                )) => match self.project_proc_into_global_extent(&alloc_key, &point) {
1284                    Ok(global_point) => {
1285                        tracing::debug!("reprojected coords: {} -> {}", point, global_point);
1286                        Some(ProcState::Created {
1287                            create_key,
1288                            point: global_point,
1289                            pid,
1290                        })
1291                    }
1292                    Err(e) => {
1293                        tracing::error!(
1294                            "failed to project coords for proc: {}.{}: {}",
1295                            alloc_key,
1296                            create_key,
1297                            e
1298                        );
1299                        None
1300                    }
1301                },
1302                Some((None, ProcState::Created { .. })) => {
1303                    panic!("illegal state: missing alloc_key for ProcState::Created event")
1304                }
1305                Some((_, update)) => {
1306                    if let ProcState::Failed { description, .. } = &update {
1307                        tracing::error!(description);
1308                        self.failed = true;
1309                    }
1310                    Some(update)
1311                }
1312                None => None,
1313            };
1314        }
1315    }
1316
1317    fn spec(&self) -> &AllocSpec {
1318        &self.spec
1319    }
1320
1321    fn extent(&self) -> &Extent {
1322        &self.spec.extent
1323    }
1324
1325    fn world_id(&self) -> &WorldId {
1326        &self.world_id
1327    }
1328
1329    async fn stop(&mut self) -> Result<(), AllocatorError> {
1330        tracing::info!("stopping alloc");
1331
1332        for (host_id, task_state) in self.host_states.iter_mut() {
1333            tracing::debug!("stopping alloc at host {}", host_id);
1334            task_state.tx.post(RemoteProcessAllocatorMessage::Stop);
1335        }
1336
1337        Ok(())
1338    }
1339
1340    /// For Tcp and Metatls, return a router address which has the same IP address
1341    /// as the alloc's bootstrap address, but with port set as 0. We do this
1342    /// instead of using `ChannelAddr::any` is because we want alloc uses the
1343    /// same IP address for all its frontend ports. In some environment, the
1344    /// host can have public IP address and private IP address, and use the wrong
1345    /// one could lead to port unreachable error.
1346    ///
1347    /// For other channel types, this method still uses ChannelAddr::any.
1348    fn client_router_addr(&self) -> ChannelAddr {
1349        with_unspecified_port_or_any(&self.bootstrap_addr)
1350    }
1351}
1352
1353impl Drop for RemoteProcessAlloc {
1354    fn drop(&mut self) {
1355        tracing::debug!("dropping RemoteProcessAlloc of world_id {}", self.world_id);
1356    }
1357}
1358
1359#[cfg(test)]
1360mod test {
1361    use std::assert_matches::assert_matches;
1362
1363    use hyperactor::ActorRef;
1364    use hyperactor::channel::ChannelRx;
1365    use hyperactor::clock::ClockKind;
1366    use hyperactor::id;
1367    use ndslice::extent;
1368    use tokio::sync::oneshot;
1369
1370    use super::*;
1371    use crate::alloc::ChannelTransport;
1372    use crate::alloc::MockAlloc;
1373    use crate::alloc::MockAllocWrapper;
1374    use crate::alloc::MockAllocator;
1375    use crate::alloc::ProcStopReason;
1376    use crate::alloc::with_unspecified_port_or_any;
1377    use crate::proc_mesh::mesh_agent::ProcMeshAgent;
1378
1379    async fn read_all_created(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1380        let mut i: usize = 0;
1381        while i < alloc_len {
1382            let m = rx.recv().await.unwrap();
1383            match m {
1384                RemoteProcessProcStateMessage::Update(_, ProcState::Created { .. }) => i += 1,
1385                RemoteProcessProcStateMessage::HeartBeat => {}
1386                _ => panic!("unexpected message: {:?}", m),
1387            }
1388        }
1389    }
1390
1391    async fn read_all_running(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1392        let mut i: usize = 0;
1393        while i < alloc_len {
1394            let m = rx.recv().await.unwrap();
1395            match m {
1396                RemoteProcessProcStateMessage::Update(_, ProcState::Running { .. }) => i += 1,
1397                RemoteProcessProcStateMessage::HeartBeat => {}
1398                _ => panic!("unexpected message: {:?}", m),
1399            }
1400        }
1401    }
1402
1403    async fn read_all_stopped(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1404        let mut i: usize = 0;
1405        while i < alloc_len {
1406            let m = rx.recv().await.unwrap();
1407            match m {
1408                RemoteProcessProcStateMessage::Update(_, ProcState::Stopped { .. }) => i += 1,
1409                RemoteProcessProcStateMessage::HeartBeat => {}
1410                _ => panic!("unexpected message: {:?}", m),
1411            }
1412        }
1413    }
1414
1415    fn set_procstate_expectations(alloc: &mut MockAlloc, extent: Extent) {
1416        alloc.expect_extent().return_const(extent.clone());
1417        let mut create_keys = Vec::new();
1418        for point in extent.points() {
1419            let create_key = ShortUuid::generate();
1420            create_keys.push(create_key.clone());
1421            alloc.expect_next().times(1).return_once(move || {
1422                Some(ProcState::Created {
1423                    create_key: create_key.clone(),
1424                    point,
1425                    pid: 0,
1426                })
1427            });
1428        }
1429        for (i, create_key) in create_keys
1430            .iter()
1431            .take(extent.num_ranks())
1432            .cloned()
1433            .enumerate()
1434        {
1435            let proc_id = format!("test[{i}]").parse().unwrap();
1436            let mesh_agent = ActorRef::<ProcMeshAgent>::attest(
1437                format!("test[{i}].mesh_agent[{i}]").parse().unwrap(),
1438            );
1439            alloc.expect_next().times(1).return_once(move || {
1440                Some(ProcState::Running {
1441                    create_key,
1442                    proc_id,
1443                    addr: ChannelAddr::Unix("/proc0".parse().unwrap()),
1444                    mesh_agent,
1445                })
1446            });
1447        }
1448        for create_key in create_keys.iter().take(extent.num_ranks()).cloned() {
1449            alloc.expect_next().times(1).return_once(|| {
1450                Some(ProcState::Stopped {
1451                    create_key,
1452                    reason: ProcStopReason::Unknown,
1453                })
1454            });
1455        }
1456    }
1457
1458    #[timed_test::async_timed_test(timeout_secs = 5)]
1459    async fn test_simple() {
1460        let config = hyperactor::config::global::lock();
1461        let _guard = config.override_key(
1462            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1463            Duration::from_millis(100),
1464        );
1465        hyperactor_telemetry::initialize_logging_for_test();
1466        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1467        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1468        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1469
1470        let extent = extent!(host = 1, gpu = 2);
1471        let tx = channel::dial(serve_addr.clone()).unwrap();
1472
1473        let world_id: WorldId = id!(test_world_id);
1474        let mut alloc = MockAlloc::new();
1475        alloc.expect_world_id().return_const(world_id.clone());
1476        alloc.expect_extent().return_const(extent.clone());
1477
1478        set_procstate_expectations(&mut alloc, extent.clone());
1479
1480        // final none
1481        alloc.expect_next().return_const(None);
1482
1483        let mut allocator = MockAllocator::new();
1484        let total_messages = extent.num_ranks() * 3 + 1;
1485        let mock_wrapper = MockAllocWrapper::new_block_next(
1486            alloc,
1487            // block after create, running, stopped and done.
1488            total_messages,
1489        );
1490        allocator
1491            .expect_allocate()
1492            .times(1)
1493            .return_once(move |_| Ok(mock_wrapper));
1494
1495        let remote_allocator = RemoteProcessAllocator::new();
1496        let handle = tokio::spawn({
1497            let remote_allocator = remote_allocator.clone();
1498            async move {
1499                remote_allocator
1500                    .start_with_allocator(serve_addr, allocator, None)
1501                    .await
1502            }
1503        });
1504
1505        let alloc_key = ShortUuid::generate();
1506
1507        tx.send(RemoteProcessAllocatorMessage::Allocate {
1508            alloc_key: alloc_key.clone(),
1509            extent: extent.clone(),
1510            bootstrap_addr,
1511            hosts: vec![],
1512            client_context: None,
1513            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1514        })
1515        .await
1516        .unwrap();
1517
1518        // Allocated
1519        let m = rx.recv().await.unwrap();
1520        assert_matches!(
1521            m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1522                if got_world_id == world_id && got_alloc_key == alloc_key
1523        );
1524
1525        // All Created events
1526        let mut rank: usize = 0;
1527        let mut create_keys = Vec::with_capacity(extent.num_ranks());
1528        while rank < extent.num_ranks() {
1529            let m = rx.recv().await.unwrap();
1530            match m {
1531                RemoteProcessProcStateMessage::Update(
1532                    got_alloc_key,
1533                    ProcState::Created {
1534                        create_key, point, ..
1535                    },
1536                ) => {
1537                    let expected_point = extent.point_of_rank(rank).unwrap();
1538                    assert_eq!(got_alloc_key, alloc_key);
1539                    assert_eq!(point, expected_point);
1540                    create_keys.push(create_key);
1541                    rank += 1;
1542                }
1543                RemoteProcessProcStateMessage::HeartBeat => {}
1544                _ => panic!("unexpected message: {:?}", m),
1545            }
1546        }
1547        // All Running events
1548        let mut rank: usize = 0;
1549        while rank < extent.num_ranks() {
1550            let m = rx.recv().await.unwrap();
1551            match m {
1552                RemoteProcessProcStateMessage::Update(
1553                    got_alloc_key,
1554                    ProcState::Running {
1555                        create_key,
1556                        proc_id,
1557                        mesh_agent,
1558                        addr: _,
1559                    },
1560                ) => {
1561                    assert_eq!(got_alloc_key, alloc_key);
1562                    assert_eq!(create_key, create_keys[rank]);
1563                    let expected_proc_id = format!("test[{}]", rank).parse().unwrap();
1564                    let expected_mesh_agent = ActorRef::<ProcMeshAgent>::attest(
1565                        format!("test[{}].mesh_agent[{}]", rank, rank)
1566                            .parse()
1567                            .unwrap(),
1568                    );
1569                    assert_eq!(proc_id, expected_proc_id);
1570                    assert_eq!(mesh_agent, expected_mesh_agent);
1571                    rank += 1;
1572                }
1573                RemoteProcessProcStateMessage::HeartBeat => {}
1574                _ => panic!("unexpected message: {:?}", m),
1575            }
1576        }
1577        // All Stopped events
1578        let mut rank: usize = 0;
1579        while rank < extent.num_ranks() {
1580            let m = rx.recv().await.unwrap();
1581            match m {
1582                RemoteProcessProcStateMessage::Update(
1583                    got_alloc_key,
1584                    ProcState::Stopped {
1585                        create_key,
1586                        reason: ProcStopReason::Unknown,
1587                    },
1588                ) => {
1589                    assert_eq!(got_alloc_key, alloc_key);
1590                    assert_eq!(create_key, create_keys[rank]);
1591                    rank += 1;
1592                }
1593                RemoteProcessProcStateMessage::HeartBeat => {}
1594                _ => panic!("unexpected message: {:?}", m),
1595            }
1596        }
1597        // Done
1598        loop {
1599            let m = rx.recv().await.unwrap();
1600            match m {
1601                RemoteProcessProcStateMessage::Done(got_alloc_key) => {
1602                    assert_eq!(got_alloc_key, alloc_key);
1603                    break;
1604                }
1605                RemoteProcessProcStateMessage::HeartBeat => {}
1606                _ => panic!("unexpected message: {:?}", m),
1607            }
1608        }
1609
1610        remote_allocator.terminate();
1611        handle.await.unwrap().unwrap();
1612    }
1613
1614    #[timed_test::async_timed_test(timeout_secs = 15)]
1615    async fn test_normal_stop() {
1616        let config = hyperactor::config::global::lock();
1617        let _guard = config.override_key(
1618            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1619            Duration::from_millis(100),
1620        );
1621        hyperactor_telemetry::initialize_logging_for_test();
1622        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1623        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1624        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1625
1626        let extent = extent!(host = 1, gpu = 2);
1627        let tx = channel::dial(serve_addr.clone()).unwrap();
1628
1629        let world_id: WorldId = id!(test_world_id);
1630        let mut alloc = MockAllocWrapper::new_block_next(
1631            MockAlloc::new(),
1632            // block after all created, all running
1633            extent.num_ranks() * 2,
1634        );
1635        let next_tx = alloc.notify_tx();
1636        alloc.alloc.expect_world_id().return_const(world_id.clone());
1637        alloc.alloc.expect_extent().return_const(extent.clone());
1638
1639        set_procstate_expectations(&mut alloc.alloc, extent.clone());
1640
1641        alloc.alloc.expect_next().return_const(None);
1642        alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1643
1644        let mut allocator = MockAllocator::new();
1645        allocator
1646            .expect_allocate()
1647            .times(1)
1648            .return_once(|_| Ok(alloc));
1649
1650        let remote_allocator = RemoteProcessAllocator::new();
1651        let handle = tokio::spawn({
1652            let remote_allocator = remote_allocator.clone();
1653            async move {
1654                remote_allocator
1655                    .start_with_allocator(serve_addr, allocator, None)
1656                    .await
1657            }
1658        });
1659
1660        let alloc_key = ShortUuid::generate();
1661        tx.send(RemoteProcessAllocatorMessage::Allocate {
1662            alloc_key: alloc_key.clone(),
1663            extent: extent.clone(),
1664            bootstrap_addr,
1665            hosts: vec![],
1666            client_context: None,
1667            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1668        })
1669        .await
1670        .unwrap();
1671
1672        // Allocated
1673        let m = rx.recv().await.unwrap();
1674        assert_matches!(
1675            m,
1676            RemoteProcessProcStateMessage::Allocated {  world_id: got_world_id, alloc_key: got_alloc_key }
1677            if world_id == got_world_id && alloc_key == got_alloc_key
1678        );
1679
1680        read_all_created(&mut rx, extent.num_ranks()).await;
1681        read_all_running(&mut rx, extent.num_ranks()).await;
1682
1683        // allocation finished. now we stop it.
1684        tracing::info!("stopping allocation");
1685        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1686        // receive all stops
1687        next_tx.send(()).unwrap();
1688
1689        read_all_stopped(&mut rx, extent.num_ranks()).await;
1690
1691        remote_allocator.terminate();
1692        handle.await.unwrap().unwrap();
1693    }
1694
1695    #[timed_test::async_timed_test(timeout_secs = 15)]
1696    async fn test_realloc() {
1697        let config = hyperactor::config::global::lock();
1698        let _guard = config.override_key(
1699            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1700            Duration::from_millis(100),
1701        );
1702        hyperactor_telemetry::initialize_logging_for_test();
1703        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1704        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1705        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1706
1707        let extent = extent!(host = 1, gpu = 2);
1708
1709        let tx = channel::dial(serve_addr.clone()).unwrap();
1710
1711        let world_id: WorldId = id!(test_world_id);
1712        let mut alloc1 = MockAllocWrapper::new_block_next(
1713            MockAlloc::new(),
1714            // block after all created, all running
1715            extent.num_ranks() * 2,
1716        );
1717        let next_tx1 = alloc1.notify_tx();
1718        alloc1
1719            .alloc
1720            .expect_world_id()
1721            .return_const(world_id.clone());
1722        alloc1.alloc.expect_extent().return_const(extent.clone());
1723
1724        set_procstate_expectations(&mut alloc1.alloc, extent.clone());
1725        alloc1.alloc.expect_next().return_const(None);
1726        alloc1.alloc.expect_stop().times(1).return_once(|| Ok(()));
1727        // second allocation
1728        let mut alloc2 = MockAllocWrapper::new_block_next(
1729            MockAlloc::new(),
1730            // block after all created, all running
1731            extent.num_ranks() * 2,
1732        );
1733        let next_tx2 = alloc2.notify_tx();
1734        alloc2
1735            .alloc
1736            .expect_world_id()
1737            .return_const(world_id.clone());
1738        alloc2.alloc.expect_extent().return_const(extent.clone());
1739        set_procstate_expectations(&mut alloc2.alloc, extent.clone());
1740        alloc2.alloc.expect_next().return_const(None);
1741        alloc2.alloc.expect_stop().times(1).return_once(|| Ok(()));
1742
1743        let mut allocator = MockAllocator::new();
1744        allocator
1745            .expect_allocate()
1746            .times(1)
1747            .return_once(|_| Ok(alloc1));
1748        // second alloc
1749        allocator
1750            .expect_allocate()
1751            .times(1)
1752            .return_once(|_| Ok(alloc2));
1753
1754        let remote_allocator = RemoteProcessAllocator::new();
1755        let handle = tokio::spawn({
1756            let remote_allocator = remote_allocator.clone();
1757            async move {
1758                remote_allocator
1759                    .start_with_allocator(serve_addr, allocator, None)
1760                    .await
1761            }
1762        });
1763
1764        let alloc_key = ShortUuid::generate();
1765
1766        tx.send(RemoteProcessAllocatorMessage::Allocate {
1767            alloc_key: alloc_key.clone(),
1768            extent: extent.clone(),
1769            bootstrap_addr: bootstrap_addr.clone(),
1770            hosts: vec![],
1771            client_context: None,
1772            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1773        })
1774        .await
1775        .unwrap();
1776
1777        // Allocated
1778        let m = rx.recv().await.unwrap();
1779        assert_matches!(
1780            m,
1781            RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1782            if got_world_id == world_id && got_alloc_key == alloc_key
1783        );
1784
1785        read_all_created(&mut rx, extent.num_ranks()).await;
1786        read_all_running(&mut rx, extent.num_ranks()).await;
1787
1788        let alloc_key = ShortUuid::generate();
1789
1790        // allocation finished now we request a new one
1791        tx.send(RemoteProcessAllocatorMessage::Allocate {
1792            alloc_key: alloc_key.clone(),
1793            extent: extent.clone(),
1794            bootstrap_addr,
1795            hosts: vec![],
1796            client_context: None,
1797            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1798        })
1799        .await
1800        .unwrap();
1801        // unblock next for the first allocation
1802        next_tx1.send(()).unwrap();
1803        // we expect a stop(), then Stopped proc states, then a new Allocated
1804        read_all_stopped(&mut rx, extent.num_ranks()).await;
1805        let m = rx.recv().await.unwrap();
1806        assert_matches!(m, RemoteProcessProcStateMessage::Done(_));
1807        let m = rx.recv().await.unwrap();
1808        assert_matches!(
1809            m,
1810            RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1811            if got_world_id == world_id && got_alloc_key == alloc_key
1812        );
1813        // ProcStates for the new allocation
1814        read_all_created(&mut rx, extent.num_ranks()).await;
1815        read_all_running(&mut rx, extent.num_ranks()).await;
1816        // finally stop
1817        tracing::info!("stopping allocation");
1818        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1819        // receive all stops
1820        next_tx2.send(()).unwrap();
1821
1822        read_all_stopped(&mut rx, extent.num_ranks()).await;
1823
1824        remote_allocator.terminate();
1825        handle.await.unwrap().unwrap();
1826    }
1827
1828    #[timed_test::async_timed_test(timeout_secs = 15)]
1829    async fn test_upstream_closed() {
1830        // Use temporary config for this test
1831        let config = hyperactor::config::global::lock();
1832        let _guard1 = config.override_key(
1833            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1834            Duration::from_secs(1),
1835        );
1836        let _guard2 = config.override_key(
1837            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1838            Duration::from_millis(100),
1839        );
1840
1841        hyperactor_telemetry::initialize_logging_for_test();
1842        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1843        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1844        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1845
1846        let extent = extent!(host = 1, gpu = 2);
1847
1848        let tx = channel::dial(serve_addr.clone()).unwrap();
1849
1850        let world_id: WorldId = id!(test_world_id);
1851        let mut alloc = MockAllocWrapper::new_block_next(
1852            MockAlloc::new(),
1853            // block after all created, all running
1854            extent.num_ranks() * 2,
1855        );
1856        let next_tx = alloc.notify_tx();
1857        alloc.alloc.expect_world_id().return_const(world_id.clone());
1858        alloc.alloc.expect_extent().return_const(extent.clone());
1859
1860        set_procstate_expectations(&mut alloc.alloc, extent.clone());
1861
1862        alloc.alloc.expect_next().return_const(None);
1863        // we expect a stop due to the failure
1864        // synchronize test with the stop
1865        let (stop_tx, stop_rx) = oneshot::channel();
1866        alloc.alloc.expect_stop().times(1).return_once(|| {
1867            stop_tx.send(()).unwrap();
1868            Ok(())
1869        });
1870
1871        let mut allocator = MockAllocator::new();
1872        allocator
1873            .expect_allocate()
1874            .times(1)
1875            .return_once(|_| Ok(alloc));
1876
1877        let remote_allocator = RemoteProcessAllocator::new();
1878        let handle = tokio::spawn({
1879            let remote_allocator = remote_allocator.clone();
1880            async move {
1881                remote_allocator
1882                    .start_with_allocator(serve_addr, allocator, None)
1883                    .await
1884            }
1885        });
1886
1887        let alloc_key = ShortUuid::generate();
1888
1889        tx.send(RemoteProcessAllocatorMessage::Allocate {
1890            alloc_key: alloc_key.clone(),
1891            extent: extent.clone(),
1892            bootstrap_addr,
1893            hosts: vec![],
1894            client_context: None,
1895            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1896        })
1897        .await
1898        .unwrap();
1899
1900        // Allocated
1901        let m = rx.recv().await.unwrap();
1902        assert_matches!(
1903            m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1904                if got_world_id == world_id && got_alloc_key == alloc_key
1905        );
1906
1907        read_all_created(&mut rx, extent.num_ranks()).await;
1908        read_all_running(&mut rx, extent.num_ranks()).await;
1909
1910        // allocation finished. terminate connection.
1911        tracing::info!("closing upstream");
1912        drop(rx);
1913        // wait for the heartbeat to expire
1914        #[allow(clippy::disallowed_methods)]
1915        tokio::time::sleep(Duration::from_secs(2)).await;
1916        // wait for the stop to be called
1917        stop_rx.await.unwrap();
1918        // unblock next
1919        next_tx.send(()).unwrap();
1920        remote_allocator.terminate();
1921        handle.await.unwrap().unwrap();
1922    }
1923
1924    #[timed_test::async_timed_test(timeout_secs = 15)]
1925    async fn test_inner_alloc_failure() {
1926        let config = hyperactor::config::global::lock();
1927        let _guard = config.override_key(
1928            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1929            Duration::from_secs(60),
1930        );
1931        hyperactor_telemetry::initialize_logging_for_test();
1932        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1933        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1934        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1935
1936        let extent = extent!(host = 1, gpu = 2);
1937
1938        let tx = channel::dial(serve_addr.clone()).unwrap();
1939
1940        let test_world_id: WorldId = id!(test_world_id);
1941        let mut alloc = MockAllocWrapper::new_block_next(
1942            MockAlloc::new(),
1943            // block after the failure update
1944            1,
1945        );
1946        let next_tx = alloc.notify_tx();
1947        alloc
1948            .alloc
1949            .expect_world_id()
1950            .return_const(test_world_id.clone());
1951        alloc.alloc.expect_extent().return_const(extent.clone());
1952        alloc
1953            .alloc
1954            .expect_next()
1955            .times(1)
1956            .return_const(Some(ProcState::Failed {
1957                world_id: test_world_id.clone(),
1958                description: "test".to_string(),
1959            }));
1960        alloc.alloc.expect_next().times(1).return_const(None);
1961
1962        alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1963
1964        let mut allocator = MockAllocator::new();
1965        allocator
1966            .expect_allocate()
1967            .times(1)
1968            .return_once(|_| Ok(alloc));
1969
1970        let remote_allocator = RemoteProcessAllocator::new();
1971        let handle = tokio::spawn({
1972            let remote_allocator = remote_allocator.clone();
1973            async move {
1974                remote_allocator
1975                    .start_with_allocator(serve_addr, allocator, None)
1976                    .await
1977            }
1978        });
1979
1980        let alloc_key = ShortUuid::generate();
1981        tx.send(RemoteProcessAllocatorMessage::Allocate {
1982            alloc_key: alloc_key.clone(),
1983            extent: extent.clone(),
1984            bootstrap_addr,
1985            hosts: vec![],
1986            client_context: None,
1987            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
1988        })
1989        .await
1990        .unwrap();
1991
1992        // Allocated
1993        let m = rx.recv().await.unwrap();
1994        assert_matches!(
1995            m,
1996            RemoteProcessProcStateMessage::Allocated {  world_id: got_world_id, alloc_key: got_alloc_key }
1997            if test_world_id == got_world_id && alloc_key == got_alloc_key
1998        );
1999
2000        // Failed
2001        let m = rx.recv().await.unwrap();
2002        assert_matches!(
2003            m,
2004            RemoteProcessProcStateMessage::Update(
2005                got_alloc_key,
2006                ProcState::Failed { world_id, description }
2007            ) if got_alloc_key == alloc_key && world_id == test_world_id && description == "test"
2008        );
2009
2010        tracing::info!("stopping allocation");
2011        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
2012        // receive all stops
2013        next_tx.send(()).unwrap();
2014        // we are expecting 1 Done when Alloc successfully stops.
2015        let m = rx.recv().await.unwrap();
2016        assert_matches!(
2017            m,
2018            RemoteProcessProcStateMessage::Done(got_alloc_key)
2019            if got_alloc_key == alloc_key
2020        );
2021
2022        remote_allocator.terminate();
2023        handle.await.unwrap().unwrap();
2024    }
2025
2026    #[timed_test::async_timed_test(timeout_secs = 15)]
2027    async fn test_trace_id_propagation() {
2028        let config = hyperactor::config::global::lock();
2029        let _guard = config.override_key(
2030            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2031            Duration::from_secs(60),
2032        );
2033        hyperactor_telemetry::initialize_logging(ClockKind::default());
2034        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
2035        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
2036        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
2037
2038        let extent = extent!(host = 1, gpu = 1);
2039        let tx = channel::dial(serve_addr.clone()).unwrap();
2040        let test_world_id: WorldId = id!(test_world_id);
2041        let test_trace_id = "test_trace_id_12345";
2042
2043        // Create a mock alloc that we can verify receives the correct trace id
2044        let mut alloc = MockAlloc::new();
2045        alloc.expect_world_id().return_const(test_world_id.clone());
2046        alloc.expect_extent().return_const(extent.clone());
2047        alloc.expect_next().return_const(None);
2048
2049        // Create a mock allocator that captures the AllocSpec passed to it
2050        let mut allocator = MockAllocator::new();
2051        allocator
2052            .expect_allocate()
2053            .times(1)
2054            .withf(move |spec: &AllocSpec| {
2055                // Verify that the trace id is correctly set in the constraints
2056                spec.constraints
2057                    .match_labels
2058                    .get(CLIENT_TRACE_ID_LABEL)
2059                    .is_some_and(|trace_id| trace_id == test_trace_id)
2060            })
2061            .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
2062
2063        let remote_allocator = RemoteProcessAllocator::new();
2064        let handle = tokio::spawn({
2065            let remote_allocator = remote_allocator.clone();
2066            async move {
2067                remote_allocator
2068                    .start_with_allocator(serve_addr, allocator, None)
2069                    .await
2070            }
2071        });
2072
2073        let alloc_key = ShortUuid::generate();
2074        tx.send(RemoteProcessAllocatorMessage::Allocate {
2075            alloc_key: alloc_key.clone(),
2076            extent: extent.clone(),
2077            bootstrap_addr,
2078            hosts: vec![],
2079            client_context: Some(ClientContext {
2080                trace_id: test_trace_id.to_string(),
2081            }),
2082            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
2083        })
2084        .await
2085        .unwrap();
2086
2087        // Verify we get the allocated message
2088        let m = rx.recv().await.unwrap();
2089        assert_matches!(
2090            m,
2091            RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
2092            if got_world_id == test_world_id && got_alloc_key == alloc_key
2093        );
2094
2095        // Verify we get the done message since the mock alloc returns None immediately
2096        let m = rx.recv().await.unwrap();
2097        assert_matches!(
2098            m,
2099            RemoteProcessProcStateMessage::Done(got_alloc_key)
2100            if alloc_key == got_alloc_key
2101        );
2102
2103        remote_allocator.terminate();
2104        handle.await.unwrap().unwrap();
2105    }
2106
2107    #[timed_test::async_timed_test(timeout_secs = 15)]
2108    async fn test_trace_id_propagation_no_client_context() {
2109        let config = hyperactor::config::global::lock();
2110        let _guard = config.override_key(
2111            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2112            Duration::from_secs(60),
2113        );
2114        hyperactor_telemetry::initialize_logging(ClockKind::default());
2115        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
2116        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
2117        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
2118
2119        let extent = extent!(host = 1, gpu = 1);
2120        let tx = channel::dial(serve_addr.clone()).unwrap();
2121        let test_world_id: WorldId = id!(test_world_id);
2122
2123        // Create a mock alloc
2124        let mut alloc = MockAlloc::new();
2125        alloc.expect_world_id().return_const(test_world_id.clone());
2126        alloc.expect_extent().return_const(extent.clone());
2127        alloc.expect_next().return_const(None);
2128
2129        // Create a mock allocator that verifies no trace id is set when client_context is None
2130        let mut allocator = MockAllocator::new();
2131        allocator
2132            .expect_allocate()
2133            .times(1)
2134            .withf(move |spec: &AllocSpec| {
2135                // Verify that no trace id is set in the constraints when client_context is None
2136                spec.constraints.match_labels.is_empty()
2137            })
2138            .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
2139
2140        let remote_allocator = RemoteProcessAllocator::new();
2141        let handle = tokio::spawn({
2142            let remote_allocator = remote_allocator.clone();
2143            async move {
2144                remote_allocator
2145                    .start_with_allocator(serve_addr, allocator, None)
2146                    .await
2147            }
2148        });
2149
2150        let alloc_key = ShortUuid::generate();
2151        tx.send(RemoteProcessAllocatorMessage::Allocate {
2152            alloc_key: alloc_key.clone(),
2153            extent: extent.clone(),
2154            bootstrap_addr,
2155            hosts: vec![],
2156            client_context: None,
2157            forwarder_addr: with_unspecified_port_or_any(&tx.addr()),
2158        })
2159        .await
2160        .unwrap();
2161
2162        // Verify we get the allocated message
2163        let m = rx.recv().await.unwrap();
2164        assert_matches!(
2165            m,
2166            RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
2167            if got_world_id == test_world_id && got_alloc_key == alloc_key
2168        );
2169
2170        // Verify we get the done message since the mock alloc returns None immediately
2171        let m = rx.recv().await.unwrap();
2172        assert_matches!(
2173            m,
2174            RemoteProcessProcStateMessage::Done(got_alloc_key)
2175            if got_alloc_key == alloc_key
2176        );
2177
2178        remote_allocator.terminate();
2179        handle.await.unwrap().unwrap();
2180    }
2181}
2182
2183#[cfg(test)]
2184mod test_alloc {
2185    use std::os::unix::process::ExitStatusExt;
2186
2187    use hyperactor::clock::ClockKind;
2188    use hyperactor::config;
2189    use ndslice::extent;
2190    use nix::sys::signal;
2191    use nix::unistd::Pid;
2192    use timed_test::async_timed_test;
2193
2194    use super::*;
2195
2196    #[async_timed_test(timeout_secs = 60)]
2197    #[cfg(fbcode_build)]
2198    async fn test_alloc_simple() {
2199        // Use temporary config for this test
2200        let config = hyperactor::config::global::lock();
2201        let _guard1 = config.override_key(
2202            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2203            Duration::from_secs(1),
2204        );
2205        let _guard2 = config.override_key(
2206            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2207            Duration::from_millis(100),
2208        );
2209        hyperactor_telemetry::initialize_logging(ClockKind::default());
2210
2211        let spec = AllocSpec {
2212            extent: extent!(host = 2, gpu = 2),
2213            constraints: Default::default(),
2214            proc_name: None,
2215            transport: ChannelTransport::Unix,
2216            proc_allocation_mode: Default::default(),
2217        };
2218        let world_id = WorldId("test_world_id".to_string());
2219
2220        let task1_allocator = RemoteProcessAllocator::new();
2221        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2222        let task1_addr_string = task1_addr.to_string();
2223        let task1_cmd = Command::new(crate::testresource::get(
2224            "monarch/hyperactor_mesh/bootstrap",
2225        ));
2226        let task2_allocator = RemoteProcessAllocator::new();
2227        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2228        let task2_addr_string = task2_addr.to_string();
2229        let task2_cmd = Command::new(crate::testresource::get(
2230            "monarch/hyperactor_mesh/bootstrap",
2231        ));
2232        let task1_allocator_copy = task1_allocator.clone();
2233        let task1_allocator_handle = tokio::spawn(async move {
2234            tracing::info!("spawning task1");
2235            task1_allocator_copy
2236                .start(task1_cmd, task1_addr, None)
2237                .await
2238                .unwrap();
2239        });
2240        let task2_allocator_copy = task2_allocator.clone();
2241        let task2_allocator_handle = tokio::spawn(async move {
2242            task2_allocator_copy
2243                .start(task2_cmd, task2_addr, None)
2244                .await
2245                .unwrap();
2246        });
2247
2248        let mut initializer = MockRemoteProcessAllocInitializer::new();
2249        initializer.expect_initialize_alloc().return_once(move || {
2250            Ok(vec![
2251                RemoteProcessAllocHost {
2252                    hostname: task1_addr_string,
2253                    id: "task1".to_string(),
2254                },
2255                RemoteProcessAllocHost {
2256                    hostname: task2_addr_string,
2257                    id: "task2".to_string(),
2258                },
2259            ])
2260        });
2261        let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2262            .await
2263            .unwrap();
2264        let mut created = HashSet::new();
2265        let mut running_procs = HashSet::new();
2266        let mut proc_points = HashSet::new();
2267        for _ in 0..spec.extent.num_ranks() * 2 {
2268            let proc_state = alloc.next().await.unwrap();
2269            tracing::debug!("test got message: {:?}", proc_state);
2270            match proc_state {
2271                ProcState::Created {
2272                    create_key, point, ..
2273                } => {
2274                    created.insert(create_key);
2275                    proc_points.insert(point);
2276                }
2277                ProcState::Running { create_key, .. } => {
2278                    assert!(created.remove(&create_key));
2279                    running_procs.insert(create_key);
2280                }
2281                _ => panic!("expected Created or Running"),
2282            }
2283        }
2284        assert!(created.is_empty());
2285        // ensure coords coverage
2286        assert!(
2287            spec.extent
2288                .points()
2289                .all(|point| proc_points.contains(&point))
2290        );
2291
2292        // ensure no more pending items
2293        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2294        tokio::select! {
2295            _ = hyperactor::clock::RealClock.sleep_until(timeout) => {},
2296            _ = alloc.next() => panic!("expected no more items"),
2297        }
2298
2299        // stop the allocation
2300        alloc.stop().await.unwrap();
2301        for _ in 0..spec.extent.num_ranks() {
2302            let proc_state = alloc.next().await.unwrap();
2303            tracing::info!("test received next proc_state: {:?}", proc_state);
2304            match proc_state {
2305                ProcState::Stopped {
2306                    create_key, reason, ..
2307                } => {
2308                    assert!(running_procs.remove(&create_key));
2309                    assert_eq!(reason, ProcStopReason::Stopped);
2310                }
2311                _ => panic!("expected stopped"),
2312            }
2313        }
2314        // Exactly one None
2315        let proc_state = alloc.next().await;
2316        assert!(proc_state.is_none());
2317        // Anything afterwards is None
2318        let proc_state = alloc.next().await;
2319        assert!(proc_state.is_none());
2320
2321        task1_allocator.terminate();
2322        task1_allocator_handle.await.unwrap();
2323        task2_allocator.terminate();
2324        task2_allocator_handle.await.unwrap();
2325    }
2326
2327    #[async_timed_test(timeout_secs = 60)]
2328    #[cfg(fbcode_build)]
2329    async fn test_alloc_host_failure() {
2330        // Use temporary config for this test
2331        let config = hyperactor::config::global::lock();
2332        let _guard1 = config.override_key(
2333            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2334            Duration::from_secs(1),
2335        );
2336        let _guard2 = config.override_key(
2337            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2338            Duration::from_millis(100),
2339        );
2340        hyperactor_telemetry::initialize_logging(ClockKind::default());
2341
2342        let spec = AllocSpec {
2343            extent: extent!(host = 2, gpu = 2),
2344            constraints: Default::default(),
2345            proc_name: None,
2346            transport: ChannelTransport::Unix,
2347            proc_allocation_mode: Default::default(),
2348        };
2349        let world_id = WorldId("test_world_id".to_string());
2350
2351        let task1_allocator = RemoteProcessAllocator::new();
2352        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2353        let task1_addr_string = task1_addr.to_string();
2354        let task1_cmd = Command::new(crate::testresource::get(
2355            "monarch/hyperactor_mesh/bootstrap",
2356        ));
2357        let task2_allocator = RemoteProcessAllocator::new();
2358        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2359        let task2_addr_string = task2_addr.to_string();
2360        let task2_cmd = Command::new(crate::testresource::get(
2361            "monarch/hyperactor_mesh/bootstrap",
2362        ));
2363        let task1_allocator_copy = task1_allocator.clone();
2364        let task1_allocator_handle = tokio::spawn(async move {
2365            tracing::info!("spawning task1");
2366            task1_allocator_copy
2367                .start(task1_cmd, task1_addr, None)
2368                .await
2369                .unwrap();
2370            tracing::info!("task1 terminated");
2371        });
2372        let task2_allocator_copy = task2_allocator.clone();
2373        let task2_allocator_handle = tokio::spawn(async move {
2374            task2_allocator_copy
2375                .start(task2_cmd, task2_addr, None)
2376                .await
2377                .unwrap();
2378            tracing::info!("task2 terminated");
2379        });
2380
2381        let mut initializer = MockRemoteProcessAllocInitializer::new();
2382        initializer.expect_initialize_alloc().return_once(move || {
2383            Ok(vec![
2384                RemoteProcessAllocHost {
2385                    hostname: task1_addr_string,
2386                    id: "task1".to_string(),
2387                },
2388                RemoteProcessAllocHost {
2389                    hostname: task2_addr_string,
2390                    id: "task2".to_string(),
2391                },
2392            ])
2393        });
2394        let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2395            .await
2396            .unwrap();
2397        for _ in 0..spec.extent.num_ranks() * 2 {
2398            match alloc.next().await {
2399                Some(ProcState::Created { .. }) | Some(ProcState::Running { .. }) => {}
2400                _ => panic!("expected Created or Running"),
2401            }
2402        }
2403
2404        // ensure no more pending items
2405        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2406        tokio::select! {
2407            _ = hyperactor::clock::RealClock
2408            .sleep_until(timeout) => {},
2409            _ = alloc.next() => panic!("expected no more items"),
2410        }
2411
2412        // now we kill task1 and wait for timeout
2413        tracing::info!("aborting task1 allocator");
2414        task1_allocator_handle.abort();
2415        RealClock
2416            .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2)
2417            .await;
2418        for _ in 0..spec.extent.num_ranks() / 2 {
2419            let proc_state = alloc.next().await.unwrap();
2420            tracing::info!("test received next proc_state: {:?}", proc_state);
2421            match proc_state {
2422                ProcState::Stopped { reason, .. } => {
2423                    assert_eq!(reason, ProcStopReason::HostWatchdog);
2424                }
2425                _ => panic!("expected stopped"),
2426            }
2427        }
2428        // no more events
2429        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2430        tokio::select! {
2431            _ = hyperactor::clock::RealClock
2432            .sleep_until(timeout) => {},
2433            _ = alloc.next() => panic!("expected no more items"),
2434        }
2435
2436        // abort the second host
2437        tracing::info!("aborting task2 allocator");
2438        task2_allocator_handle.abort();
2439        RealClock
2440            .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2)
2441            .await;
2442        for _ in 0..spec.extent.num_ranks() / 2 {
2443            let proc_state = alloc.next().await.unwrap();
2444            tracing::info!("test received next proc_state: {:?}", proc_state);
2445            match proc_state {
2446                ProcState::Stopped { reason, .. } => {
2447                    assert_eq!(reason, ProcStopReason::HostWatchdog);
2448                }
2449                _ => panic!("expected stopped"),
2450            }
2451        }
2452        // now the alloc has stopped we always get None
2453        let proc_state = alloc.next().await;
2454        assert!(proc_state.is_none());
2455        // Anything afterwards is None
2456        let proc_state = alloc.next().await;
2457        assert!(proc_state.is_none());
2458    }
2459
2460    #[async_timed_test(timeout_secs = 15)]
2461    #[cfg(fbcode_build)]
2462    async fn test_alloc_inner_alloc_failure() {
2463        // SAFETY: Test happens in single-threaded code.
2464        unsafe {
2465            std::env::set_var("MONARCH_MESSAGE_DELIVERY_TIMEOUT_SECS", "1");
2466        }
2467
2468        let config = hyperactor::config::global::lock();
2469        let _guard = config.override_key(
2470            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2471            Duration::from_millis(100),
2472        );
2473        hyperactor_telemetry::initialize_logging_for_test();
2474
2475        let spec = AllocSpec {
2476            extent: extent!(host = 2, gpu = 2),
2477            constraints: Default::default(),
2478            proc_name: None,
2479            transport: ChannelTransport::Unix,
2480            proc_allocation_mode: Default::default(),
2481        };
2482        let world_id = WorldId("test_world_id".to_string());
2483
2484        let task1_allocator = RemoteProcessAllocator::new();
2485        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2486        let task1_addr_string = task1_addr.to_string();
2487        let task1_cmd = Command::new(crate::testresource::get(
2488            "monarch/hyperactor_mesh/bootstrap",
2489        ));
2490        let task2_allocator = RemoteProcessAllocator::new();
2491        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2492        let task2_addr_string = task2_addr.to_string();
2493        // non-existent binary to fail the allocation
2494        let task2_cmd = Command::new("/caught/somewhere/in/time");
2495        let task1_allocator_copy = task1_allocator.clone();
2496        let task1_allocator_handle = tokio::spawn(async move {
2497            tracing::info!("spawning task1");
2498            task1_allocator_copy
2499                .start(task1_cmd, task1_addr, None)
2500                .await
2501                .unwrap();
2502        });
2503        let task2_allocator_copy = task2_allocator.clone();
2504        let task2_allocator_handle = tokio::spawn(async move {
2505            task2_allocator_copy
2506                .start(task2_cmd, task2_addr, None)
2507                .await
2508                .unwrap();
2509        });
2510
2511        let mut initializer = MockRemoteProcessAllocInitializer::new();
2512        initializer.expect_initialize_alloc().return_once(move || {
2513            Ok(vec![
2514                RemoteProcessAllocHost {
2515                    hostname: task1_addr_string,
2516                    id: "task1".to_string(),
2517                },
2518                RemoteProcessAllocHost {
2519                    hostname: task2_addr_string,
2520                    id: "task2".to_string(),
2521                },
2522            ])
2523        });
2524        let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2525            .await
2526            .unwrap();
2527        let mut created = HashSet::new();
2528        let mut started_procs = HashSet::new();
2529        let mut proc_points = HashSet::new();
2530        let mut failed = 0;
2531        // task 1 procs + 1 failed event for task 2
2532        for _ in 0..spec.extent.num_ranks() + 1 {
2533            let proc_state = alloc.next().await.unwrap();
2534            tracing::debug!("test got message: {:?}", proc_state);
2535            match proc_state {
2536                ProcState::Created {
2537                    create_key, point, ..
2538                } => {
2539                    created.insert(create_key);
2540                    proc_points.insert(point);
2541                }
2542                ProcState::Running { create_key, .. } => {
2543                    assert!(created.remove(&create_key));
2544                    started_procs.insert(create_key);
2545                }
2546                ProcState::Failed { .. } => {
2547                    failed += 1;
2548                }
2549                _ => panic!("expected Created, Running or Failed"),
2550            }
2551        }
2552        assert!(created.is_empty());
2553        assert_eq!(failed, 1);
2554        // ensure coords coverage for task 1
2555        for rank in 0..spec.extent.num_ranks() / 2 {
2556            let point = spec.extent.point_of_rank(rank).unwrap();
2557            assert!(proc_points.contains(&point));
2558        }
2559
2560        // ensure no more pending items
2561        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2562        tokio::select! {
2563            _ = hyperactor::clock::RealClock
2564            .sleep_until(timeout) => {},
2565            _ = alloc.next() => panic!("expected no more items"),
2566        }
2567
2568        // stop the allocation
2569        alloc.stop().await.unwrap();
2570        for _ in 0..spec.extent.num_ranks() / 2 {
2571            let proc_state = alloc.next().await.unwrap();
2572            tracing::info!("test received next proc_state: {:?}", proc_state);
2573            match proc_state {
2574                ProcState::Stopped {
2575                    create_key, reason, ..
2576                } => {
2577                    assert!(started_procs.remove(&create_key));
2578                    assert_eq!(reason, ProcStopReason::Stopped);
2579                }
2580                _ => panic!("expected stopped"),
2581            }
2582        }
2583        // Exactly one None
2584        let proc_state = alloc.next().await;
2585        assert!(proc_state.is_none());
2586        // Anything afterwards is None
2587        let proc_state = alloc.next().await;
2588        assert!(proc_state.is_none());
2589
2590        task1_allocator.terminate();
2591        task1_allocator_handle.await.unwrap();
2592        task2_allocator.terminate();
2593        task2_allocator_handle.await.unwrap();
2594    }
2595
2596    #[tracing_test::traced_test]
2597    #[async_timed_test(timeout_secs = 60)]
2598    #[cfg(fbcode_build)]
2599    async fn test_remote_process_alloc_signal_handler() {
2600        let num_proc_meshes = 5;
2601        let hosts_per_proc_mesh = 5;
2602
2603        let pid_addr = ChannelAddr::any(ChannelTransport::Unix);
2604        let (pid_addr, mut pid_rx) = channel::serve::<u32>(pid_addr).unwrap();
2605
2606        let addresses = (0..(num_proc_meshes * hosts_per_proc_mesh))
2607            .map(|_| ChannelAddr::any(ChannelTransport::Unix).to_string())
2608            .collect::<Vec<_>>();
2609
2610        let remote_process_allocators = addresses
2611            .iter()
2612            .map(|addr| {
2613                Command::new(crate::testresource::get(
2614                    "monarch/hyperactor_mesh/remote_process_allocator",
2615                ))
2616                .env("RUST_LOG", "info")
2617                .arg(format!("--addr={addr}"))
2618                .stdout(std::process::Stdio::piped())
2619                .spawn()
2620                .unwrap()
2621            })
2622            .collect::<Vec<_>>();
2623
2624        let done_allocating_addr = ChannelAddr::any(ChannelTransport::Unix);
2625        let (done_allocating_addr, mut done_allocating_rx) =
2626            channel::serve::<()>(done_allocating_addr).unwrap();
2627        let mut remote_process_alloc = Command::new(crate::testresource::get(
2628            "monarch/hyperactor_mesh/remote_process_alloc",
2629        ))
2630        .arg(format!("--done-allocating-addr={}", done_allocating_addr))
2631        .arg(format!("--addresses={}", addresses.join(",")))
2632        .arg(format!("--num-proc-meshes={}", num_proc_meshes))
2633        .arg(format!("--hosts-per-proc-mesh={}", hosts_per_proc_mesh))
2634        .arg(format!("--pid-addr={}", pid_addr))
2635        .spawn()
2636        .unwrap();
2637
2638        done_allocating_rx.recv().await.unwrap();
2639        let mut received_pids = Vec::new();
2640        while let Ok(pid) = pid_rx.recv().await {
2641            received_pids.push(pid);
2642            if received_pids.len() == remote_process_allocators.len() {
2643                break;
2644            }
2645        }
2646
2647        signal::kill(
2648            Pid::from_raw(remote_process_alloc.id().unwrap() as i32),
2649            signal::SIGINT,
2650        )
2651        .unwrap();
2652
2653        assert_eq!(
2654            remote_process_alloc.wait().await.unwrap().signal(),
2655            Some(signal::SIGINT as i32)
2656        );
2657
2658        RealClock.sleep(tokio::time::Duration::from_secs(5)).await;
2659
2660        // Assert that the processes spawned by ProcessAllocator have been killed
2661        for child_pid in received_pids {
2662            let pid_check = Command::new("kill")
2663                .arg("-0")
2664                .arg(child_pid.to_string())
2665                .output()
2666                .await
2667                .expect("Failed to check if PID is alive");
2668
2669            assert!(
2670                !pid_check.status.success(),
2671                "PID {} should no longer be alive",
2672                child_pid
2673            );
2674        }
2675
2676        // Cleanup remote process allocator processes as SIGINT causes the current
2677        // allocs to stop but not the RemoteProcessAllocator loops
2678        for mut remote_process_allocator in remote_process_allocators {
2679            remote_process_allocator.kill().await.unwrap();
2680        }
2681    }
2682}