hyperactor_mesh/alloc/
remoteprocess.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::collections::HashSet;
11use std::collections::VecDeque;
12use std::sync::Arc;
13use std::time::Duration;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use dashmap::DashMap;
18use futures::FutureExt;
19use futures::future::join_all;
20use futures::future::select_all;
21use hyperactor::Named;
22use hyperactor::WorldId;
23use hyperactor::channel;
24use hyperactor::channel::ChannelAddr;
25use hyperactor::channel::ChannelRx;
26use hyperactor::channel::ChannelTransport;
27use hyperactor::channel::ChannelTx;
28use hyperactor::channel::Rx;
29use hyperactor::channel::Tx;
30use hyperactor::channel::TxStatus;
31use hyperactor::clock;
32use hyperactor::clock::Clock;
33use hyperactor::clock::RealClock;
34use hyperactor::config;
35use hyperactor::mailbox::DialMailboxRouter;
36use hyperactor::mailbox::MailboxServer;
37use hyperactor::observe_async;
38use hyperactor::observe_result;
39use hyperactor::reference::Reference;
40use hyperactor::serde_json;
41use mockall::automock;
42use ndslice::Region;
43use ndslice::View;
44use ndslice::ViewExt;
45use ndslice::view::Extent;
46use ndslice::view::Point;
47use serde::Deserialize;
48use serde::Serialize;
49use strum::AsRefStr;
50use tokio::io::AsyncWriteExt;
51use tokio::process::Command;
52use tokio::sync::mpsc::UnboundedReceiver;
53use tokio::sync::mpsc::UnboundedSender;
54use tokio::sync::mpsc::unbounded_channel;
55use tokio::task::JoinHandle;
56use tokio_stream::StreamExt;
57use tokio_stream::wrappers::WatchStream;
58use tokio_util::sync::CancellationToken;
59
60use crate::alloc::Alloc;
61use crate::alloc::AllocConstraints;
62use crate::alloc::AllocSpec;
63use crate::alloc::Allocator;
64use crate::alloc::AllocatorError;
65use crate::alloc::ProcState;
66use crate::alloc::ProcStopReason;
67use crate::alloc::ProcessAllocator;
68use crate::alloc::process::CLIENT_TRACE_ID_LABEL;
69use crate::alloc::process::ClientContext;
70use crate::shortuuid::ShortUuid;
71
72/// Control messages sent from remote process allocator to local allocator.
73#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
74pub enum RemoteProcessAllocatorMessage {
75    /// Create allocation with given spec and send updates to bootstrap_addr.
76    Allocate {
77        /// The key used to identify this allocation.
78        alloc_key: ShortUuid,
79        /// The extent to allocate.
80        extent: Extent,
81        /// Bootstrap address to be used for sending updates.
82        bootstrap_addr: ChannelAddr,
83        /// Ordered list of hosts in this allocation. Can be used to
84        /// pre-populate the any local configurations such as torch.dist.
85        hosts: Vec<String>,
86        /// Client context which is passed to the ProcessAlloc
87        /// Todo: Once RemoteProcessAllocator moves to mailbox,
88        /// the client_context will go to the message header instead
89        client_context: Option<ClientContext>,
90    },
91    /// Stop allocation.
92    Stop,
93    /// Heartbeat message to check if remote process allocator and its
94    /// host are alive.
95    HeartBeat,
96}
97
98/// Control message sent from local allocator to remote allocator
99/// relaying process state updates.
100/// AsRefStr allows us to log the values
101#[derive(Debug, Clone, Serialize, Deserialize, Named, AsRefStr)]
102pub enum RemoteProcessProcStateMessage {
103    /// Allocation successful and Update, Done messages will follow.
104    Allocated {
105        alloc_key: ShortUuid,
106        world_id: WorldId,
107    },
108    /// ProcState updates.
109    Update(ShortUuid, ProcState),
110    /// Underlying Alloc is done.
111    Done(ShortUuid),
112    /// Heartbeat message to check if client is alive.
113    HeartBeat,
114}
115
116/// Allocator with a service frontend that wraps ProcessAllocator.
117pub struct RemoteProcessAllocator {
118    cancel_token: CancellationToken,
119}
120
121async fn conditional_sleeper<F: futures::Future<Output = ()>>(t: Option<F>) {
122    match t {
123        Some(timer) => timer.await,
124        None => futures::future::pending().await,
125    }
126}
127
128impl RemoteProcessAllocator {
129    /// Create a new allocator. It will not start until start() is called.
130    pub fn new() -> Arc<Self> {
131        Arc::new(Self {
132            cancel_token: CancellationToken::new(),
133        })
134    }
135
136    /// Stop the allocator. This will stop any ongoing allocations.
137    pub fn terminate(&self) {
138        self.cancel_token.cancel();
139    }
140
141    /// Start a remote process allocator with given cmd listening for
142    /// RemoteProcessAllocatorMessage on serve_addr. Call will block until cancelled.
143    /// The implementation is simple such that it can only handle one Alloc at
144    /// a time. Generally that's the most common use-case.
145    /// Flow works as follows:
146    /// 1. Client sends Allocate message to serve_addr.
147    /// 2. Allocator connects to bootstrap_addr, creates Alloc and sends Allocated message.
148    /// 3. Allocator streams one or more Update messages to bootstrap_addr as Alloc progresses
149    ///    making the following changes:
150    ///    * Remap mesh_agent listen address to our own forwarder actor address.
151    /// 4. Allocator sends Done message to bootstrap_addr when Alloc is done.
152    ///
153    /// At any point, client can send Stop message to serve_addr to stop the allocator.
154    /// If timeout is Some, the allocator will exit if no client connects within
155    /// that timeout, and no child allocation is running.
156    #[hyperactor::instrument]
157    pub async fn start(
158        &self,
159        cmd: Command,
160        serve_addr: ChannelAddr,
161        timeout: Option<Duration>,
162    ) -> Result<(), anyhow::Error> {
163        let process_allocator = ProcessAllocator::new(cmd);
164        self.start_with_allocator(serve_addr, process_allocator, timeout)
165            .await
166    }
167
168    /// Start a remote process allocator with given allocator listening for
169    /// RemoteProcessAllocatorMessage on serve_addr.
170    /// Used for testing.
171    pub async fn start_with_allocator<A: Allocator + Send + Sync + 'static>(
172        &self,
173        serve_addr: ChannelAddr,
174        mut process_allocator: A,
175        timeout: Option<Duration>,
176    ) -> Result<(), anyhow::Error>
177    where
178        <A as Allocator>::Alloc: Send,
179        <A as Allocator>::Alloc: Sync,
180    {
181        tracing::info!("starting remote allocator on: {}", serve_addr);
182        let (_, mut rx) = channel::serve(serve_addr.clone()).map_err(anyhow::Error::from)?;
183
184        struct ActiveAllocation {
185            handle: JoinHandle<()>,
186            cancel_token: CancellationToken,
187        }
188        #[observe_async("RemoteProcessAllocator")]
189        async fn ensure_previous_alloc_stopped(active_allocation: &mut Option<ActiveAllocation>) {
190            if let Some(active_allocation) = active_allocation.take() {
191                tracing::info!("previous alloc found, stopping");
192                active_allocation.cancel_token.cancel();
193                match active_allocation.handle.await {
194                    Ok(_) => {
195                        // Todo, add named tracing for state change here
196                        tracing::info!("allocation stopped.")
197                    }
198                    Err(e) => {
199                        tracing::error!("allocation handler failed: {}", e);
200                    }
201                }
202            }
203        }
204
205        let mut active_allocation: Option<ActiveAllocation> = None;
206        loop {
207            // Refresh each loop iteration so the timer updates whenever a message
208            // is received.
209            let sleep = conditional_sleeper(timeout.map(|t| RealClock.sleep(t)));
210            tokio::select! {
211                msg = rx.recv() => {
212                    match msg {
213                        Ok(RemoteProcessAllocatorMessage::Allocate {
214                            alloc_key,
215                            extent,
216                            bootstrap_addr,
217                            hosts,
218                            client_context,
219                        }) => {
220                            tracing::info!("received allocation request for {} with extent {}", alloc_key, extent);
221                            ensure_previous_alloc_stopped(&mut active_allocation).await;
222
223                            // Create the corresponding local allocation spec.
224                            let mut constraints: AllocConstraints = Default::default();
225                            if let Some(context) = &client_context {
226                                constraints = AllocConstraints {
227                                    match_labels: HashMap::from([(
228                                    CLIENT_TRACE_ID_LABEL.to_string(),
229                                    context.trace_id.to_string(),
230                                    )]
231                                )};
232                                tracing::info!(
233                                    monarch_client_trace_id = context.trace_id.to_string(),
234                                    "allocating...",
235                                );
236                            }
237
238
239                            let spec = AllocSpec {
240                                extent,
241                                constraints,
242                                proc_name: None, // TODO(meriksen, direct addressing): we need to pass the addressing mode here
243                                transport: ChannelTransport::Unix,
244                            };
245
246                            match process_allocator.allocate(spec.clone()).await {
247                                Ok(alloc) => {
248                                    let cancel_token = CancellationToken::new();
249                                    active_allocation = Some(ActiveAllocation {
250                                        cancel_token: cancel_token.clone(),
251                                        handle: tokio::spawn(Self::handle_allocation_request(
252                                            Box::new(alloc) as Box<dyn Alloc + Send + Sync>,
253                                            alloc_key,
254                                            serve_addr.transport(),
255                                            bootstrap_addr,
256                                            hosts,
257                                            cancel_token,
258                                        )),
259                                    })
260                                }
261                                Err(e) => {
262                                    tracing::error!("allocation for {:?} failed: {}", spec, e);
263                                    continue;
264                                }
265                            }
266                        }
267                        Ok(RemoteProcessAllocatorMessage::Stop) => {
268                            tracing::info!("received stop request");
269
270                            ensure_previous_alloc_stopped(&mut active_allocation).await;
271                        }
272                        // Hearbeat message is discarded immediately after being received, sender (client)
273                        // relies on channel ack to know if the receiver (remote process allocator) is
274                        // still alive. No state needs to be updated.
275                        Ok(RemoteProcessAllocatorMessage::HeartBeat) => {}
276                        Err(e) => {
277                            tracing::error!("upstream channel error: {}", e);
278                            continue;
279                        }
280                    }
281                }
282                _ = self.cancel_token.cancelled() => {
283                    tracing::info!("main loop cancelled");
284
285                    ensure_previous_alloc_stopped(&mut active_allocation).await;
286
287                    break;
288                }
289                _ = sleep => {
290                    // If there are any active allocations, reset the timeout.
291                    if active_allocation.is_some() {
292                        continue;
293                    }
294                    // Else, exit the loop as a client hasn't connected in a reasonable
295                    // amount of time.
296                    tracing::warn!("timeout of {} seconds elapsed without any allocations, exiting", timeout.unwrap_or_default().as_secs());
297                    break;
298                }
299            }
300        }
301
302        Ok(())
303    }
304
305    #[observe_async("RemoteProcessAllocator")]
306    async fn handle_allocation_request(
307        alloc: Box<dyn Alloc + Send + Sync>,
308        alloc_key: ShortUuid,
309        serve_transport: ChannelTransport,
310        bootstrap_addr: ChannelAddr,
311        hosts: Vec<String>,
312        cancel_token: CancellationToken,
313    ) {
314        tracing::info!("handle allocation request, bootstrap_addr: {bootstrap_addr}");
315        // start proc message forwarder
316        // Use serve_transport instead of bootstrap_addr's transport so the transports are
317        // consistent between the remote process allocator and the processes.
318        // The bootstrap_addr could be a different transport that the process might not be compatible with.
319        let (forwarder_addr, forwarder_rx) = match channel::serve(ChannelAddr::any(serve_transport))
320        {
321            Ok(v) => v,
322            Err(e) => {
323                tracing::error!("failed to to bootstrap forwarder actor: {}", e);
324                return;
325            }
326        };
327        let router = DialMailboxRouter::new();
328        let mailbox_handle = router.clone().serve(forwarder_rx);
329        tracing::info!("started forwarder on: {}", forwarder_addr);
330
331        // Check if we need to write TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE
332        // See: https://github.com/fairinternal/xlformers/blob/llama4_monarch/tools/launching/torchx/entrypoint/generate_ranks.py
333        if let Ok(hosts_file) = std::env::var("TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE") {
334            tracing::info!("writing hosts to {}", hosts_file);
335            #[derive(Serialize)]
336            struct Hosts {
337                hostnames: Vec<String>,
338            }
339            match serde_json::to_string(&Hosts { hostnames: hosts }) {
340                Ok(json) => match tokio::fs::File::create(&hosts_file).await {
341                    Ok(mut file) => {
342                        if file.write_all(json.as_bytes()).await.is_err() {
343                            tracing::error!("failed to write hosts to {}", hosts_file);
344                            return;
345                        }
346                    }
347                    Err(e) => {
348                        tracing::error!("failed to open hosts file {}: {}", hosts_file, e);
349                        return;
350                    }
351                },
352                Err(e) => {
353                    tracing::error!("failed to serialize hosts: {}", e);
354                    return;
355                }
356            }
357        }
358
359        Self::handle_allocation_loop(
360            alloc,
361            alloc_key,
362            bootstrap_addr,
363            router,
364            forwarder_addr,
365            cancel_token,
366        )
367        .await;
368
369        mailbox_handle.stop("alloc stopped");
370        if let Err(e) = mailbox_handle.await {
371            tracing::error!("failed to join forwarder: {}", e);
372        }
373    }
374
375    async fn handle_allocation_loop(
376        mut alloc: Box<dyn Alloc + Send + Sync>,
377        alloc_key: ShortUuid,
378        bootstrap_addr: ChannelAddr,
379        router: DialMailboxRouter,
380        forward_addr: ChannelAddr,
381        cancel_token: CancellationToken,
382    ) {
383        let world_id = alloc.world_id().clone();
384        tracing::info!("starting handle allocation loop for {}", world_id);
385        let tx = match channel::dial(bootstrap_addr) {
386            Ok(tx) => tx,
387            Err(err) => {
388                tracing::error!("failed to dial bootstrap address: {}", err);
389                return;
390            }
391        };
392        let message = RemoteProcessProcStateMessage::Allocated {
393            alloc_key: alloc_key.clone(),
394            world_id,
395        };
396        tracing::info!(name = message.as_ref(), "sending allocated message",);
397        if let Err(e) = tx.send(message).await {
398            tracing::error!("failed to send Allocated message: {}", e);
399            return;
400        }
401
402        let mut mesh_agents_by_create_key = HashMap::new();
403        let mut running = true;
404        let tx_status = tx.status().clone();
405        let mut tx_watcher = WatchStream::new(tx_status);
406        loop {
407            tokio::select! {
408                _ = cancel_token.cancelled(), if running => {
409                    tracing::info!("cancelled, stopping allocation");
410                    running = false;
411                    if let Err(e) = alloc.stop().await {
412                        tracing::error!("stop failed: {}", e);
413                        break;
414                    }
415                }
416                status = tx_watcher.next(), if running => {
417                    match status  {
418                        Some(TxStatus::Closed) => {
419                            tracing::error!("upstream channel state closed");
420                            break;
421                        },
422                        _ => {
423                            tracing::debug!("got channel event: {:?}", status.unwrap());
424                            continue;
425                        }
426                    }
427                }
428                e = alloc.next() => {
429                    match e {
430                        Some(event) => {
431                            tracing::debug!(name = event.as_ref(), "got event: {:?}", event);
432                            let event = match event {
433                                ProcState::Created { .. } => event,
434                                ProcState::Running { create_key, proc_id, mesh_agent, addr } => {
435                                    // TODO(meriksen, direct addressing): disable remapping in direct addressing mode
436                                    tracing::debug!("remapping mesh_agent {}: addr {} -> {}", mesh_agent, addr, forward_addr);
437                                    mesh_agents_by_create_key.insert(create_key.clone(), mesh_agent.clone());
438                                    router.bind(mesh_agent.actor_id().proc_id().clone().into(), addr);
439                                    ProcState::Running { create_key, proc_id, mesh_agent, addr: forward_addr.clone() }
440                                },
441                                ProcState::Stopped { create_key, reason } => {
442                                    match mesh_agents_by_create_key.remove(&create_key) {
443                                        Some(mesh_agent) => {
444                                            tracing::debug!("unmapping mesh_agent {}", mesh_agent);
445                                            let agent_ref: Reference = mesh_agent.actor_id().proc_id().clone().into();
446                                            router.unbind(&agent_ref);
447                                        },
448                                        None => {
449                                            tracing::warn!("mesh_agent not found for create key {}", create_key);
450                                        }
451                                    }
452                                    ProcState::Stopped { create_key, reason }
453                                },
454                                ProcState::Failed { ref world_id, ref description } => {
455                                    tracing::error!("allocation failed for {}: {}", world_id, description);
456                                    event
457                                }
458                            };
459                            tracing::debug!(name = event.as_ref(), "sending event: {:?}", event);
460                            tx.post(RemoteProcessProcStateMessage::Update(alloc_key.clone(), event));
461                        }
462                        None => {
463                            tracing::debug!("sending done");
464                            tx.post(RemoteProcessProcStateMessage::Done(alloc_key.clone()));
465                            running = false;
466                            break;
467                        }
468                    }
469                }
470                _ = RealClock.sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)) => {
471                    tracing::trace!("sending heartbeat");
472                    tx.post(RemoteProcessProcStateMessage::HeartBeat);
473                }
474            }
475        }
476        tracing::info!("allocation handler loop exited");
477        if running {
478            tracing::info!("stopping processes");
479            if let Err(e) = alloc.stop_and_wait().await {
480                tracing::error!("stop failed: {}", e);
481                return;
482            }
483            tracing::info!("stop finished");
484        }
485    }
486}
487
488/// HostId is a string-based identifier of the host. It may be the same
489/// as the hostname but in some situations a separate ID may be used.
490type HostId = String;
491
492/// A host entry passed down from the initializer.
493#[derive(Clone)]
494pub struct RemoteProcessAllocHost {
495    /// A unique identifier of that host. Hostname may be used by ID can
496    /// be different in some situations.
497    pub id: HostId,
498    /// The FQDN of the host.
499    pub hostname: String,
500}
501
502/// State of a host in the RemoteProcessAlloc.
503struct RemoteProcessAllocHostState {
504    /// The allocation key used to identify the host.
505    alloc_key: ShortUuid,
506    /// The host ID of the remote host.
507    host_id: HostId,
508    /// TX channel to the remote host allocator.
509    tx: ChannelTx<RemoteProcessAllocatorMessage>,
510    /// Set of active processes on this host.
511    active_procs: HashSet<ShortUuid>,
512    /// Region allocated by host.
513    region: Region,
514    /// World ID for this host as indicated from Allocated message.
515    world_id: Option<WorldId>,
516    /// If remote allocator sent us ProcState::Failed.
517    failed: bool,
518    /// If remote allocater has ever allocated a proc.
519    allocated: bool,
520}
521
522#[automock]
523#[async_trait]
524/// Interface to provide the set of hosts to be used by RemoteProcessAlloc.
525pub trait RemoteProcessAllocInitializer {
526    /// Initializes and returns a list of hosts to be used by this RemoteProcessAlloc.
527    async fn initialize_alloc(&mut self) -> Result<Vec<RemoteProcessAllocHost>, anyhow::Error>;
528}
529
530/// Wrapper struct around `HashMap<HostId, RemoteProcessAllocHostState>`
531/// to ensure that host addresses are synced with the signal handler
532struct HostStates {
533    inner: HashMap<HostId, RemoteProcessAllocHostState>,
534    host_addresses: Arc<DashMap<HostId, ChannelAddr>>,
535}
536
537impl HostStates {
538    fn new(host_addresses: Arc<DashMap<HostId, ChannelAddr>>) -> HostStates {
539        Self {
540            inner: HashMap::new(),
541            host_addresses,
542        }
543    }
544
545    fn insert(
546        &mut self,
547        host_id: HostId,
548        state: RemoteProcessAllocHostState,
549        address: ChannelAddr,
550    ) {
551        self.host_addresses.insert(host_id.clone(), address);
552        self.inner.insert(host_id, state);
553    }
554
555    fn get(&self, host_id: &HostId) -> Option<&RemoteProcessAllocHostState> {
556        self.inner.get(host_id)
557    }
558
559    fn get_mut(&mut self, host_id: &HostId) -> Option<&mut RemoteProcessAllocHostState> {
560        self.inner.get_mut(host_id)
561    }
562
563    fn remove(&mut self, host_id: &HostId) -> Option<RemoteProcessAllocHostState> {
564        self.host_addresses.remove(host_id);
565        self.inner.remove(host_id)
566    }
567
568    fn iter(&self) -> impl Iterator<Item = (&HostId, &RemoteProcessAllocHostState)> {
569        self.inner.iter()
570    }
571
572    fn iter_mut(&mut self) -> impl Iterator<Item = (&HostId, &mut RemoteProcessAllocHostState)> {
573        self.inner.iter_mut()
574    }
575
576    fn is_empty(&self) -> bool {
577        self.inner.is_empty()
578    }
579    // Any missing HashMap methods should be added here as needed
580}
581
582/// A generalized implementation of an Alloc using one or more hosts running
583/// RemoteProcessAlloc for process allocation.
584pub struct RemoteProcessAlloc {
585    // The initializer to be called at the first next() invocation to obtain
586    // allocated hosts.
587    initializer: Box<dyn RemoteProcessAllocInitializer + Send + Sync>,
588    spec: AllocSpec,
589    remote_allocator_port: u16,
590    world_id: WorldId,
591    ordered_hosts: Vec<RemoteProcessAllocHost>,
592    // Indicates that the initial remote allocation requests have been sent.
593    started: bool,
594    // Indicates that this Alloc is active (we have at least one remote process running).
595    running: bool,
596    // Inidicates that the allocation process has permanently failed.
597    failed: bool,
598    // Maps the alloc key to the host.
599    alloc_to_host: HashMap<ShortUuid, HostId>,
600    host_states: HostStates,
601    world_offsets: HashMap<WorldId, usize>,
602    event_queue: VecDeque<ProcState>,
603    comm_watcher_tx: UnboundedSender<HostId>,
604    comm_watcher_rx: UnboundedReceiver<HostId>,
605
606    bootstrap_addr: ChannelAddr,
607    rx: ChannelRx<RemoteProcessProcStateMessage>,
608    _signal_cleanup_guard: hyperactor::SignalCleanupGuard,
609}
610
611impl RemoteProcessAlloc {
612    /// Create a new Alloc. initializer will be called on the first invocation of next()
613    /// to obtain a list of allocate hosts. Then Allocate message will be sent to all
614    /// RemoteProcessAllocator on all hosts. Heartbeats will be used to maintain health
615    /// status of remote hosts.
616    #[observe_result("RemoteProcessAlloc")]
617    pub async fn new(
618        spec: AllocSpec,
619        world_id: WorldId,
620        remote_allocator_port: u16,
621        initializer: impl RemoteProcessAllocInitializer + Send + Sync + 'static,
622    ) -> Result<Self, anyhow::Error> {
623        let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(spec.transport.clone()))
624            .map_err(anyhow::Error::from)?;
625
626        tracing::info!(
627            "starting alloc for {} on: {}",
628            world_id,
629            bootstrap_addr.clone()
630        );
631
632        let (comm_watcher_tx, comm_watcher_rx) = unbounded_channel();
633
634        let host_addresses = Arc::new(DashMap::<HostId, ChannelAddr>::new());
635        let host_addresses_for_signal = host_addresses.clone();
636
637        // Register cleanup callback with global signal manager
638        let signal_cleanup_guard =
639            hyperactor::register_signal_cleanup_scoped(Box::pin(async move {
640                join_all(host_addresses_for_signal.iter().map(|entry| async move {
641                    let addr = entry.value().clone();
642                    match channel::dial(addr.clone()) {
643                        Ok(tx) => {
644                            if let Err(e) = tx.send(RemoteProcessAllocatorMessage::Stop).await {
645                                tracing::error!("Failed to send Stop to {}: {}", addr, e);
646                            }
647                        }
648                        Err(e) => {
649                            tracing::error!("Failed to dial {} during signal cleanup: {}", addr, e);
650                        }
651                    }
652                }))
653                .await;
654            }));
655
656        Ok(Self {
657            spec,
658            world_id,
659            remote_allocator_port,
660            initializer: Box::new(initializer),
661            world_offsets: HashMap::new(),
662            ordered_hosts: Vec::new(),
663            alloc_to_host: HashMap::new(),
664            host_states: HostStates::new(host_addresses),
665            bootstrap_addr,
666            event_queue: VecDeque::new(),
667            comm_watcher_tx,
668            comm_watcher_rx,
669            rx,
670            started: false,
671            running: true,
672            failed: false,
673            _signal_cleanup_guard: signal_cleanup_guard,
674        })
675    }
676
677    /// Start a RemoteProcessAllocator tx watcher. It spawns a task that monitor the underlying
678    /// TxStatus and repot to the main next() loop via comm_watcher_tx. It does not however
679    /// send the actual heartbteats. Just watches the status.
680    /// The task will terminate once comm_watcher_rx is closed.
681    ///
682    /// A task approach was used here as opposed to select! to avoid nesting complexities with
683    /// select!() and select_all().
684    async fn start_comm_watcher(&self) {
685        let mut tx_watchers = Vec::new();
686        for host in &self.ordered_hosts {
687            let tx_status = self.host_states.get(&host.id).unwrap().tx.status().clone();
688            let watcher = WatchStream::new(tx_status);
689            tx_watchers.push((watcher, host.id.clone()));
690        }
691        assert!(!tx_watchers.is_empty());
692        let tx = self.comm_watcher_tx.clone();
693        tokio::spawn(async move {
694            loop {
695                let mut tx_status_futures = Vec::new();
696                for (watcher, _) in &mut tx_watchers {
697                    let fut = watcher.next().boxed();
698                    tx_status_futures.push(fut);
699                }
700                let (tx_status, index, _) = select_all(tx_status_futures).await;
701                let host_id = match tx_watchers.get(index) {
702                    Some((_, host_id)) => host_id.clone(),
703                    None => {
704                        // Should never happen
705                        tracing::error!(
706                            "got selected index {} with no matching host in {}",
707                            index,
708                            tx_watchers.len()
709                        );
710                        continue;
711                    }
712                };
713                if let Some(tx_status) = tx_status {
714                    tracing::debug!("host {} channel event: {:?}", host_id, tx_status);
715                    if tx_status == TxStatus::Closed {
716                        if tx.send(host_id.clone()).is_err() {
717                            // other side closed
718                            break;
719                        }
720                        tx_watchers.remove(index);
721                        if tx_watchers.is_empty() {
722                            // All of the statuses have been closed, exit the loop.
723                            break;
724                        }
725                    }
726                }
727            }
728        });
729    }
730
731    /// Ensure that we have the list of allocated hosts and that we send Allocate
732    /// request to all o fthem. Once done, start_comm_watcher() is called to start
733    /// the channel watcher.
734    /// Function is idempotent.
735    async fn ensure_started(&mut self) -> Result<(), anyhow::Error> {
736        if self.started || self.failed {
737            return Ok(());
738        }
739
740        self.started = true;
741        let hosts = self
742            .initializer
743            .initialize_alloc()
744            .await
745            .context("alloc initializer error")?;
746        if hosts.is_empty() {
747            anyhow::bail!("initializer returned empty list of hosts");
748        }
749        // prepare a list of host names in this allocation to be sent
750        // to remote allocators.
751        let hostnames: Vec<_> = hosts.iter().map(|e| e.hostname.clone()).collect();
752        tracing::info!("obtained {} hosts for this allocation", hostnames.len());
753
754        // We require at least a dimension for hosts, and one for sub-host (e.g., GPUs)
755        anyhow::ensure!(
756            self.spec.extent.len() >= 2,
757            "invalid extent: {}, expected at least 2 dimensions",
758            self.spec.extent
759        );
760
761        // We group by the innermost dimension of the extent.
762        let split_dim = &self.spec.extent.labels()[self.spec.extent.len() - 1];
763        for (i, region) in self.spec.extent.group_by(split_dim)?.enumerate() {
764            let host = &hosts[i];
765            tracing::debug!("allocating: {} for host: {}", region, host.id);
766
767            let remote_addr = match self.spec.transport {
768                ChannelTransport::MetaTls(_) => {
769                    format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
770                }
771                ChannelTransport::Tcp => {
772                    format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
773                }
774                // Used only for testing.
775                ChannelTransport::Unix => host.hostname.clone(),
776                _ => {
777                    anyhow::bail!(
778                        "unsupported transport for host {}: {:?}",
779                        host.id,
780                        self.spec.transport,
781                    );
782                }
783            };
784
785            tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
786            let remote_addr = remote_addr.parse::<ChannelAddr>()?;
787            let tx = channel::dial(remote_addr.clone())
788                .map_err(anyhow::Error::from)
789                .context(format!(
790                    "failed to dial remote {} for host {}",
791                    remote_addr, host.id
792                ))?;
793
794            // Possibly we could use the HostId directly here.
795            let alloc_key = ShortUuid::generate();
796            assert!(
797                self.alloc_to_host
798                    .insert(alloc_key.clone(), host.id.clone())
799                    .is_none()
800            );
801
802            let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
803            let client_context = Some(ClientContext { trace_id });
804            let message = RemoteProcessAllocatorMessage::Allocate {
805                alloc_key: alloc_key.clone(),
806                extent: region.extent(),
807                bootstrap_addr: self.bootstrap_addr.clone(),
808                hosts: hostnames.clone(),
809                client_context,
810            };
811            tracing::info!(
812                name = message.as_ref(),
813                "sending allocate message to workers"
814            );
815            tx.post(message);
816
817            self.host_states.insert(
818                host.id.clone(),
819                RemoteProcessAllocHostState {
820                    alloc_key,
821                    host_id: host.id.clone(),
822                    tx,
823                    active_procs: HashSet::new(),
824                    region,
825                    world_id: None,
826                    failed: false,
827                    allocated: false,
828                },
829                remote_addr,
830            );
831        }
832
833        self.ordered_hosts = hosts;
834        self.start_comm_watcher().await;
835        self.started = true;
836
837        Ok(())
838    }
839
840    // Given a proc_id, obtain the internal HostState structure.
841    fn get_host_state_mut(
842        &mut self,
843        alloc_key: &ShortUuid,
844    ) -> Result<&mut RemoteProcessAllocHostState, anyhow::Error> {
845        let host_id: &HostId = self
846            .alloc_to_host
847            .get(alloc_key)
848            .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
849
850        self.host_states
851            .get_mut(host_id)
852            .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
853    }
854
855    // Given a proc_id, obtain the internal HostState structure.
856    fn get_host_state(
857        &self,
858        alloc_key: &ShortUuid,
859    ) -> Result<&RemoteProcessAllocHostState, anyhow::Error> {
860        let host_id: &HostId = self
861            .alloc_to_host
862            .get(alloc_key)
863            .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
864
865        self.host_states
866            .get(host_id)
867            .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
868    }
869
870    fn remove_host_state(
871        &mut self,
872        alloc_key: &ShortUuid,
873    ) -> Result<RemoteProcessAllocHostState, anyhow::Error> {
874        let host_id: &HostId = self
875            .alloc_to_host
876            .get(alloc_key)
877            .ok_or_else(|| anyhow::anyhow!("alloc with key {} not found", alloc_key))?;
878
879        self.host_states
880            .remove(host_id)
881            .ok_or_else(|| anyhow::anyhow!("no host state found for host {}", host_id))
882    }
883
884    fn add_proc_id_to_host_state(
885        &mut self,
886        alloc_key: &ShortUuid,
887        create_key: &ShortUuid,
888    ) -> Result<(), anyhow::Error> {
889        let task_state = self.get_host_state_mut(alloc_key)?;
890        if !task_state.active_procs.insert(create_key.clone()) {
891            // Should not happen but we can ignore
892            tracing::error!("proc with create key {} already in host state", create_key);
893        }
894        task_state.allocated = true;
895        Ok(())
896    }
897
898    fn remove_proc_from_host_state(
899        &mut self,
900        alloc_key: &ShortUuid,
901        create_key: &ShortUuid,
902    ) -> Result<(), anyhow::Error> {
903        let task_state = self.get_host_state_mut(alloc_key)?;
904        if !task_state.active_procs.remove(create_key) {
905            // Should not happen but we can ignore
906            tracing::error!("proc with create_key already in host state: {}", create_key);
907        }
908        Ok(())
909    }
910
911    // Reproject proc world coords to global shape coords.
912    fn project_proc_into_global_extent(
913        &self,
914        alloc_key: &ShortUuid,
915        point: &Point,
916    ) -> Result<Point, anyhow::Error> {
917        let global_rank = self
918            .get_host_state(alloc_key)?
919            .region
920            .get(point.rank())
921            .ok_or_else(|| {
922                anyhow::anyhow!(
923                    "rank {} out of bounds for in alloc {}",
924                    point.rank(),
925                    alloc_key
926                )
927            })?;
928        Ok(self.spec.extent.point_of_rank(global_rank)?)
929    }
930
931    // Cleanup a comm-failed host information by its ID.
932    fn cleanup_host_channel_closed(
933        &mut self,
934        host_id: HostId,
935    ) -> Result<Vec<ShortUuid>, anyhow::Error> {
936        let state = match self.host_states.remove(&host_id) {
937            Some(state) => state,
938            None => {
939                // this should never happen.
940                anyhow::bail!(
941                    "got channel closed event for host {} which has no known state",
942                    host_id
943                );
944            }
945        };
946        self.ordered_hosts.retain(|host| host.id != host_id);
947        self.alloc_to_host.remove(&state.alloc_key);
948        if let Some(world_id) = state.world_id {
949            self.world_offsets.remove(&world_id);
950        }
951        let create_keys = state.active_procs.iter().cloned().collect();
952
953        Ok(create_keys)
954    }
955}
956
957#[async_trait]
958impl Alloc for RemoteProcessAlloc {
959    async fn next(&mut self) -> Option<ProcState> {
960        loop {
961            if let state @ Some(_) = self.event_queue.pop_front() {
962                break state;
963            }
964
965            if !self.running {
966                break None;
967            }
968
969            if let Err(e) = self.ensure_started().await {
970                break Some(ProcState::Failed {
971                    world_id: self.world_id.clone(),
972                    description: format!("failed to ensure started: {:#}", e),
973                });
974            }
975
976            let heartbeat_interval =
977                config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL);
978            let mut heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
979            // rerun outer loop in case we pushed new items to the event queue
980            let mut reloop = false;
981            let update = loop {
982                tokio::select! {
983                    msg = self.rx.recv() => {
984                        tracing::debug!("got ProcState message from allocator: {:?}", msg);
985                        match msg {
986                            Ok(RemoteProcessProcStateMessage::Allocated { alloc_key, world_id }) => {
987                                tracing::info!("remote alloc {}: allocated", alloc_key);
988                                match self.get_host_state_mut(&alloc_key) {
989                                    Ok(state) => {
990                                        state.world_id = Some(world_id.clone());
991                                    }
992                                    Err(err) => {
993                                        // should never happenA
994                                        tracing::error!(
995                                            "received allocated message alloc: {} with no known state: {}",
996                                            alloc_key, err,
997                                        );
998                                    }
999                                }
1000                            }
1001                            Ok(RemoteProcessProcStateMessage::Update(alloc_key, proc_state)) => {
1002                                let update = match proc_state {
1003                                    ProcState::Created { ref create_key, .. } => {
1004                                        if let Err(e) = self.add_proc_id_to_host_state(&alloc_key, create_key) {
1005                                            tracing::error!("failed to add proc with create key {} host state: {}", create_key, e);
1006                                        }
1007                                        proc_state
1008                                    }
1009                                    ProcState::Stopped{ ref create_key, ..} => {
1010                                        if let Err(e) = self.remove_proc_from_host_state(&alloc_key, create_key) {
1011                                            tracing::error!("failed to remove proc with create key {} host state: {}", create_key, e);
1012                                        }
1013                                        proc_state
1014                                    }
1015                                    ProcState::Failed { ref world_id, ref description } => {
1016                                        match self.get_host_state_mut(&alloc_key) {
1017                                            Ok(state) => {
1018                                                state.failed = true;
1019                                                ProcState::Failed {
1020                                                    world_id: world_id.clone(),
1021                                                    description: format!("host {} failed: {}", state.host_id, description),
1022                                                }
1023                                            }
1024                                            Err(e) => {
1025                                                tracing::error!("failed to find host state for world id: {}: {}", world_id, e);
1026                                                proc_state
1027                                            }
1028                                        }
1029                                    }
1030                                    _ => proc_state
1031                                };
1032
1033                                break Some((Some(alloc_key), update));
1034                            }
1035                            Ok(RemoteProcessProcStateMessage::Done(alloc_key)) => {
1036                                tracing::info!("allocator {} is done", alloc_key);
1037
1038                                if let Ok(state) = self.remove_host_state(&alloc_key) {
1039                                    if !state.active_procs.is_empty() {
1040                                        tracing::error!("received done for alloc {} with active procs: {:?}", alloc_key, state.active_procs);
1041                                    }
1042                                } else {
1043                                    tracing::error!("received done for unknown alloc {}", alloc_key);
1044                                }
1045
1046                                if self.host_states.is_empty() {
1047                                    self.running = false;
1048                                    break None;
1049                                }
1050                            }
1051                            // Hearbeat message is discarded immediately after being received, sender (remote
1052                            // process allocator) relies on channel ack to know if the receiver (client) is
1053                            // still alive. No state needs to be updated.
1054                            Ok(RemoteProcessProcStateMessage::HeartBeat) => {}
1055                            Err(e) => {
1056                                break Some((None, ProcState::Failed {world_id: self.world_id.clone(), description: format!("error receiving events: {}", e)}));
1057                            }
1058                        }
1059                    }
1060
1061                    _ = clock::RealClock.sleep_until(heartbeat_time) => {
1062                        self.host_states.iter().for_each(|(_, host_state)| host_state.tx.post(RemoteProcessAllocatorMessage::HeartBeat));
1063                        heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
1064                    }
1065
1066                    closed_host_id = self.comm_watcher_rx.recv() => {
1067                        if let Some(closed_host_id) = closed_host_id {
1068                            tracing::debug!("host {} channel closed, cleaning up", closed_host_id);
1069                            if let Some(state) = self.host_states.get(&closed_host_id)
1070                                && !state.allocated {
1071                                    break Some((None, ProcState::Failed {
1072                                        world_id: self.world_id.clone(),
1073                                        description: format!(
1074                                            "no process has ever been allocated on {} before the channel is closed; \
1075                                            a common issue could be the channel was never established",
1076                                            closed_host_id
1077                                        )}));
1078                                }
1079                            let create_keys = match self.cleanup_host_channel_closed(closed_host_id) {
1080                                Ok(create_keys) => create_keys,
1081                                Err(err) => {
1082                                    tracing::error!("failed to cleanup disconnected host: {}", err);
1083                                    continue;
1084                                }
1085                            };
1086                            for create_key in create_keys {
1087                                tracing::debug!("queuing Stopped state for proc with create key {}", create_key);
1088                                self.event_queue.push_back(
1089                                    ProcState::Stopped {
1090                                        create_key,
1091                                        reason: ProcStopReason::HostWatchdog
1092                                    }
1093                                );
1094                            }
1095                            // Check if there are any hosts left
1096                            if self.host_states.is_empty() {
1097                                tracing::info!("no more hosts left, stopping the alloc");
1098                                self.running = false;
1099                            }
1100                            // Kick back to the outer loop to pop off the queue if necessary
1101                            reloop = true;
1102                            break None;
1103                        } else {
1104                            // other side closed
1105                            tracing::warn!("unexpected comm watcher channel close");
1106                            break None;
1107                        }
1108                    }
1109                }
1110            };
1111
1112            if reloop {
1113                // Kick back to the outer loop to pop off the queue.
1114                // Should not have produced an update.
1115                assert!(update.is_none());
1116                continue;
1117            }
1118
1119            break match update {
1120                Some((
1121                    Some(alloc_key),
1122                    ProcState::Created {
1123                        create_key,
1124                        point,
1125                        pid,
1126                    },
1127                )) => match self.project_proc_into_global_extent(&alloc_key, &point) {
1128                    Ok(global_point) => {
1129                        tracing::debug!("reprojected coords: {} -> {}", point, global_point);
1130                        Some(ProcState::Created {
1131                            create_key,
1132                            point: global_point,
1133                            pid,
1134                        })
1135                    }
1136                    Err(e) => {
1137                        tracing::error!(
1138                            "failed to project coords for proc: {}.{}: {}",
1139                            alloc_key,
1140                            create_key,
1141                            e
1142                        );
1143                        None
1144                    }
1145                },
1146                Some((None, ProcState::Created { .. })) => {
1147                    panic!("illegal state: missing alloc_key for ProcState::Created event")
1148                }
1149                Some((_, update)) => {
1150                    if let ProcState::Failed { description, .. } = &update {
1151                        tracing::error!(description);
1152                        self.failed = true;
1153                    }
1154                    Some(update)
1155                }
1156                None => None,
1157            };
1158        }
1159    }
1160
1161    fn spec(&self) -> &AllocSpec {
1162        &self.spec
1163    }
1164
1165    fn extent(&self) -> &Extent {
1166        &self.spec.extent
1167    }
1168
1169    fn world_id(&self) -> &WorldId {
1170        &self.world_id
1171    }
1172
1173    async fn stop(&mut self) -> Result<(), AllocatorError> {
1174        tracing::info!("stopping alloc");
1175
1176        for (host_id, task_state) in self.host_states.iter_mut() {
1177            tracing::debug!("stopping alloc at host {}", host_id);
1178            task_state.tx.post(RemoteProcessAllocatorMessage::Stop);
1179        }
1180
1181        Ok(())
1182    }
1183}
1184
1185impl Drop for RemoteProcessAlloc {
1186    fn drop(&mut self) {
1187        tracing::debug!("dropping RemoteProcessAlloc of world_id {}", self.world_id);
1188    }
1189}
1190
1191#[cfg(test)]
1192mod test {
1193    use std::assert_matches::assert_matches;
1194
1195    use hyperactor::ActorRef;
1196    use hyperactor::channel::ChannelRx;
1197    use hyperactor::clock::ClockKind;
1198    use hyperactor::id;
1199    use ndslice::extent;
1200    use tokio::sync::oneshot;
1201
1202    use super::*;
1203    use crate::alloc::ChannelTransport;
1204    use crate::alloc::MockAlloc;
1205    use crate::alloc::MockAllocWrapper;
1206    use crate::alloc::MockAllocator;
1207    use crate::alloc::ProcStopReason;
1208    use crate::proc_mesh::mesh_agent::ProcMeshAgent;
1209
1210    async fn read_all_created(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1211        let mut i: usize = 0;
1212        while i < alloc_len {
1213            let m = rx.recv().await.unwrap();
1214            match m {
1215                RemoteProcessProcStateMessage::Update(_, ProcState::Created { .. }) => i += 1,
1216                RemoteProcessProcStateMessage::HeartBeat => {}
1217                _ => panic!("unexpected message: {:?}", m),
1218            }
1219        }
1220    }
1221
1222    async fn read_all_running(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1223        let mut i: usize = 0;
1224        while i < alloc_len {
1225            let m = rx.recv().await.unwrap();
1226            match m {
1227                RemoteProcessProcStateMessage::Update(_, ProcState::Running { .. }) => i += 1,
1228                RemoteProcessProcStateMessage::HeartBeat => {}
1229                _ => panic!("unexpected message: {:?}", m),
1230            }
1231        }
1232    }
1233
1234    async fn read_all_stopped(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1235        let mut i: usize = 0;
1236        while i < alloc_len {
1237            let m = rx.recv().await.unwrap();
1238            match m {
1239                RemoteProcessProcStateMessage::Update(_, ProcState::Stopped { .. }) => i += 1,
1240                RemoteProcessProcStateMessage::HeartBeat => {}
1241                _ => panic!("unexpected message: {:?}", m),
1242            }
1243        }
1244    }
1245
1246    fn set_procstate_expectations(alloc: &mut MockAlloc, extent: Extent) {
1247        alloc.expect_extent().return_const(extent.clone());
1248        let mut create_keys = Vec::new();
1249        for point in extent.points() {
1250            let create_key = ShortUuid::generate();
1251            create_keys.push(create_key.clone());
1252            alloc.expect_next().times(1).return_once(move || {
1253                Some(ProcState::Created {
1254                    create_key: create_key.clone(),
1255                    point,
1256                    pid: 0,
1257                })
1258            });
1259        }
1260        for (i, create_key) in create_keys
1261            .iter()
1262            .take(extent.num_ranks())
1263            .cloned()
1264            .enumerate()
1265        {
1266            let proc_id = format!("test[{i}]").parse().unwrap();
1267            let mesh_agent = ActorRef::<ProcMeshAgent>::attest(
1268                format!("test[{i}].mesh_agent[{i}]").parse().unwrap(),
1269            );
1270            alloc.expect_next().times(1).return_once(move || {
1271                Some(ProcState::Running {
1272                    create_key,
1273                    proc_id,
1274                    addr: ChannelAddr::Unix("/proc0".parse().unwrap()),
1275                    mesh_agent,
1276                })
1277            });
1278        }
1279        for create_key in create_keys.iter().take(extent.num_ranks()).cloned() {
1280            alloc.expect_next().times(1).return_once(|| {
1281                Some(ProcState::Stopped {
1282                    create_key,
1283                    reason: ProcStopReason::Unknown,
1284                })
1285            });
1286        }
1287    }
1288
1289    #[timed_test::async_timed_test(timeout_secs = 5)]
1290    async fn test_simple() {
1291        let config = hyperactor::config::global::lock();
1292        let _guard = config.override_key(
1293            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1294            Duration::from_millis(100),
1295        );
1296        hyperactor_telemetry::initialize_logging(ClockKind::default());
1297        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1298        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1299        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1300
1301        let extent = extent!(host = 1, gpu = 2);
1302        let tx = channel::dial(serve_addr.clone()).unwrap();
1303
1304        let world_id: WorldId = id!(test_world_id);
1305        let mut alloc = MockAlloc::new();
1306        alloc.expect_world_id().return_const(world_id.clone());
1307        alloc.expect_extent().return_const(extent.clone());
1308
1309        set_procstate_expectations(&mut alloc, extent.clone());
1310
1311        // final none
1312        alloc.expect_next().return_const(None);
1313
1314        let mut allocator = MockAllocator::new();
1315        let total_messages = extent.num_ranks() * 3 + 1;
1316        let mock_wrapper = MockAllocWrapper::new_block_next(
1317            alloc,
1318            // block after create, running, stopped and done.
1319            total_messages,
1320        );
1321        allocator
1322            .expect_allocate()
1323            .times(1)
1324            .return_once(move |_| Ok(mock_wrapper));
1325
1326        let remote_allocator = RemoteProcessAllocator::new();
1327        let handle = tokio::spawn({
1328            let remote_allocator = remote_allocator.clone();
1329            async move {
1330                remote_allocator
1331                    .start_with_allocator(serve_addr, allocator, None)
1332                    .await
1333            }
1334        });
1335
1336        let alloc_key = ShortUuid::generate();
1337
1338        tx.send(RemoteProcessAllocatorMessage::Allocate {
1339            alloc_key: alloc_key.clone(),
1340            extent: extent.clone(),
1341            bootstrap_addr,
1342            hosts: vec![],
1343            client_context: None,
1344        })
1345        .await
1346        .unwrap();
1347
1348        // Allocated
1349        let m = rx.recv().await.unwrap();
1350        assert_matches!(
1351            m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1352                if got_world_id == world_id && got_alloc_key == alloc_key
1353        );
1354
1355        // All Created events
1356        let mut rank: usize = 0;
1357        let mut create_keys = Vec::with_capacity(extent.num_ranks());
1358        while rank < extent.num_ranks() {
1359            let m = rx.recv().await.unwrap();
1360            match m {
1361                RemoteProcessProcStateMessage::Update(
1362                    got_alloc_key,
1363                    ProcState::Created {
1364                        create_key, point, ..
1365                    },
1366                ) => {
1367                    let expected_point = extent.point_of_rank(rank).unwrap();
1368                    assert_eq!(got_alloc_key, alloc_key);
1369                    assert_eq!(point, expected_point);
1370                    create_keys.push(create_key);
1371                    rank += 1;
1372                }
1373                RemoteProcessProcStateMessage::HeartBeat => {}
1374                _ => panic!("unexpected message: {:?}", m),
1375            }
1376        }
1377        // All Running events
1378        let mut rank: usize = 0;
1379        while rank < extent.num_ranks() {
1380            let m = rx.recv().await.unwrap();
1381            match m {
1382                RemoteProcessProcStateMessage::Update(
1383                    got_alloc_key,
1384                    ProcState::Running {
1385                        create_key,
1386                        proc_id,
1387                        mesh_agent,
1388                        addr: _,
1389                    },
1390                ) => {
1391                    assert_eq!(got_alloc_key, alloc_key);
1392                    assert_eq!(create_key, create_keys[rank]);
1393                    let expected_proc_id = format!("test[{}]", rank).parse().unwrap();
1394                    let expected_mesh_agent = ActorRef::<ProcMeshAgent>::attest(
1395                        format!("test[{}].mesh_agent[{}]", rank, rank)
1396                            .parse()
1397                            .unwrap(),
1398                    );
1399                    assert_eq!(proc_id, expected_proc_id);
1400                    assert_eq!(mesh_agent, expected_mesh_agent);
1401                    rank += 1;
1402                }
1403                RemoteProcessProcStateMessage::HeartBeat => {}
1404                _ => panic!("unexpected message: {:?}", m),
1405            }
1406        }
1407        // All Stopped events
1408        let mut rank: usize = 0;
1409        while rank < extent.num_ranks() {
1410            let m = rx.recv().await.unwrap();
1411            match m {
1412                RemoteProcessProcStateMessage::Update(
1413                    got_alloc_key,
1414                    ProcState::Stopped {
1415                        create_key,
1416                        reason: ProcStopReason::Unknown,
1417                    },
1418                ) => {
1419                    assert_eq!(got_alloc_key, alloc_key);
1420                    assert_eq!(create_key, create_keys[rank]);
1421                    rank += 1;
1422                }
1423                RemoteProcessProcStateMessage::HeartBeat => {}
1424                _ => panic!("unexpected message: {:?}", m),
1425            }
1426        }
1427        // Done
1428        loop {
1429            let m = rx.recv().await.unwrap();
1430            match m {
1431                RemoteProcessProcStateMessage::Done(got_alloc_key) => {
1432                    assert_eq!(got_alloc_key, alloc_key);
1433                    break;
1434                }
1435                RemoteProcessProcStateMessage::HeartBeat => {}
1436                _ => panic!("unexpected message: {:?}", m),
1437            }
1438        }
1439
1440        remote_allocator.terminate();
1441        handle.await.unwrap().unwrap();
1442    }
1443
1444    #[timed_test::async_timed_test(timeout_secs = 15)]
1445    async fn test_normal_stop() {
1446        let config = hyperactor::config::global::lock();
1447        let _guard = config.override_key(
1448            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1449            Duration::from_millis(100),
1450        );
1451        hyperactor_telemetry::initialize_logging(ClockKind::default());
1452        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1453        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1454        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1455
1456        let extent = extent!(host = 1, gpu = 2);
1457        let tx = channel::dial(serve_addr.clone()).unwrap();
1458
1459        let world_id: WorldId = id!(test_world_id);
1460        let mut alloc = MockAllocWrapper::new_block_next(
1461            MockAlloc::new(),
1462            // block after all created, all running
1463            extent.num_ranks() * 2,
1464        );
1465        let next_tx = alloc.notify_tx();
1466        alloc.alloc.expect_world_id().return_const(world_id.clone());
1467        alloc.alloc.expect_extent().return_const(extent.clone());
1468
1469        set_procstate_expectations(&mut alloc.alloc, extent.clone());
1470
1471        alloc.alloc.expect_next().return_const(None);
1472        alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1473
1474        let mut allocator = MockAllocator::new();
1475        allocator
1476            .expect_allocate()
1477            .times(1)
1478            .return_once(|_| Ok(alloc));
1479
1480        let remote_allocator = RemoteProcessAllocator::new();
1481        let handle = tokio::spawn({
1482            let remote_allocator = remote_allocator.clone();
1483            async move {
1484                remote_allocator
1485                    .start_with_allocator(serve_addr, allocator, None)
1486                    .await
1487            }
1488        });
1489
1490        let alloc_key = ShortUuid::generate();
1491        tx.send(RemoteProcessAllocatorMessage::Allocate {
1492            alloc_key: alloc_key.clone(),
1493            extent: extent.clone(),
1494            bootstrap_addr,
1495            hosts: vec![],
1496            client_context: None,
1497        })
1498        .await
1499        .unwrap();
1500
1501        // Allocated
1502        let m = rx.recv().await.unwrap();
1503        assert_matches!(
1504            m,
1505            RemoteProcessProcStateMessage::Allocated {  world_id: got_world_id, alloc_key: got_alloc_key }
1506            if world_id == got_world_id && alloc_key == got_alloc_key
1507        );
1508
1509        read_all_created(&mut rx, extent.num_ranks()).await;
1510        read_all_running(&mut rx, extent.num_ranks()).await;
1511
1512        // allocation finished. now we stop it.
1513        tracing::info!("stopping allocation");
1514        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1515        // receive all stops
1516        next_tx.send(()).unwrap();
1517
1518        read_all_stopped(&mut rx, extent.num_ranks()).await;
1519
1520        remote_allocator.terminate();
1521        handle.await.unwrap().unwrap();
1522    }
1523
1524    #[timed_test::async_timed_test(timeout_secs = 15)]
1525    async fn test_realloc() {
1526        let config = hyperactor::config::global::lock();
1527        let _guard = config.override_key(
1528            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1529            Duration::from_millis(100),
1530        );
1531        hyperactor_telemetry::initialize_logging(ClockKind::default());
1532        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1533        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1534        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1535
1536        let extent = extent!(host = 1, gpu = 2);
1537
1538        let tx = channel::dial(serve_addr.clone()).unwrap();
1539
1540        let world_id: WorldId = id!(test_world_id);
1541        let mut alloc1 = MockAllocWrapper::new_block_next(
1542            MockAlloc::new(),
1543            // block after all created, all running
1544            extent.num_ranks() * 2,
1545        );
1546        let next_tx1 = alloc1.notify_tx();
1547        alloc1
1548            .alloc
1549            .expect_world_id()
1550            .return_const(world_id.clone());
1551        alloc1.alloc.expect_extent().return_const(extent.clone());
1552
1553        set_procstate_expectations(&mut alloc1.alloc, extent.clone());
1554        alloc1.alloc.expect_next().return_const(None);
1555        alloc1.alloc.expect_stop().times(1).return_once(|| Ok(()));
1556        // second allocation
1557        let mut alloc2 = MockAllocWrapper::new_block_next(
1558            MockAlloc::new(),
1559            // block after all created, all running
1560            extent.num_ranks() * 2,
1561        );
1562        let next_tx2 = alloc2.notify_tx();
1563        alloc2
1564            .alloc
1565            .expect_world_id()
1566            .return_const(world_id.clone());
1567        alloc2.alloc.expect_extent().return_const(extent.clone());
1568        set_procstate_expectations(&mut alloc2.alloc, extent.clone());
1569        alloc2.alloc.expect_next().return_const(None);
1570        alloc2.alloc.expect_stop().times(1).return_once(|| Ok(()));
1571
1572        let mut allocator = MockAllocator::new();
1573        allocator
1574            .expect_allocate()
1575            .times(1)
1576            .return_once(|_| Ok(alloc1));
1577        // second alloc
1578        allocator
1579            .expect_allocate()
1580            .times(1)
1581            .return_once(|_| Ok(alloc2));
1582
1583        let remote_allocator = RemoteProcessAllocator::new();
1584        let handle = tokio::spawn({
1585            let remote_allocator = remote_allocator.clone();
1586            async move {
1587                remote_allocator
1588                    .start_with_allocator(serve_addr, allocator, None)
1589                    .await
1590            }
1591        });
1592
1593        let alloc_key = ShortUuid::generate();
1594
1595        tx.send(RemoteProcessAllocatorMessage::Allocate {
1596            alloc_key: alloc_key.clone(),
1597            extent: extent.clone(),
1598            bootstrap_addr: bootstrap_addr.clone(),
1599            hosts: vec![],
1600            client_context: None,
1601        })
1602        .await
1603        .unwrap();
1604
1605        // Allocated
1606        let m = rx.recv().await.unwrap();
1607        assert_matches!(
1608            m,
1609            RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1610            if got_world_id == world_id && got_alloc_key == alloc_key
1611        );
1612
1613        read_all_created(&mut rx, extent.num_ranks()).await;
1614        read_all_running(&mut rx, extent.num_ranks()).await;
1615
1616        let alloc_key = ShortUuid::generate();
1617
1618        // allocation finished now we request a new one
1619        tx.send(RemoteProcessAllocatorMessage::Allocate {
1620            alloc_key: alloc_key.clone(),
1621            extent: extent.clone(),
1622            bootstrap_addr,
1623            hosts: vec![],
1624            client_context: None,
1625        })
1626        .await
1627        .unwrap();
1628        // unblock next for the first allocation
1629        next_tx1.send(()).unwrap();
1630        // we expect a stop(), then Stopped proc states, then a new Allocated
1631        read_all_stopped(&mut rx, extent.num_ranks()).await;
1632        let m = rx.recv().await.unwrap();
1633        assert_matches!(m, RemoteProcessProcStateMessage::Done(_));
1634        let m = rx.recv().await.unwrap();
1635        assert_matches!(
1636            m,
1637            RemoteProcessProcStateMessage::Allocated { world_id: got_world_id, alloc_key: got_alloc_key }
1638            if got_world_id == world_id && got_alloc_key == alloc_key
1639        );
1640        // ProcStates for the new allocation
1641        read_all_created(&mut rx, extent.num_ranks()).await;
1642        read_all_running(&mut rx, extent.num_ranks()).await;
1643        // finally stop
1644        tracing::info!("stopping allocation");
1645        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1646        // receive all stops
1647        next_tx2.send(()).unwrap();
1648
1649        read_all_stopped(&mut rx, extent.num_ranks()).await;
1650
1651        remote_allocator.terminate();
1652        handle.await.unwrap().unwrap();
1653    }
1654
1655    #[timed_test::async_timed_test(timeout_secs = 15)]
1656    async fn test_upstream_closed() {
1657        // Use temporary config for this test
1658        let config = hyperactor::config::global::lock();
1659        let _guard1 = config.override_key(
1660            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1661            Duration::from_secs(1),
1662        );
1663        let _guard2 = config.override_key(
1664            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1665            Duration::from_millis(100),
1666        );
1667
1668        hyperactor_telemetry::initialize_logging(ClockKind::default());
1669        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1670        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1671        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1672
1673        let extent = extent!(host = 1, gpu = 2);
1674
1675        let tx = channel::dial(serve_addr.clone()).unwrap();
1676
1677        let world_id: WorldId = id!(test_world_id);
1678        let mut alloc = MockAllocWrapper::new_block_next(
1679            MockAlloc::new(),
1680            // block after all created, all running
1681            extent.num_ranks() * 2,
1682        );
1683        let next_tx = alloc.notify_tx();
1684        alloc.alloc.expect_world_id().return_const(world_id.clone());
1685        alloc.alloc.expect_extent().return_const(extent.clone());
1686
1687        set_procstate_expectations(&mut alloc.alloc, extent.clone());
1688
1689        alloc.alloc.expect_next().return_const(None);
1690        // we expect a stop due to the failure
1691        // synchronize test with the stop
1692        let (stop_tx, stop_rx) = oneshot::channel();
1693        alloc.alloc.expect_stop().times(1).return_once(|| {
1694            stop_tx.send(()).unwrap();
1695            Ok(())
1696        });
1697
1698        let mut allocator = MockAllocator::new();
1699        allocator
1700            .expect_allocate()
1701            .times(1)
1702            .return_once(|_| Ok(alloc));
1703
1704        let remote_allocator = RemoteProcessAllocator::new();
1705        let handle = tokio::spawn({
1706            let remote_allocator = remote_allocator.clone();
1707            async move {
1708                remote_allocator
1709                    .start_with_allocator(serve_addr, allocator, None)
1710                    .await
1711            }
1712        });
1713
1714        let alloc_key = ShortUuid::generate();
1715
1716        tx.send(RemoteProcessAllocatorMessage::Allocate {
1717            alloc_key: alloc_key.clone(),
1718            extent: extent.clone(),
1719            bootstrap_addr,
1720            hosts: vec![],
1721            client_context: None,
1722        })
1723        .await
1724        .unwrap();
1725
1726        // Allocated
1727        let m = rx.recv().await.unwrap();
1728        assert_matches!(
1729            m, RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1730                if got_world_id == world_id && got_alloc_key == alloc_key
1731        );
1732
1733        read_all_created(&mut rx, extent.num_ranks()).await;
1734        read_all_running(&mut rx, extent.num_ranks()).await;
1735
1736        // allocation finished. terminate connection.
1737        tracing::info!("closing upstream");
1738        drop(rx);
1739        // wait for the heartbeat to expire
1740        #[allow(clippy::disallowed_methods)]
1741        tokio::time::sleep(Duration::from_secs(2)).await;
1742        // wait for the stop to be called
1743        stop_rx.await.unwrap();
1744        // unblock next
1745        next_tx.send(()).unwrap();
1746        remote_allocator.terminate();
1747        handle.await.unwrap().unwrap();
1748    }
1749
1750    #[timed_test::async_timed_test(timeout_secs = 15)]
1751    async fn test_inner_alloc_failure() {
1752        let config = hyperactor::config::global::lock();
1753        let _guard = config.override_key(
1754            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1755            Duration::from_secs(60),
1756        );
1757        hyperactor_telemetry::initialize_logging(ClockKind::default());
1758        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1759        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1760        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1761
1762        let extent = extent!(host = 1, gpu = 2);
1763
1764        let tx = channel::dial(serve_addr.clone()).unwrap();
1765
1766        let test_world_id: WorldId = id!(test_world_id);
1767        let mut alloc = MockAllocWrapper::new_block_next(
1768            MockAlloc::new(),
1769            // block after the failure update
1770            1,
1771        );
1772        let next_tx = alloc.notify_tx();
1773        alloc
1774            .alloc
1775            .expect_world_id()
1776            .return_const(test_world_id.clone());
1777        alloc.alloc.expect_extent().return_const(extent.clone());
1778        alloc
1779            .alloc
1780            .expect_next()
1781            .times(1)
1782            .return_const(Some(ProcState::Failed {
1783                world_id: test_world_id.clone(),
1784                description: "test".to_string(),
1785            }));
1786        alloc.alloc.expect_next().times(1).return_const(None);
1787
1788        alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1789
1790        let mut allocator = MockAllocator::new();
1791        allocator
1792            .expect_allocate()
1793            .times(1)
1794            .return_once(|_| Ok(alloc));
1795
1796        let remote_allocator = RemoteProcessAllocator::new();
1797        let handle = tokio::spawn({
1798            let remote_allocator = remote_allocator.clone();
1799            async move {
1800                remote_allocator
1801                    .start_with_allocator(serve_addr, allocator, None)
1802                    .await
1803            }
1804        });
1805
1806        let alloc_key = ShortUuid::generate();
1807        tx.send(RemoteProcessAllocatorMessage::Allocate {
1808            alloc_key: alloc_key.clone(),
1809            extent: extent.clone(),
1810            bootstrap_addr,
1811            hosts: vec![],
1812            client_context: None,
1813        })
1814        .await
1815        .unwrap();
1816
1817        // Allocated
1818        let m = rx.recv().await.unwrap();
1819        assert_matches!(
1820            m,
1821            RemoteProcessProcStateMessage::Allocated {  world_id: got_world_id, alloc_key: got_alloc_key }
1822            if test_world_id == got_world_id && alloc_key == got_alloc_key
1823        );
1824
1825        // Failed
1826        let m = rx.recv().await.unwrap();
1827        assert_matches!(
1828            m,
1829            RemoteProcessProcStateMessage::Update(
1830                got_alloc_key,
1831                ProcState::Failed { world_id, description }
1832            ) if got_alloc_key == alloc_key && world_id == test_world_id && description == "test"
1833        );
1834
1835        tracing::info!("stopping allocation");
1836        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1837        // receive all stops
1838        next_tx.send(()).unwrap();
1839        // we are expecting 1 Done when Alloc successfully stops.
1840        let m = rx.recv().await.unwrap();
1841        assert_matches!(
1842            m,
1843            RemoteProcessProcStateMessage::Done(got_alloc_key)
1844            if got_alloc_key == alloc_key
1845        );
1846
1847        remote_allocator.terminate();
1848        handle.await.unwrap().unwrap();
1849    }
1850
1851    #[timed_test::async_timed_test(timeout_secs = 15)]
1852    async fn test_trace_id_propagation() {
1853        let config = hyperactor::config::global::lock();
1854        let _guard = config.override_key(
1855            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1856            Duration::from_secs(60),
1857        );
1858        hyperactor_telemetry::initialize_logging(ClockKind::default());
1859        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1860        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1861        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1862
1863        let extent = extent!(host = 1, gpu = 1);
1864        let tx = channel::dial(serve_addr.clone()).unwrap();
1865        let test_world_id: WorldId = id!(test_world_id);
1866        let test_trace_id = "test_trace_id_12345";
1867
1868        // Create a mock alloc that we can verify receives the correct trace id
1869        let mut alloc = MockAlloc::new();
1870        alloc.expect_world_id().return_const(test_world_id.clone());
1871        alloc.expect_extent().return_const(extent.clone());
1872        alloc.expect_next().return_const(None);
1873
1874        // Create a mock allocator that captures the AllocSpec passed to it
1875        let mut allocator = MockAllocator::new();
1876        allocator
1877            .expect_allocate()
1878            .times(1)
1879            .withf(move |spec: &AllocSpec| {
1880                // Verify that the trace id is correctly set in the constraints
1881                spec.constraints
1882                    .match_labels
1883                    .get(CLIENT_TRACE_ID_LABEL)
1884                    .is_some_and(|trace_id| trace_id == test_trace_id)
1885            })
1886            .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
1887
1888        let remote_allocator = RemoteProcessAllocator::new();
1889        let handle = tokio::spawn({
1890            let remote_allocator = remote_allocator.clone();
1891            async move {
1892                remote_allocator
1893                    .start_with_allocator(serve_addr, allocator, None)
1894                    .await
1895            }
1896        });
1897
1898        let alloc_key = ShortUuid::generate();
1899        tx.send(RemoteProcessAllocatorMessage::Allocate {
1900            alloc_key: alloc_key.clone(),
1901            extent: extent.clone(),
1902            bootstrap_addr,
1903            hosts: vec![],
1904            client_context: Some(ClientContext {
1905                trace_id: test_trace_id.to_string(),
1906            }),
1907        })
1908        .await
1909        .unwrap();
1910
1911        // Verify we get the allocated message
1912        let m = rx.recv().await.unwrap();
1913        assert_matches!(
1914            m,
1915            RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1916            if got_world_id == test_world_id && got_alloc_key == alloc_key
1917        );
1918
1919        // Verify we get the done message since the mock alloc returns None immediately
1920        let m = rx.recv().await.unwrap();
1921        assert_matches!(
1922            m,
1923            RemoteProcessProcStateMessage::Done(got_alloc_key)
1924            if alloc_key == got_alloc_key
1925        );
1926
1927        remote_allocator.terminate();
1928        handle.await.unwrap().unwrap();
1929    }
1930
1931    #[timed_test::async_timed_test(timeout_secs = 15)]
1932    async fn test_trace_id_propagation_no_client_context() {
1933        let config = hyperactor::config::global::lock();
1934        let _guard = config.override_key(
1935            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1936            Duration::from_secs(60),
1937        );
1938        hyperactor_telemetry::initialize_logging(ClockKind::default());
1939        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1940        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1941        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).unwrap();
1942
1943        let extent = extent!(host = 1, gpu = 1);
1944        let tx = channel::dial(serve_addr.clone()).unwrap();
1945        let test_world_id: WorldId = id!(test_world_id);
1946
1947        // Create a mock alloc
1948        let mut alloc = MockAlloc::new();
1949        alloc.expect_world_id().return_const(test_world_id.clone());
1950        alloc.expect_extent().return_const(extent.clone());
1951        alloc.expect_next().return_const(None);
1952
1953        // Create a mock allocator that verifies no trace id is set when client_context is None
1954        let mut allocator = MockAllocator::new();
1955        allocator
1956            .expect_allocate()
1957            .times(1)
1958            .withf(move |spec: &AllocSpec| {
1959                // Verify that no trace id is set in the constraints when client_context is None
1960                spec.constraints.match_labels.is_empty()
1961            })
1962            .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
1963
1964        let remote_allocator = RemoteProcessAllocator::new();
1965        let handle = tokio::spawn({
1966            let remote_allocator = remote_allocator.clone();
1967            async move {
1968                remote_allocator
1969                    .start_with_allocator(serve_addr, allocator, None)
1970                    .await
1971            }
1972        });
1973
1974        let alloc_key = ShortUuid::generate();
1975        tx.send(RemoteProcessAllocatorMessage::Allocate {
1976            alloc_key: alloc_key.clone(),
1977            extent: extent.clone(),
1978            bootstrap_addr,
1979            hosts: vec![],
1980            client_context: None,
1981        })
1982        .await
1983        .unwrap();
1984
1985        // Verify we get the allocated message
1986        let m = rx.recv().await.unwrap();
1987        assert_matches!(
1988            m,
1989            RemoteProcessProcStateMessage::Allocated { alloc_key: got_alloc_key, world_id: got_world_id }
1990            if got_world_id == test_world_id && got_alloc_key == alloc_key
1991        );
1992
1993        // Verify we get the done message since the mock alloc returns None immediately
1994        let m = rx.recv().await.unwrap();
1995        assert_matches!(
1996            m,
1997            RemoteProcessProcStateMessage::Done(got_alloc_key)
1998            if got_alloc_key == alloc_key
1999        );
2000
2001        remote_allocator.terminate();
2002        handle.await.unwrap().unwrap();
2003    }
2004}
2005
2006#[cfg(test)]
2007mod test_alloc {
2008    use std::os::unix::process::ExitStatusExt;
2009
2010    use hyperactor::clock::ClockKind;
2011    use hyperactor::config;
2012    use ndslice::extent;
2013    use nix::sys::signal;
2014    use nix::unistd::Pid;
2015    use timed_test::async_timed_test;
2016
2017    use super::*;
2018
2019    #[async_timed_test(timeout_secs = 60)]
2020    async fn test_alloc_simple() {
2021        // Use temporary config for this test
2022        let config = hyperactor::config::global::lock();
2023        let _guard1 = config.override_key(
2024            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2025            Duration::from_secs(1),
2026        );
2027        let _guard2 = config.override_key(
2028            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2029            Duration::from_millis(100),
2030        );
2031        hyperactor_telemetry::initialize_logging(ClockKind::default());
2032
2033        let spec = AllocSpec {
2034            extent: extent!(host = 2, gpu = 2),
2035            constraints: Default::default(),
2036            proc_name: None,
2037            transport: ChannelTransport::Unix,
2038        };
2039        let world_id = WorldId("test_world_id".to_string());
2040
2041        let task1_allocator = RemoteProcessAllocator::new();
2042        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2043        let task1_addr_string = task1_addr.to_string();
2044        let task1_cmd = Command::new(crate::testresource::get(
2045            "monarch/hyperactor_mesh/bootstrap",
2046        ));
2047        let task2_allocator = RemoteProcessAllocator::new();
2048        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2049        let task2_addr_string = task2_addr.to_string();
2050        let task2_cmd = Command::new(crate::testresource::get(
2051            "monarch/hyperactor_mesh/bootstrap",
2052        ));
2053        let task1_allocator_copy = task1_allocator.clone();
2054        let task1_allocator_handle = tokio::spawn(async move {
2055            tracing::info!("spawning task1");
2056            task1_allocator_copy
2057                .start(task1_cmd, task1_addr, None)
2058                .await
2059                .unwrap();
2060        });
2061        let task2_allocator_copy = task2_allocator.clone();
2062        let task2_allocator_handle = tokio::spawn(async move {
2063            task2_allocator_copy
2064                .start(task2_cmd, task2_addr, None)
2065                .await
2066                .unwrap();
2067        });
2068
2069        let mut initializer = MockRemoteProcessAllocInitializer::new();
2070        initializer.expect_initialize_alloc().return_once(move || {
2071            Ok(vec![
2072                RemoteProcessAllocHost {
2073                    hostname: task1_addr_string,
2074                    id: "task1".to_string(),
2075                },
2076                RemoteProcessAllocHost {
2077                    hostname: task2_addr_string,
2078                    id: "task2".to_string(),
2079                },
2080            ])
2081        });
2082        let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2083            .await
2084            .unwrap();
2085        let mut created = HashSet::new();
2086        let mut running_procs = HashSet::new();
2087        let mut proc_points = HashSet::new();
2088        for _ in 0..spec.extent.num_ranks() * 2 {
2089            let proc_state = alloc.next().await.unwrap();
2090            tracing::debug!("test got message: {:?}", proc_state);
2091            match proc_state {
2092                ProcState::Created {
2093                    create_key, point, ..
2094                } => {
2095                    created.insert(create_key);
2096                    proc_points.insert(point);
2097                }
2098                ProcState::Running { create_key, .. } => {
2099                    assert!(created.remove(&create_key));
2100                    running_procs.insert(create_key);
2101                }
2102                _ => panic!("expected Created or Running"),
2103            }
2104        }
2105        assert!(created.is_empty());
2106        // ensure coords coverage
2107        assert!(
2108            spec.extent
2109                .points()
2110                .all(|point| proc_points.contains(&point))
2111        );
2112
2113        // ensure no more pending items
2114        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2115        tokio::select! {
2116            _ = hyperactor::clock::RealClock.sleep_until(timeout) => {},
2117            _ = alloc.next() => panic!("expected no more items"),
2118        }
2119
2120        // stop the allocation
2121        alloc.stop().await.unwrap();
2122        for _ in 0..spec.extent.num_ranks() {
2123            let proc_state = alloc.next().await.unwrap();
2124            tracing::info!("test received next proc_state: {:?}", proc_state);
2125            match proc_state {
2126                ProcState::Stopped {
2127                    create_key, reason, ..
2128                } => {
2129                    assert!(running_procs.remove(&create_key));
2130                    assert_eq!(reason, ProcStopReason::Stopped);
2131                }
2132                _ => panic!("expected stopped"),
2133            }
2134        }
2135        // Exactly one None
2136        let proc_state = alloc.next().await;
2137        assert!(proc_state.is_none());
2138        // Anything afterwards is None
2139        let proc_state = alloc.next().await;
2140        assert!(proc_state.is_none());
2141
2142        task1_allocator.terminate();
2143        task1_allocator_handle.await.unwrap();
2144        task2_allocator.terminate();
2145        task2_allocator_handle.await.unwrap();
2146    }
2147
2148    #[async_timed_test(timeout_secs = 60)]
2149    async fn test_alloc_host_failure() {
2150        // Use temporary config for this test
2151        let config = hyperactor::config::global::lock();
2152        let _guard1 = config.override_key(
2153            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2154            Duration::from_secs(1),
2155        );
2156        let _guard2 = config.override_key(
2157            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2158            Duration::from_millis(100),
2159        );
2160        hyperactor_telemetry::initialize_logging(ClockKind::default());
2161
2162        let spec = AllocSpec {
2163            extent: extent!(host = 2, gpu = 2),
2164            constraints: Default::default(),
2165            proc_name: None,
2166            transport: ChannelTransport::Unix,
2167        };
2168        let world_id = WorldId("test_world_id".to_string());
2169
2170        let task1_allocator = RemoteProcessAllocator::new();
2171        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2172        let task1_addr_string = task1_addr.to_string();
2173        let task1_cmd = Command::new(crate::testresource::get(
2174            "monarch/hyperactor_mesh/bootstrap",
2175        ));
2176        let task2_allocator = RemoteProcessAllocator::new();
2177        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2178        let task2_addr_string = task2_addr.to_string();
2179        let task2_cmd = Command::new(crate::testresource::get(
2180            "monarch/hyperactor_mesh/bootstrap",
2181        ));
2182        let task1_allocator_copy = task1_allocator.clone();
2183        let task1_allocator_handle = tokio::spawn(async move {
2184            tracing::info!("spawning task1");
2185            task1_allocator_copy
2186                .start(task1_cmd, task1_addr, None)
2187                .await
2188                .unwrap();
2189            tracing::info!("task1 terminated");
2190        });
2191        let task2_allocator_copy = task2_allocator.clone();
2192        let task2_allocator_handle = tokio::spawn(async move {
2193            task2_allocator_copy
2194                .start(task2_cmd, task2_addr, None)
2195                .await
2196                .unwrap();
2197            tracing::info!("task2 terminated");
2198        });
2199
2200        let mut initializer = MockRemoteProcessAllocInitializer::new();
2201        initializer.expect_initialize_alloc().return_once(move || {
2202            Ok(vec![
2203                RemoteProcessAllocHost {
2204                    hostname: task1_addr_string,
2205                    id: "task1".to_string(),
2206                },
2207                RemoteProcessAllocHost {
2208                    hostname: task2_addr_string,
2209                    id: "task2".to_string(),
2210                },
2211            ])
2212        });
2213        let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2214            .await
2215            .unwrap();
2216        for _ in 0..spec.extent.num_ranks() * 2 {
2217            match alloc.next().await {
2218                Some(ProcState::Created { .. }) | Some(ProcState::Running { .. }) => {}
2219                _ => panic!("expected Created or Running"),
2220            }
2221        }
2222
2223        // ensure no more pending items
2224        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2225        tokio::select! {
2226            _ = hyperactor::clock::RealClock
2227            .sleep_until(timeout) => {},
2228            _ = alloc.next() => panic!("expected no more items"),
2229        }
2230
2231        // now we kill task1 and wait for timeout
2232        tracing::info!("aborting task1 allocator");
2233        task1_allocator_handle.abort();
2234        RealClock
2235            .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2)
2236            .await;
2237        for _ in 0..spec.extent.num_ranks() / 2 {
2238            let proc_state = alloc.next().await.unwrap();
2239            tracing::info!("test received next proc_state: {:?}", proc_state);
2240            match proc_state {
2241                ProcState::Stopped { reason, .. } => {
2242                    assert_eq!(reason, ProcStopReason::HostWatchdog);
2243                }
2244                _ => panic!("expected stopped"),
2245            }
2246        }
2247        // no more events
2248        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2249        tokio::select! {
2250            _ = hyperactor::clock::RealClock
2251            .sleep_until(timeout) => {},
2252            _ = alloc.next() => panic!("expected no more items"),
2253        }
2254
2255        // abort the second host
2256        tracing::info!("aborting task2 allocator");
2257        task2_allocator_handle.abort();
2258        RealClock
2259            .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2)
2260            .await;
2261        for _ in 0..spec.extent.num_ranks() / 2 {
2262            let proc_state = alloc.next().await.unwrap();
2263            tracing::info!("test received next proc_state: {:?}", proc_state);
2264            match proc_state {
2265                ProcState::Stopped { reason, .. } => {
2266                    assert_eq!(reason, ProcStopReason::HostWatchdog);
2267                }
2268                _ => panic!("expected stopped"),
2269            }
2270        }
2271        // now the alloc has stopped we always get None
2272        let proc_state = alloc.next().await;
2273        assert!(proc_state.is_none());
2274        // Anything afterwards is None
2275        let proc_state = alloc.next().await;
2276        assert!(proc_state.is_none());
2277    }
2278
2279    #[async_timed_test(timeout_secs = 15)]
2280    async fn test_alloc_inner_alloc_failure() {
2281        // SAFETY: Test happens in single-threaded code.
2282        unsafe {
2283            std::env::set_var("MONARCH_MESSAGE_DELIVERY_TIMEOUT_SECS", "1");
2284        }
2285        let config = hyperactor::config::global::lock();
2286        let _guard = config.override_key(
2287            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2288            Duration::from_millis(100),
2289        );
2290        hyperactor_telemetry::initialize_logging(ClockKind::default());
2291
2292        let spec = AllocSpec {
2293            extent: extent!(host = 2, gpu = 2),
2294            constraints: Default::default(),
2295            proc_name: None,
2296            transport: ChannelTransport::Unix,
2297        };
2298        let world_id = WorldId("test_world_id".to_string());
2299
2300        let task1_allocator = RemoteProcessAllocator::new();
2301        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2302        let task1_addr_string = task1_addr.to_string();
2303        let task1_cmd = Command::new(crate::testresource::get(
2304            "monarch/hyperactor_mesh/bootstrap",
2305        ));
2306        let task2_allocator = RemoteProcessAllocator::new();
2307        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2308        let task2_addr_string = task2_addr.to_string();
2309        // non-existent binary to fail the allocation
2310        let task2_cmd = Command::new("/caught/somewhere/in/time");
2311        let task1_allocator_copy = task1_allocator.clone();
2312        let task1_allocator_handle = tokio::spawn(async move {
2313            tracing::info!("spawning task1");
2314            task1_allocator_copy
2315                .start(task1_cmd, task1_addr, None)
2316                .await
2317                .unwrap();
2318        });
2319        let task2_allocator_copy = task2_allocator.clone();
2320        let task2_allocator_handle = tokio::spawn(async move {
2321            task2_allocator_copy
2322                .start(task2_cmd, task2_addr, None)
2323                .await
2324                .unwrap();
2325        });
2326
2327        let mut initializer = MockRemoteProcessAllocInitializer::new();
2328        initializer.expect_initialize_alloc().return_once(move || {
2329            Ok(vec![
2330                RemoteProcessAllocHost {
2331                    hostname: task1_addr_string,
2332                    id: "task1".to_string(),
2333                },
2334                RemoteProcessAllocHost {
2335                    hostname: task2_addr_string,
2336                    id: "task2".to_string(),
2337                },
2338            ])
2339        });
2340        let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, 0, initializer)
2341            .await
2342            .unwrap();
2343        let mut created = HashSet::new();
2344        let mut started_procs = HashSet::new();
2345        let mut proc_points = HashSet::new();
2346        let mut failed = 0;
2347        // task 1 procs + 1 failed event for task 2
2348        for _ in 0..spec.extent.num_ranks() + 1 {
2349            let proc_state = alloc.next().await.unwrap();
2350            tracing::debug!("test got message: {:?}", proc_state);
2351            match proc_state {
2352                ProcState::Created {
2353                    create_key, point, ..
2354                } => {
2355                    created.insert(create_key);
2356                    proc_points.insert(point);
2357                }
2358                ProcState::Running { create_key, .. } => {
2359                    assert!(created.remove(&create_key));
2360                    started_procs.insert(create_key);
2361                }
2362                ProcState::Failed { .. } => {
2363                    failed += 1;
2364                }
2365                _ => panic!("expected Created, Running or Failed"),
2366            }
2367        }
2368        assert!(created.is_empty());
2369        assert_eq!(failed, 1);
2370        // ensure coords coverage for task 1
2371        for rank in 0..spec.extent.num_ranks() / 2 {
2372            let point = spec.extent.point_of_rank(rank).unwrap();
2373            assert!(proc_points.contains(&point));
2374        }
2375
2376        // ensure no more pending items
2377        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2378        tokio::select! {
2379            _ = hyperactor::clock::RealClock
2380            .sleep_until(timeout) => {},
2381            _ = alloc.next() => panic!("expected no more items"),
2382        }
2383
2384        // stop the allocation
2385        alloc.stop().await.unwrap();
2386        for _ in 0..spec.extent.num_ranks() / 2 {
2387            let proc_state = alloc.next().await.unwrap();
2388            tracing::info!("test received next proc_state: {:?}", proc_state);
2389            match proc_state {
2390                ProcState::Stopped {
2391                    create_key, reason, ..
2392                } => {
2393                    assert!(started_procs.remove(&create_key));
2394                    assert_eq!(reason, ProcStopReason::Stopped);
2395                }
2396                _ => panic!("expected stopped"),
2397            }
2398        }
2399        // Exactly one None
2400        let proc_state = alloc.next().await;
2401        assert!(proc_state.is_none());
2402        // Anything afterwards is None
2403        let proc_state = alloc.next().await;
2404        assert!(proc_state.is_none());
2405
2406        task1_allocator.terminate();
2407        task1_allocator_handle.await.unwrap();
2408        task2_allocator.terminate();
2409        task2_allocator_handle.await.unwrap();
2410    }
2411
2412    #[tracing_test::traced_test]
2413    #[async_timed_test(timeout_secs = 60)]
2414    async fn test_remote_process_alloc_signal_handler() {
2415        let num_proc_meshes = 5;
2416        let hosts_per_proc_mesh = 5;
2417
2418        let pid_addr = ChannelAddr::any(ChannelTransport::Unix);
2419        let (pid_addr, mut pid_rx) = channel::serve::<u32>(pid_addr).unwrap();
2420
2421        let addresses = (0..(num_proc_meshes * hosts_per_proc_mesh))
2422            .map(|_| ChannelAddr::any(ChannelTransport::Unix).to_string())
2423            .collect::<Vec<_>>();
2424
2425        let remote_process_allocators = addresses
2426            .iter()
2427            .map(|addr| {
2428                Command::new(crate::testresource::get(
2429                    "monarch/hyperactor_mesh/remote_process_allocator",
2430                ))
2431                .env("RUST_LOG", "info")
2432                .arg(format!("--addr={addr}"))
2433                .stdout(std::process::Stdio::piped())
2434                .spawn()
2435                .unwrap()
2436            })
2437            .collect::<Vec<_>>();
2438
2439        let done_allocating_addr = ChannelAddr::any(ChannelTransport::Unix);
2440        let (done_allocating_addr, mut done_allocating_rx) =
2441            channel::serve::<()>(done_allocating_addr).unwrap();
2442        let mut remote_process_alloc = Command::new(crate::testresource::get(
2443            "monarch/hyperactor_mesh/remote_process_alloc",
2444        ))
2445        .arg(format!("--done-allocating-addr={}", done_allocating_addr))
2446        .arg(format!("--addresses={}", addresses.join(",")))
2447        .arg(format!("--num-proc-meshes={}", num_proc_meshes))
2448        .arg(format!("--hosts-per-proc-mesh={}", hosts_per_proc_mesh))
2449        .arg(format!("--pid-addr={}", pid_addr))
2450        .spawn()
2451        .unwrap();
2452
2453        done_allocating_rx.recv().await.unwrap();
2454        let mut received_pids = Vec::new();
2455        while let Ok(pid) = pid_rx.recv().await {
2456            received_pids.push(pid);
2457            if received_pids.len() == remote_process_allocators.len() {
2458                break;
2459            }
2460        }
2461
2462        signal::kill(
2463            Pid::from_raw(remote_process_alloc.id().unwrap() as i32),
2464            signal::SIGINT,
2465        )
2466        .unwrap();
2467
2468        assert_eq!(
2469            remote_process_alloc.wait().await.unwrap().signal(),
2470            Some(signal::SIGINT as i32)
2471        );
2472
2473        RealClock.sleep(tokio::time::Duration::from_secs(5)).await;
2474
2475        // Assert that the processes spawned by ProcessAllocator have been killed
2476        for child_pid in received_pids {
2477            let pid_check = Command::new("kill")
2478                .arg("-0")
2479                .arg(child_pid.to_string())
2480                .output()
2481                .await
2482                .expect("Failed to check if PID is alive");
2483
2484            assert!(
2485                !pid_check.status.success(),
2486                "PID {} should no longer be alive",
2487                child_pid
2488            );
2489        }
2490
2491        // Cleanup remote process allocator processes as SIGINT causes the current
2492        // allocs to stop but not the RemoteProcessAllocator loops
2493        for mut remote_process_allocator in remote_process_allocators {
2494            remote_process_allocator.kill().await.unwrap();
2495        }
2496    }
2497}