hyperactor_mesh/alloc/
remoteprocess.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::collections::HashSet;
11use std::collections::VecDeque;
12use std::sync::Arc;
13use std::time::Duration;
14
15use anyhow::Context;
16use async_trait::async_trait;
17use dashmap::DashMap;
18use futures::FutureExt;
19use futures::future::join_all;
20use futures::future::select_all;
21use hyperactor::Named;
22use hyperactor::ProcId;
23use hyperactor::WorldId;
24use hyperactor::channel;
25use hyperactor::channel::ChannelAddr;
26use hyperactor::channel::ChannelRx;
27use hyperactor::channel::ChannelTransport;
28use hyperactor::channel::ChannelTx;
29use hyperactor::channel::Rx;
30use hyperactor::channel::Tx;
31use hyperactor::channel::TxStatus;
32use hyperactor::clock;
33use hyperactor::clock::Clock;
34use hyperactor::clock::RealClock;
35use hyperactor::config;
36use hyperactor::mailbox::DialMailboxRouter;
37use hyperactor::mailbox::MailboxServer;
38use hyperactor::reference::Reference;
39use hyperactor::serde_json;
40use mockall::automock;
41use ndslice::View;
42use ndslice::ViewExt;
43use ndslice::view::Extent;
44use ndslice::view::Point;
45use serde::Deserialize;
46use serde::Serialize;
47use tokio::io::AsyncWriteExt;
48use tokio::process::Command;
49use tokio::sync::mpsc::UnboundedReceiver;
50use tokio::sync::mpsc::UnboundedSender;
51use tokio::sync::mpsc::unbounded_channel;
52use tokio::task::JoinHandle;
53use tokio_stream::StreamExt;
54use tokio_stream::wrappers::WatchStream;
55use tokio_util::sync::CancellationToken;
56
57use crate::alloc::Alloc;
58use crate::alloc::AllocConstraints;
59use crate::alloc::AllocSpec;
60use crate::alloc::Allocator;
61use crate::alloc::AllocatorError;
62use crate::alloc::ProcState;
63use crate::alloc::ProcStopReason;
64use crate::alloc::ProcessAllocator;
65use crate::alloc::process::CLIENT_TRACE_ID_LABEL;
66use crate::alloc::process::ClientContext;
67
68/// Control messages sent from remote process allocator to local allocator.
69#[derive(Debug, Clone, Serialize, Deserialize, Named)]
70pub enum RemoteProcessAllocatorMessage {
71    /// Create allocation with given spec and send updates to bootstrap_addr.
72    Allocate {
73        /// The view of a larger allocation for which this remote allocator is responsible.
74        view: View,
75        /// Bootstrap address to be used for sending updates.
76        bootstrap_addr: ChannelAddr,
77        /// Ordered list of hosts in this allocation. Can be used to
78        /// pre-populate the any local configurations such as torch.dist.
79        hosts: Vec<String>,
80        /// Client context which is passed to the ProcessAlloc
81        /// Todo: Once RemoteProcessAllocator moves to mailbox,
82        /// the client_context will go to the message header instead
83        client_context: Option<ClientContext>,
84    },
85    /// Stop allocation.
86    Stop,
87    /// Heartbeat message to check if remote process allocator and its
88    /// host are alive.
89    HeartBeat,
90}
91
92/// Control message sent from local allocator to remote allocator
93/// relaying process state updates.
94#[derive(Debug, Clone, Serialize, Deserialize, Named)]
95pub enum RemoteProcessProcStateMessage {
96    /// Allocation successful and Update, Done messages will follow.
97    Allocated { world_id: WorldId, view: View },
98    /// ProcState updates.
99    Update(ProcState),
100    /// Underlying Alloc is done.
101    Done(WorldId),
102    /// Heartbeat message to check if client is alive.
103    HeartBeat,
104}
105
106/// Allocator with a service frontend that wraps ProcessAllocator.
107pub struct RemoteProcessAllocator {
108    cancel_token: CancellationToken,
109}
110
111async fn conditional_sleeper<F: futures::Future<Output = ()>>(t: Option<F>) {
112    match t {
113        Some(timer) => timer.await,
114        None => futures::future::pending().await,
115    }
116}
117
118impl RemoteProcessAllocator {
119    /// Create a new allocator. It will not start until start() is called.
120    pub fn new() -> Arc<Self> {
121        Arc::new(Self {
122            cancel_token: CancellationToken::new(),
123        })
124    }
125
126    /// Stop the allocator. This will stop any ongoing allocations.
127    pub fn terminate(&self) {
128        self.cancel_token.cancel();
129    }
130
131    /// Start a remote process allocator with given cmd listening for
132    /// RemoteProcessAllocatorMessage on serve_addr. Call will block until cancelled.
133    /// The implementation is simple such that it can only handle one Alloc at
134    /// a time. Generally that's the most common use-case.
135    /// Flow works as follows:
136    /// 1. Client sends Allocate message to serve_addr.
137    /// 2. Allocator connects to bootstrap_addr, creates Alloc and sends Allocated message.
138    /// 3. Allocator streams one or more Update messages to bootstrap_addr as Alloc progresses
139    ///    making the following changes:
140    ///    * Remap mesh_agent listen address to our own forwarder actor address.
141    /// 4. Allocator sends Done message to bootstrap_addr when Alloc is done.
142    ///
143    /// At any point, client can send Stop message to serve_addr to stop the allocator.
144    /// If timeout is Some, the allocator will exit if no client connects within
145    /// that timeout, and no child allocation is running.
146    pub async fn start(
147        &self,
148        cmd: Command,
149        serve_addr: ChannelAddr,
150        timeout: Option<Duration>,
151    ) -> Result<(), anyhow::Error> {
152        let process_allocator = ProcessAllocator::new(cmd);
153        self.start_with_allocator(serve_addr, process_allocator, timeout)
154            .await
155    }
156
157    /// Start a remote process allocator with given allocator listening for
158    /// RemoteProcessAllocatorMessage on serve_addr.
159    /// Used for testing.
160    pub async fn start_with_allocator<A: Allocator + Send + Sync + 'static>(
161        &self,
162        serve_addr: ChannelAddr,
163        mut process_allocator: A,
164        timeout: Option<Duration>,
165    ) -> Result<(), anyhow::Error>
166    where
167        <A as Allocator>::Alloc: Send,
168        <A as Allocator>::Alloc: Sync,
169    {
170        tracing::info!("starting remote allocator on: {}", serve_addr);
171        let (_, mut rx) = channel::serve(serve_addr.clone())
172            .await
173            .map_err(anyhow::Error::from)?;
174
175        struct ActiveAllocation {
176            handle: JoinHandle<()>,
177            cancel_token: CancellationToken,
178        }
179        async fn ensure_previous_alloc_stopped(active_allocation: &mut Option<ActiveAllocation>) {
180            if let Some(active_allocation) = active_allocation.take() {
181                tracing::info!("previous alloc found, stopping");
182                active_allocation.cancel_token.cancel();
183                match active_allocation.handle.await {
184                    Ok(_) => {
185                        tracing::info!("allocation stopped.")
186                    }
187                    Err(e) => {
188                        tracing::error!("allocation handler failed: {}", e);
189                    }
190                }
191            }
192        }
193
194        let mut active_allocation: Option<ActiveAllocation> = None;
195        loop {
196            // Refresh each loop iteration so the timer updates whenever a message
197            // is received.
198            let sleep = conditional_sleeper(timeout.map(|t| RealClock.sleep(t)));
199            tokio::select! {
200                msg = rx.recv() => {
201                    match msg {
202                        Ok(RemoteProcessAllocatorMessage::Allocate {
203                            view,
204                            bootstrap_addr,
205                            hosts,
206                            client_context,
207                        }) => {
208                            tracing::info!("received allocation request for view: {}", view);
209                            ensure_previous_alloc_stopped(&mut active_allocation).await;
210                            tracing::info!("allocating...");
211
212                            // Create the corresponding local allocation spec.
213                            let mut constraints: AllocConstraints = Default::default();
214                            if let Some(context) = &client_context {
215                                constraints = AllocConstraints {
216                                    match_labels: HashMap::from([(
217                                    CLIENT_TRACE_ID_LABEL.to_string(),
218                                    context.trace_id.to_string(),
219                                    )]
220                                )};
221                            }
222                            let spec = AllocSpec {
223                                extent: view.extent(),
224                                constraints,
225                            };
226
227                            match process_allocator.allocate(spec.clone()).await {
228                                Ok(alloc) => {
229                                    let cancel_token = CancellationToken::new();
230                                    active_allocation = Some(ActiveAllocation {
231                                        cancel_token: cancel_token.clone(),
232                                        handle: tokio::spawn(Self::handle_allocation_request(
233                                            Box::new(alloc) as Box<dyn Alloc + Send + Sync>,
234                                            view,
235                                            serve_addr.transport(),
236                                            bootstrap_addr,
237                                            hosts,
238                                            cancel_token,
239                                        )),
240                                    })
241                                }
242                                Err(e) => {
243                                    tracing::error!("allocation for {:?} failed: {}", spec, e);
244                                    continue;
245                                }
246                            }
247                        }
248                        Ok(RemoteProcessAllocatorMessage::Stop) => {
249                            tracing::info!("received stop request");
250
251                            ensure_previous_alloc_stopped(&mut active_allocation).await;
252                        }
253                        // Hearbeat message is discarded immediately after being received, sender (client)
254                        // relies on channel ack to know if the receiver (remote process allocator) is
255                        // still alive. No state needs to be updated.
256                        Ok(RemoteProcessAllocatorMessage::HeartBeat) => {}
257                        Err(e) => {
258                            tracing::error!("upstream channel error: {}", e);
259                            continue;
260                        }
261                    }
262                }
263                _ = self.cancel_token.cancelled() => {
264                    tracing::info!("main loop cancelled");
265
266                    ensure_previous_alloc_stopped(&mut active_allocation).await;
267
268                    break;
269                }
270                _ = sleep => {
271                    // If there are any active allocations, reset the timeout.
272                    if active_allocation.is_some() {
273                        continue;
274                    }
275                    // Else, exit the loop as a client hasn't connected in a reasonable
276                    // amount of time.
277                    tracing::warn!("timeout of {} seconds elapsed without any allocations, exiting", timeout.unwrap_or_default().as_secs());
278                    break;
279                }
280            }
281        }
282
283        Ok(())
284    }
285
286    async fn handle_allocation_request(
287        alloc: Box<dyn Alloc + Send + Sync>,
288        view: View,
289        serve_transport: ChannelTransport,
290        bootstrap_addr: ChannelAddr,
291        hosts: Vec<String>,
292        cancel_token: CancellationToken,
293    ) {
294        tracing::info!("handle allocation request, bootstrap_addr: {bootstrap_addr}");
295        // start proc message forwarder
296        // Use serve_transport instead of bootstrap_addr's transport so the transports are
297        // consistent between the remote process allocator and the processes.
298        // The bootstrap_addr could be a different transport that the process might not be compatible with.
299        let (forwarder_addr, forwarder_rx) =
300            match channel::serve(ChannelAddr::any(serve_transport)).await {
301                Ok(v) => v,
302                Err(e) => {
303                    tracing::error!("failed to to bootstrap forwarder actor: {}", e);
304                    return;
305                }
306            };
307        let router = DialMailboxRouter::new();
308        let mailbox_handle = router.clone().serve(forwarder_rx);
309        tracing::info!("started forwarder on: {}", forwarder_addr);
310
311        // Check if we need to write TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE
312        // See: https://github.com/fairinternal/xlformers/blob/llama4_monarch/tools/launching/torchx/entrypoint/generate_ranks.py
313        if let Ok(hosts_file) = std::env::var("TORCH_ELASTIC_CUSTOM_HOSTNAMES_LIST_FILE") {
314            tracing::info!("writing hosts to {}", hosts_file);
315            #[derive(Serialize)]
316            struct Hosts {
317                hostnames: Vec<String>,
318            }
319            match serde_json::to_string(&Hosts { hostnames: hosts }) {
320                Ok(json) => match tokio::fs::File::create(&hosts_file).await {
321                    Ok(mut file) => {
322                        if file.write_all(json.as_bytes()).await.is_err() {
323                            tracing::error!("failed to write hosts to {}", hosts_file);
324                            return;
325                        }
326                    }
327                    Err(e) => {
328                        tracing::error!("failed to open hosts file {}: {}", hosts_file, e);
329                        return;
330                    }
331                },
332                Err(e) => {
333                    tracing::error!("failed to serialize hosts: {}", e);
334                    return;
335                }
336            }
337        }
338
339        Self::handle_allocation_loop(
340            alloc,
341            view,
342            bootstrap_addr,
343            router,
344            forwarder_addr,
345            cancel_token,
346        )
347        .await;
348
349        mailbox_handle.stop("alloc stopped");
350        if let Err(e) = mailbox_handle.await {
351            tracing::error!("failed to join forwarder: {}", e);
352        }
353    }
354
355    async fn handle_allocation_loop(
356        mut alloc: Box<dyn Alloc + Send + Sync>,
357        view: View,
358        bootstrap_addr: ChannelAddr,
359        router: DialMailboxRouter,
360        forward_addr: ChannelAddr,
361        cancel_token: CancellationToken,
362    ) {
363        tracing::info!("starting handle allocation loop");
364        let tx = match channel::dial(bootstrap_addr) {
365            Ok(tx) => tx,
366            Err(err) => {
367                tracing::error!("failed to dial bootstrap address: {}", err);
368                return;
369            }
370        };
371        if let Err(e) = tx
372            .send(RemoteProcessProcStateMessage::Allocated {
373                world_id: alloc.world_id().clone(),
374                view,
375            })
376            .await
377        {
378            tracing::error!("failed to send Allocated message: {}", e);
379            return;
380        }
381
382        let mut mesh_agents_by_proc_id = HashMap::new();
383        let mut running = true;
384        let tx_status = tx.status().clone();
385        let mut tx_watcher = WatchStream::new(tx_status);
386        loop {
387            tokio::select! {
388                _ = cancel_token.cancelled(), if running => {
389                    tracing::info!("cancelled, stopping allocation");
390                    running = false;
391                    if let Err(e) = alloc.stop().await {
392                        tracing::error!("stop failed: {}", e);
393                        break;
394                    }
395                }
396                status = tx_watcher.next(), if running => {
397                    match status  {
398                        Some(TxStatus::Closed) => {
399                            tracing::error!("upstream channel state closed");
400                            break;
401                        },
402                        _ => {
403                            tracing::debug!("got channel event: {:?}", status.unwrap());
404                            continue;
405                        }
406                    }
407                }
408                e = alloc.next() => {
409                    match e {
410                        Some(event) => {
411                            tracing::debug!("got event: {:?}", event);
412                            let event = match event {
413                                ProcState::Created { .. } => event,
414                                ProcState::Running { proc_id, mesh_agent, addr } => {
415                                    tracing::debug!("remapping mesh_agent {}: addr {} -> {}", mesh_agent, addr, forward_addr);
416                                    mesh_agents_by_proc_id.insert(proc_id.clone(), mesh_agent.clone());
417                                    router.bind(mesh_agent.actor_id().proc_id().clone().into(), addr);
418                                    ProcState::Running { proc_id, mesh_agent, addr: forward_addr.clone() }
419                                },
420                                ProcState::Stopped { proc_id, reason } => {
421                                    match mesh_agents_by_proc_id.remove(&proc_id) {
422                                        Some(mesh_agent) => {
423                                            tracing::debug!("unmapping mesh_agent {}", mesh_agent);
424                                            let agent_ref: Reference = mesh_agent.actor_id().proc_id().clone().into();
425                                            router.unbind(&agent_ref);
426                                        },
427                                        None => {
428                                            tracing::warn!("mesh_agent not found for proc_id: {}", proc_id);
429                                        }
430                                    }
431                                    ProcState::Stopped { proc_id, reason }
432                                },
433                                ProcState::Failed { ref world_id, ref description } => {
434                                    tracing::error!("allocation failed for {}: {}", world_id, description);
435                                    event
436                                }
437                            };
438                            tracing::debug!("sending event: {:?}", event);
439                            tx.post(RemoteProcessProcStateMessage::Update(event));
440                        }
441                        None => {
442                            tracing::debug!("sending done");
443                            tx.post(RemoteProcessProcStateMessage::Done(alloc.world_id().clone()));
444                            running = false;
445                            break;
446                        }
447                    }
448                }
449                _ = RealClock.sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)) => {
450                    tracing::trace!("sending heartbeat");
451                    tx.post(RemoteProcessProcStateMessage::HeartBeat);
452                }
453            }
454        }
455        tracing::info!("allocation handler loop exited");
456        if running {
457            tracing::info!("stopping processes");
458            if let Err(e) = alloc.stop_and_wait().await {
459                tracing::error!("stop failed: {}", e);
460                return;
461            }
462            tracing::info!("stop finished");
463        }
464    }
465}
466
467/// HostId is a string-based identifier of the host. It may be the same
468/// as the hostname but in some situations a separate ID may be used.
469type HostId = String;
470
471/// A host entry passed down from the initializer.
472#[derive(Clone)]
473pub struct RemoteProcessAllocHost {
474    /// A unique identifier of that host. Hostname may be used by ID can
475    /// be different in some situations.
476    pub id: HostId,
477    /// The FQDN of the host.
478    pub hostname: String,
479}
480
481/// State of a host in the RemoteProcessAlloc.
482struct RemoteProcessAllocHostState {
483    /// The host ID of the remote host.
484    host_id: HostId,
485    /// TX channel to the remote host allocator.
486    tx: ChannelTx<RemoteProcessAllocatorMessage>,
487    /// Set of active processes on this host.
488    active_procs: HashSet<ProcId>,
489    /// Slice offset for this host.
490    offset: usize,
491    /// World ID for this host as indicated from Allocated message.
492    world_id: Option<WorldId>,
493    /// If remote allocator sent us ProcState::Failed.
494    failed: bool,
495    /// If remote allocater has ever allocated a proc.
496    allocated: bool,
497}
498
499#[automock]
500#[async_trait]
501/// Interface to provide the set of hosts to be used by RemoteProcessAlloc.
502pub trait RemoteProcessAllocInitializer {
503    /// Initializes and returns a list of hosts to be used by this RemoteProcessAlloc.
504    async fn initialize_alloc(&mut self) -> Result<Vec<RemoteProcessAllocHost>, anyhow::Error>;
505}
506
507/// Wrapper struct around `HashMap<HostId, RemoteProcessAllocHostState>`
508/// to ensure that host addresses are synced with the signal handler
509struct HostStates {
510    inner: HashMap<HostId, RemoteProcessAllocHostState>,
511    host_addresses: Arc<DashMap<HostId, ChannelAddr>>,
512}
513
514impl HostStates {
515    fn new(host_addresses: Arc<DashMap<HostId, ChannelAddr>>) -> HostStates {
516        Self {
517            inner: HashMap::new(),
518            host_addresses,
519        }
520    }
521
522    fn insert(
523        &mut self,
524        host_id: HostId,
525        state: RemoteProcessAllocHostState,
526        address: ChannelAddr,
527    ) {
528        self.host_addresses.insert(host_id.clone(), address);
529        self.inner.insert(host_id, state);
530    }
531
532    fn get(&self, host_id: &HostId) -> Option<&RemoteProcessAllocHostState> {
533        self.inner.get(host_id)
534    }
535
536    fn get_mut(&mut self, host_id: &HostId) -> Option<&mut RemoteProcessAllocHostState> {
537        self.inner.get_mut(host_id)
538    }
539
540    fn remove(&mut self, host_id: &HostId) -> Option<RemoteProcessAllocHostState> {
541        self.host_addresses.remove(host_id);
542        self.inner.remove(host_id)
543    }
544
545    fn iter(&self) -> impl Iterator<Item = (&HostId, &RemoteProcessAllocHostState)> {
546        self.inner.iter()
547    }
548
549    fn iter_mut(&mut self) -> impl Iterator<Item = (&HostId, &mut RemoteProcessAllocHostState)> {
550        self.inner.iter_mut()
551    }
552
553    fn is_empty(&self) -> bool {
554        self.inner.is_empty()
555    }
556    // Any missing HashMap methods should be added here as needed
557}
558
559/// A generalized implementation of an Alloc using one or more hosts running
560/// RemoteProcessAlloc for process allocation.
561pub struct RemoteProcessAlloc {
562    // The initializer to be called at the first next() invocation to obtain
563    // allocated hosts.
564    initializer: Box<dyn RemoteProcessAllocInitializer + Send + Sync>,
565    spec: AllocSpec,
566    remote_allocator_port: u16,
567    transport: ChannelTransport,
568    world_id: WorldId,
569    ordered_hosts: Vec<RemoteProcessAllocHost>,
570    // Indicates that the initial remote allocation requests have been sent.
571    started: bool,
572    // Indicates that this Alloc is active (we have at least one remote process running).
573    running: bool,
574    // Inidicates that the allocation process has permanently failed.
575    failed: bool,
576    hosts_by_offset: HashMap<usize, HostId>,
577    host_states: HostStates,
578    world_offsets: HashMap<WorldId, usize>,
579    event_queue: VecDeque<ProcState>,
580    comm_watcher_tx: UnboundedSender<HostId>,
581    comm_watcher_rx: UnboundedReceiver<HostId>,
582
583    bootstrap_addr: ChannelAddr,
584    rx: ChannelRx<RemoteProcessProcStateMessage>,
585    _signal_cleanup_guard: hyperactor::SignalCleanupGuard,
586}
587
588impl RemoteProcessAlloc {
589    /// Create a new Alloc. initializer will be called on the first invocation of next()
590    /// to obtain a list of allocate hosts. Then Allocate message will be sent to all
591    /// RemoteProcessAllocator on all hosts. Heartbeats will be used to maintain health
592    /// status of remote hosts.
593    pub async fn new(
594        spec: AllocSpec,
595        world_id: WorldId,
596        transport: ChannelTransport,
597        remote_allocator_port: u16,
598        initializer: impl RemoteProcessAllocInitializer + Send + Sync + 'static,
599    ) -> Result<Self, anyhow::Error> {
600        let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(transport.clone()))
601            .await
602            .map_err(anyhow::Error::from)?;
603
604        tracing::info!(
605            "starting alloc for {} on: {}",
606            world_id,
607            bootstrap_addr.clone()
608        );
609
610        let (comm_watcher_tx, comm_watcher_rx) = unbounded_channel();
611
612        let host_addresses = Arc::new(DashMap::<HostId, ChannelAddr>::new());
613        let host_addresses_for_signal = host_addresses.clone();
614
615        // Register cleanup callback with global signal manager
616        let signal_cleanup_guard =
617            hyperactor::register_signal_cleanup_scoped(Box::pin(async move {
618                join_all(host_addresses_for_signal.iter().map(|entry| async move {
619                    let addr = entry.value().clone();
620                    match channel::dial(addr.clone()) {
621                        Ok(tx) => {
622                            if let Err(e) = tx.send(RemoteProcessAllocatorMessage::Stop).await {
623                                tracing::error!("Failed to send Stop to {}: {}", addr, e);
624                            }
625                        }
626                        Err(e) => {
627                            tracing::error!("Failed to dial {} during signal cleanup: {}", addr, e);
628                        }
629                    }
630                }))
631                .await;
632            }));
633
634        Ok(Self {
635            spec,
636            world_id,
637            transport,
638            remote_allocator_port,
639            initializer: Box::new(initializer),
640            world_offsets: HashMap::new(),
641            ordered_hosts: Vec::new(),
642            hosts_by_offset: HashMap::new(),
643            host_states: HostStates::new(host_addresses),
644            bootstrap_addr,
645            event_queue: VecDeque::new(),
646            comm_watcher_tx,
647            comm_watcher_rx,
648            rx,
649            started: false,
650            running: true,
651            failed: false,
652            _signal_cleanup_guard: signal_cleanup_guard,
653        })
654    }
655
656    /// Start a RemoteProcessAllocator tx watcher. It spawns a task that monitor the underlying
657    /// TxStatus and repot to the main next() loop via comm_watcher_tx. It does not however
658    /// send the actual heartbteats. Just watches the status.
659    /// The task will terminate once comm_watcher_rx is closed.
660    ///
661    /// A task approach was used here as opposed to select! to avoid nesting complexities with
662    /// select!() and select_all().
663    async fn start_comm_watcher(&self) {
664        let mut tx_watchers = Vec::new();
665        for host in &self.ordered_hosts {
666            let tx_status = self.host_states.get(&host.id).unwrap().tx.status().clone();
667            let watcher = WatchStream::new(tx_status);
668            tx_watchers.push((watcher, host.id.clone()));
669        }
670        assert!(!tx_watchers.is_empty());
671        let tx = self.comm_watcher_tx.clone();
672        tokio::spawn(async move {
673            loop {
674                let mut tx_status_futures = Vec::new();
675                for (watcher, _) in &mut tx_watchers {
676                    let fut = watcher.next().boxed();
677                    tx_status_futures.push(fut);
678                }
679                let (tx_status, index, _) = select_all(tx_status_futures).await;
680                let host_id = match tx_watchers.get(index) {
681                    Some((_, host_id)) => host_id.clone(),
682                    None => {
683                        // Should never happen
684                        tracing::error!(
685                            "got selected index {} with no matching host in {}",
686                            index,
687                            tx_watchers.len()
688                        );
689                        continue;
690                    }
691                };
692                if let Some(tx_status) = tx_status {
693                    tracing::debug!("host {} channel event: {:?}", host_id, tx_status);
694                    if tx_status == TxStatus::Closed {
695                        if tx.send(host_id.clone()).is_err() {
696                            // other side closed
697                            break;
698                        }
699                        tx_watchers.remove(index);
700                        if tx_watchers.is_empty() {
701                            // All of the statuses have been closed, exit the loop.
702                            break;
703                        }
704                    }
705                }
706            }
707        });
708    }
709
710    /// Ensure that we have the list of allocated hosts and that we send Allocate
711    /// request to all o fthem. Once done, start_comm_watcher() is called to start
712    /// the channel watcher.
713    /// Function is idempotent.
714    async fn ensure_started(&mut self) -> Result<(), anyhow::Error> {
715        if self.started || self.failed {
716            return Ok(());
717        }
718
719        self.started = true;
720        let hosts = self
721            .initializer
722            .initialize_alloc()
723            .await
724            .context("alloc initializer error")?;
725        if hosts.is_empty() {
726            anyhow::bail!("Initializer returned empty list of hosts");
727        }
728        // prepare a list of host names in this allocation to be sent
729        // to remote allocators.
730        let hostnames: Vec<_> = hosts.iter().map(|e| e.hostname.clone()).collect();
731        tracing::info!("obtained {} hosts for this allocation", hostnames.len());
732
733        // We require at least a dimension for hosts, and one for sub-host (e.g., GPUs)
734        anyhow::ensure!(
735            self.spec.extent.len() >= 2,
736            "invalid extent: {}, expected at least 2 dimensions",
737            self.spec.extent
738        );
739
740        // We group by the innermost dimension of the extent.
741        let split_dim = &self.spec.extent.labels()[self.spec.extent.len() - 1];
742        for (i, view) in self.spec.extent.group_by(split_dim)?.enumerate() {
743            let host = &hosts[i];
744            tracing::debug!("allocating: {} for host: {}", view, host.id);
745
746            let remote_addr = match self.transport {
747                ChannelTransport::MetaTls(_) => {
748                    format!("metatls!{}:{}", host.hostname, self.remote_allocator_port)
749                }
750                ChannelTransport::Tcp => {
751                    format!("tcp!{}:{}", host.hostname, self.remote_allocator_port)
752                }
753                // Used only for testing.
754                ChannelTransport::Unix => host.hostname.clone(),
755                _ => {
756                    anyhow::bail!(
757                        "unsupported transport for host {}: {:?}",
758                        host.id,
759                        self.transport
760                    );
761                }
762            };
763
764            let offset = offset_of_view(&view).unwrap();
765
766            tracing::debug!("dialing remote: {} for host {}", remote_addr, host.id);
767            let remote_addr = remote_addr.parse::<ChannelAddr>()?;
768            let tx = channel::dial(remote_addr.clone())
769                .map_err(anyhow::Error::from)
770                .context(format!(
771                    "failed to dial remote {} for host {}",
772                    remote_addr, host.id
773                ))?;
774
775            let trace_id = hyperactor_telemetry::trace::get_or_create_trace_id();
776            let client_context = Some(ClientContext { trace_id });
777
778            tx.post(RemoteProcessAllocatorMessage::Allocate {
779                view,
780                bootstrap_addr: self.bootstrap_addr.clone(),
781                hosts: hostnames.clone(),
782                client_context,
783            });
784
785            self.hosts_by_offset.insert(offset, host.id.clone());
786            self.host_states.insert(
787                host.id.clone(),
788                RemoteProcessAllocHostState {
789                    host_id: host.id.clone(),
790                    tx,
791                    active_procs: HashSet::new(),
792                    offset,
793                    world_id: None,
794                    failed: false,
795                    allocated: false,
796                },
797                remote_addr,
798            );
799        }
800
801        self.ordered_hosts = hosts;
802        self.start_comm_watcher().await;
803        self.started = true;
804
805        Ok(())
806    }
807
808    fn host_id_for_world_id(&self, world_id: &WorldId) -> Option<HostId> {
809        let offset = self.world_offsets.get(world_id)?;
810        self.hosts_by_offset.get(offset).cloned()
811    }
812
813    fn host_state_for_world_id(
814        &mut self,
815        world_id: &WorldId,
816    ) -> Result<&mut RemoteProcessAllocHostState, anyhow::Error> {
817        if let Some(host_id) = self.host_id_for_world_id(world_id) {
818            if let Some(task_state) = self.host_states.get_mut(&host_id) {
819                Ok(task_state)
820            } else {
821                anyhow::bail!(
822                    "task state not found for world id: {}, host id: {}",
823                    world_id,
824                    host_id
825                );
826            }
827        } else {
828            anyhow::bail!("task not found for world id: {}", world_id);
829        }
830    }
831
832    // Given a proc_id, obtain the internal HostState structure.
833    fn host_state_for_proc_id(
834        &mut self,
835        proc_id: &ProcId,
836    ) -> Result<&mut RemoteProcessAllocHostState, anyhow::Error> {
837        self.host_state_for_world_id(
838            proc_id
839                .world_id()
840                .expect("proc must be ranked for host state lookup"),
841        )
842    }
843
844    fn add_proc_id_to_host_state(&mut self, proc_id: &ProcId) -> Result<(), anyhow::Error> {
845        let task_state = self.host_state_for_proc_id(proc_id)?;
846        if !task_state.active_procs.insert(proc_id.clone()) {
847            // Should not happen but we can ignore
848            tracing::error!("proc id already in host state: {}", proc_id);
849        }
850        task_state.allocated = true;
851        Ok(())
852    }
853
854    fn remove_proc_from_host_state(&mut self, proc_id: &ProcId) -> Result<(), anyhow::Error> {
855        let task_state = self.host_state_for_proc_id(proc_id)?;
856        if !task_state.active_procs.remove(proc_id) {
857            // Should not happen but we can ignore
858            tracing::error!("proc id already in host state: {}", proc_id);
859        }
860        Ok(())
861    }
862
863    // Reproject proc world coords to global shape coords.
864    fn project_proc_into_global_extent(
865        &self,
866        proc_id: &ProcId,
867        point: &Point,
868    ) -> Result<Point, anyhow::Error> {
869        let world_id = proc_id
870            .world_id()
871            .expect("proc must be ranked for point mapping");
872        let offset = self
873            .world_offsets
874            .get(world_id)
875            .ok_or_else(|| anyhow::anyhow!("could not find offset for world id: {}", world_id))?;
876
877        Ok(self.spec.extent.point_of_rank(offset + point.rank())?)
878    }
879
880    // Cleanup a comm-failed host information by its ID.
881    fn cleanup_host_channel_closed(
882        &mut self,
883        host_id: HostId,
884    ) -> Result<Vec<ProcId>, anyhow::Error> {
885        let state = match self.host_states.remove(&host_id) {
886            Some(state) => state,
887            None => {
888                // this should never happen.
889                anyhow::bail!(
890                    "got channel closed event for host {} which has no known state",
891                    host_id
892                );
893            }
894        };
895        self.ordered_hosts.retain(|host| host.id != host_id);
896        self.hosts_by_offset.remove(&state.offset);
897        if let Some(world_id) = state.world_id {
898            self.world_offsets.remove(&world_id);
899        }
900        let proc_ids = state.active_procs.iter().cloned().collect();
901
902        Ok(proc_ids)
903    }
904}
905
906#[async_trait]
907impl Alloc for RemoteProcessAlloc {
908    async fn next(&mut self) -> Option<ProcState> {
909        loop {
910            if let state @ Some(_) = self.event_queue.pop_front() {
911                break state;
912            }
913
914            if !self.running {
915                break None;
916            }
917
918            if let Err(e) = self.ensure_started().await {
919                break Some(ProcState::Failed {
920                    world_id: self.world_id.clone(),
921                    description: format!("failed to ensure started: {:#}", e),
922                });
923            }
924
925            let heartbeat_interval =
926                config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL);
927            let mut heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
928            // rerun outer loop in case we pushed new items to the event queue
929            let mut reloop = false;
930            let update = loop {
931                tokio::select! {
932                    msg = self.rx.recv() => {
933                        tracing::debug!("got ProcState message from allocator: {:?}", msg);
934                        match msg {
935                            Ok(RemoteProcessProcStateMessage::Allocated { world_id, view }) => {
936                                let Some(offset) = offset_of_view(&view) else {
937                                    tracing::error!(
938                                        "allocated {}: cannot derive offset from allocated view: {}",
939                                        world_id,
940                                        view
941                                    );
942                                    continue;
943                                };
944                                tracing::info!("received allocated world id: {}", world_id);
945                                match self.hosts_by_offset.get(&offset) {
946                                    Some(host_id) => {
947                                        // update state
948                                        match self.host_states.get_mut(host_id) {
949                                            Some(state) => {
950                                                state.world_id = Some(world_id.clone());
951                                                if let Some(old_entry) = self.world_offsets.insert(world_id.clone(), offset) {
952                                                    // should never happen
953                                                    tracing::warn!(
954                                                        "got allocated for duplicate world id: {} with known offset: {}",
955                                                        world_id,
956                                                        old_entry
957                                                    );
958                                                }
959                                            }
960                                            None => {
961                                                // should never happen
962                                                tracing::error!(
963                                                    "got allocated for host ID: {} with no known state",
964                                                    host_id
965                                                );
966                                            }
967                                        }
968                                    }
969                                    None => {
970                                        // should never happen
971                                        tracing::error!(
972                                            "got allocated for unknown world: {}, view: {}",
973                                            world_id,
974                                            view
975                                        );
976                                    }
977                                }
978                            }
979                            Ok(RemoteProcessProcStateMessage::Update(proc_state)) => {
980                                break match proc_state {
981                                    ProcState::Created { ref proc_id, .. } => {
982                                        if let Err(e) = self.add_proc_id_to_host_state(proc_id) {
983                                            tracing::error!("failed to add proc id to host state: {}", e);
984                                        }
985                                        Some(proc_state)
986                                    }
987                                    ProcState::Stopped{ ref proc_id, ..} => {
988                                        if let Err(e) = self.remove_proc_from_host_state(proc_id) {
989                                            tracing::error!("failed to remove proc id from host state: {}", e);
990                                        }
991                                        Some(proc_state)
992                                    }
993                                    ProcState::Failed { ref world_id, ref description } => {
994                                        match self.host_state_for_world_id(world_id) {
995                                            Ok(state) => {
996                                                state.failed = true;
997                                                Some(ProcState::Failed {
998                                                    world_id: world_id.clone(),
999                                                    description: format!("host {} failed: {}", state.host_id, description),
1000                                                })
1001                                            }
1002                                            Err(e) => {
1003                                                tracing::error!("failed to find host state for world id: {}: {}", world_id, e);
1004                                                Some(proc_state)
1005                                            }
1006                                        }
1007                                    }
1008                                    _ => Some(proc_state)
1009
1010                                };
1011                            }
1012                            Ok(RemoteProcessProcStateMessage::Done(world_id)) => {
1013                                tracing::info!("allocator world_id: {} is done", world_id);
1014                                if let Some(host_id) = self.host_id_for_world_id(&world_id) {
1015                                    if let Some(state) = self.host_states.get(&host_id) {
1016                                        if !state.active_procs.is_empty() {
1017                                            tracing::error!("received done for world id: {} with active procs: {:?}", world_id, state.active_procs);
1018                                        }
1019                                    } else {
1020                                        tracing::error!("received done for unknown state world id: {}", world_id);
1021                                    }
1022                                } else {
1023                                    tracing::error!("received done for unknown world id: {}", world_id);
1024                                }
1025                                if self.world_offsets.remove(&world_id).is_none() {
1026                                    tracing::error!("received done for unknown world id: {}", world_id);
1027                                } else if self.world_offsets.is_empty() {
1028                                    self.running = false;
1029                                    break None;
1030                                }
1031                            }
1032                            // Hearbeat message is discarded immediately after being received, sender (remote
1033                            // process allocator) relies on channel ack to know if the receiver (client) is
1034                            // still alive. No state needs to be updated.
1035                            Ok(RemoteProcessProcStateMessage::HeartBeat) => {}
1036                            Err(e) => {
1037                                break Some(ProcState::Failed {world_id: self.world_id.clone(), description: format!("error receiving events: {}", e)});
1038                            }
1039                        }
1040                    }
1041
1042                    _ = clock::RealClock.sleep_until(heartbeat_time) => {
1043                        self.host_states.iter().for_each(|(_, host_state)| host_state.tx.post(RemoteProcessAllocatorMessage::HeartBeat));
1044                        heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval;
1045                    }
1046
1047                    closed_host_id = self.comm_watcher_rx.recv() => {
1048                        if let Some(closed_host_id) = closed_host_id {
1049                            tracing::debug!("host {} channel closed, cleaning up", closed_host_id);
1050                            if let Some(state) = self.host_states.get(&closed_host_id) {
1051                                if !state.allocated {
1052                                    break Some(ProcState::Failed {
1053                                        world_id: self.world_id.clone(),
1054                                        description: format!(
1055                                            "no process has ever been allocated on {} before the channel is closed; \
1056                                            a common issue could be the channel was never established",
1057                                            closed_host_id
1058                                        )});
1059                                }
1060                            }
1061                            let proc_ids = match self.cleanup_host_channel_closed(closed_host_id) {
1062                                Ok(proc_ids) => proc_ids,
1063                                Err(err) => {
1064                                    tracing::error!("failed to cleanup disconnected host: {}", err);
1065                                    continue;
1066                                }
1067                            };
1068                            for proc_id in proc_ids {
1069                                tracing::debug!("queuing Stopped state for {}", proc_id);
1070                                self.event_queue.push_back(ProcState::Stopped{proc_id, reason: ProcStopReason::HostWatchdog});
1071                            }
1072                            // Check if there are any hosts left
1073                            if self.host_states.is_empty() {
1074                                tracing::info!("no more hosts left, stopping the alloc");
1075                                self.running = false;
1076                            }
1077                            // Kick back to the outer loop to pop off the queue if necessary
1078                            reloop = true;
1079                            break None;
1080                        } else {
1081                            // other side closed
1082                            tracing::warn!("unexpected comm watcher channel close");
1083                            break None;
1084                        }
1085                    }
1086                }
1087            };
1088
1089            if reloop {
1090                // Kick back to the outer loop to pop off the queue.
1091                // Should not have produced an update.
1092                assert!(update.is_none());
1093                continue;
1094            }
1095
1096            break match update {
1097                Some(ProcState::Created {
1098                    proc_id,
1099                    point,
1100                    pid,
1101                }) => match self.project_proc_into_global_extent(&proc_id, &point) {
1102                    Ok(global_point) => {
1103                        tracing::debug!("reprojected coords: {} -> {}", point, global_point);
1104                        Some(ProcState::Created {
1105                            proc_id,
1106                            point: global_point,
1107                            pid,
1108                        })
1109                    }
1110                    Err(e) => {
1111                        tracing::error!("failed to project coords for proc: {}: {}", proc_id, e);
1112                        None
1113                    }
1114                },
1115
1116                Some(ProcState::Failed {
1117                    world_id: _,
1118                    ref description,
1119                }) => {
1120                    tracing::error!(description);
1121                    self.failed = true;
1122                    update
1123                }
1124
1125                _ => update,
1126            };
1127        }
1128    }
1129
1130    fn extent(&self) -> &Extent {
1131        &self.spec.extent
1132    }
1133
1134    fn world_id(&self) -> &WorldId {
1135        &self.world_id
1136    }
1137
1138    fn transport(&self) -> ChannelTransport {
1139        self.transport.clone()
1140    }
1141
1142    async fn stop(&mut self) -> Result<(), AllocatorError> {
1143        tracing::info!("stopping alloc");
1144
1145        for (host_id, task_state) in self.host_states.iter_mut() {
1146            tracing::debug!("stopping alloc at host {}", host_id);
1147            task_state.tx.post(RemoteProcessAllocatorMessage::Stop);
1148        }
1149
1150        Ok(())
1151    }
1152}
1153
1154impl Drop for RemoteProcessAlloc {
1155    fn drop(&mut self) {
1156        tracing::debug!("dropping RemoteProcessAlloc of world_id {}", self.world_id);
1157    }
1158}
1159
1160/// Computes the offset of the given view for allocation purposes.
1161/// The offset is the first rank in the view. Because we have grouped
1162/// this from an extent, we know these are contiguous.
1163fn offset_of_view(view: &View) -> Option<usize> {
1164    let (_point, offset) = view.iter().next()?;
1165    Some(offset)
1166}
1167
1168#[cfg(test)]
1169mod test {
1170    use std::assert_matches::assert_matches;
1171
1172    use hyperactor::ActorRef;
1173    use hyperactor::channel::ChannelRx;
1174    use hyperactor::clock::ClockKind;
1175    use hyperactor::id;
1176    use ndslice::extent;
1177    use tokio::sync::oneshot;
1178
1179    use super::*;
1180    use crate::alloc::ChannelTransport;
1181    use crate::alloc::MockAlloc;
1182    use crate::alloc::MockAllocWrapper;
1183    use crate::alloc::MockAllocator;
1184    use crate::alloc::ProcStopReason;
1185    use crate::proc_mesh::mesh_agent::MeshAgent;
1186
1187    async fn read_all_created(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1188        let mut i: usize = 0;
1189        while i < alloc_len {
1190            let m = rx.recv().await.unwrap();
1191            match m {
1192                RemoteProcessProcStateMessage::Update(ProcState::Created { .. }) => i += 1,
1193                RemoteProcessProcStateMessage::HeartBeat => {}
1194                _ => panic!("unexpected message: {:?}", m),
1195            }
1196        }
1197    }
1198
1199    async fn read_all_running(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1200        let mut i: usize = 0;
1201        while i < alloc_len {
1202            let m = rx.recv().await.unwrap();
1203            match m {
1204                RemoteProcessProcStateMessage::Update(ProcState::Running { .. }) => i += 1,
1205                RemoteProcessProcStateMessage::HeartBeat => {}
1206                _ => panic!("unexpected message: {:?}", m),
1207            }
1208        }
1209    }
1210
1211    async fn read_all_stopped(rx: &mut ChannelRx<RemoteProcessProcStateMessage>, alloc_len: usize) {
1212        let mut i: usize = 0;
1213        while i < alloc_len {
1214            let m = rx.recv().await.unwrap();
1215            match m {
1216                RemoteProcessProcStateMessage::Update(ProcState::Stopped { .. }) => i += 1,
1217                RemoteProcessProcStateMessage::HeartBeat => {}
1218                _ => panic!("unexpected message: {:?}", m),
1219            }
1220        }
1221    }
1222
1223    fn set_procstate_expectations(alloc: &mut MockAlloc, extent: Extent) {
1224        alloc.expect_extent().return_const(extent.clone());
1225        for (i, point) in extent.points().enumerate() {
1226            let proc_id = format!("test[{}]", i).parse().unwrap();
1227            alloc.expect_next().times(1).return_once(|| {
1228                Some(ProcState::Created {
1229                    proc_id,
1230                    point,
1231                    pid: 0,
1232                })
1233            });
1234        }
1235        for i in 0..extent.num_ranks() {
1236            let proc_id = format!("test[{}]", i).parse().unwrap();
1237            let mesh_agent = ActorRef::<MeshAgent>::attest(
1238                format!("test[{}].mesh_agent[{}]", i, i).parse().unwrap(),
1239            );
1240            alloc.expect_next().times(1).return_once(|| {
1241                Some(ProcState::Running {
1242                    proc_id,
1243                    addr: ChannelAddr::Unix("/proc0".parse().unwrap()),
1244                    mesh_agent,
1245                })
1246            });
1247        }
1248        for i in 0..extent.num_ranks() {
1249            let proc_id = format!("test[{}]", i).parse().unwrap();
1250            alloc.expect_next().times(1).return_once(|| {
1251                Some(ProcState::Stopped {
1252                    proc_id,
1253                    reason: ProcStopReason::Unknown,
1254                })
1255            });
1256        }
1257    }
1258
1259    #[timed_test::async_timed_test(timeout_secs = 5)]
1260    async fn test_simple() {
1261        let config = hyperactor::config::global::lock();
1262        let _guard = config.override_key(
1263            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1264            Duration::from_millis(100),
1265        );
1266        hyperactor_telemetry::initialize_logging(ClockKind::default());
1267        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1268        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1269        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1270
1271        let extent = extent!(host = 1, gpu = 2);
1272        let tx = channel::dial(serve_addr.clone()).unwrap();
1273
1274        let world_id: WorldId = id!(test_world_id);
1275        let mut alloc = MockAlloc::new();
1276        alloc.expect_world_id().return_const(world_id.clone());
1277        alloc.expect_extent().return_const(extent.clone());
1278
1279        set_procstate_expectations(&mut alloc, extent.clone());
1280
1281        // final none
1282        alloc.expect_next().return_const(None);
1283
1284        let mut allocator = MockAllocator::new();
1285        let total_messages = extent.num_ranks() * 3 + 1;
1286        let mock_wrapper = MockAllocWrapper::new_block_next(
1287            alloc,
1288            // block after create, running, stopped and done.
1289            total_messages,
1290        );
1291        allocator
1292            .expect_allocate()
1293            .times(1)
1294            .return_once(move |_| Ok(mock_wrapper));
1295
1296        let remote_allocator = RemoteProcessAllocator::new();
1297        let handle = tokio::spawn({
1298            let remote_allocator = remote_allocator.clone();
1299            async move {
1300                remote_allocator
1301                    .start_with_allocator(serve_addr, allocator, None)
1302                    .await
1303            }
1304        });
1305
1306        tx.send(RemoteProcessAllocatorMessage::Allocate {
1307            view: extent.clone().into(),
1308            bootstrap_addr,
1309            hosts: vec![],
1310            client_context: None,
1311        })
1312        .await
1313        .unwrap();
1314
1315        // Allocated
1316        let m = rx.recv().await.unwrap();
1317        assert_matches!(
1318            m, RemoteProcessProcStateMessage::Allocated { world_id, view }
1319                if world_id == world_id && extent == view.extent()
1320        );
1321
1322        // All Created events
1323        let mut rank: usize = 0;
1324        while rank < extent.num_ranks() {
1325            let m = rx.recv().await.unwrap();
1326            match m {
1327                RemoteProcessProcStateMessage::Update(ProcState::Created {
1328                    proc_id,
1329                    point,
1330                    ..
1331                }) => {
1332                    let expected_proc_id = format!("test[{}]", rank).parse().unwrap();
1333                    let expected_point = extent.point_of_rank(rank).unwrap();
1334                    assert_eq!(proc_id, expected_proc_id);
1335                    assert_eq!(point, expected_point);
1336                    rank += 1;
1337                }
1338                RemoteProcessProcStateMessage::HeartBeat => {}
1339                _ => panic!("unexpected message: {:?}", m),
1340            }
1341        }
1342        // All Running events
1343        let mut rank: usize = 0;
1344        while rank < extent.num_ranks() {
1345            let m = rx.recv().await.unwrap();
1346            match m {
1347                RemoteProcessProcStateMessage::Update(ProcState::Running {
1348                    proc_id,
1349                    mesh_agent,
1350                    addr: _,
1351                }) => {
1352                    let expected_proc_id = format!("test[{}]", rank).parse().unwrap();
1353                    let expected_mesh_agent = ActorRef::<MeshAgent>::attest(
1354                        format!("test[{}].mesh_agent[{}]", rank, rank)
1355                            .parse()
1356                            .unwrap(),
1357                    );
1358                    assert_eq!(proc_id, expected_proc_id);
1359                    assert_eq!(mesh_agent, expected_mesh_agent);
1360                    rank += 1;
1361                }
1362                RemoteProcessProcStateMessage::HeartBeat => {}
1363                _ => panic!("unexpected message: {:?}", m),
1364            }
1365        }
1366        // All Stopped events
1367        let mut rank: usize = 0;
1368        while rank < extent.num_ranks() {
1369            let m = rx.recv().await.unwrap();
1370            match m {
1371                RemoteProcessProcStateMessage::Update(ProcState::Stopped {
1372                    proc_id,
1373                    reason: ProcStopReason::Unknown,
1374                }) => {
1375                    let expected_proc_id = format!("test[{}]", rank).parse().unwrap();
1376                    assert_eq!(proc_id, expected_proc_id);
1377                    rank += 1;
1378                }
1379                RemoteProcessProcStateMessage::HeartBeat => {}
1380                _ => panic!("unexpected message: {:?}", m),
1381            }
1382        }
1383        // Done
1384        loop {
1385            let m = rx.recv().await.unwrap();
1386            match m {
1387                RemoteProcessProcStateMessage::Done(id) => {
1388                    assert_eq!(id, world_id);
1389                    break;
1390                }
1391                RemoteProcessProcStateMessage::HeartBeat => {}
1392                _ => panic!("unexpected message: {:?}", m),
1393            }
1394        }
1395
1396        remote_allocator.terminate();
1397        handle.await.unwrap().unwrap();
1398    }
1399
1400    #[timed_test::async_timed_test(timeout_secs = 15)]
1401    async fn test_normal_stop() {
1402        let config = hyperactor::config::global::lock();
1403        let _guard = config.override_key(
1404            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1405            Duration::from_millis(100),
1406        );
1407        hyperactor_telemetry::initialize_logging(ClockKind::default());
1408        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1409        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1410        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1411
1412        let extent = extent!(host = 1, gpu = 2);
1413        let tx = channel::dial(serve_addr.clone()).unwrap();
1414
1415        let world_id: WorldId = id!(test_world_id);
1416        let mut alloc = MockAllocWrapper::new_block_next(
1417            MockAlloc::new(),
1418            // block after all created, all running
1419            extent.num_ranks() * 2,
1420        );
1421        let next_tx = alloc.notify_tx();
1422        alloc.alloc.expect_world_id().return_const(world_id.clone());
1423        alloc.alloc.expect_extent().return_const(extent.clone());
1424
1425        set_procstate_expectations(&mut alloc.alloc, extent.clone());
1426
1427        alloc.alloc.expect_next().return_const(None);
1428        alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1429
1430        let mut allocator = MockAllocator::new();
1431        allocator
1432            .expect_allocate()
1433            .times(1)
1434            .return_once(|_| Ok(alloc));
1435
1436        let remote_allocator = RemoteProcessAllocator::new();
1437        let handle = tokio::spawn({
1438            let remote_allocator = remote_allocator.clone();
1439            async move {
1440                remote_allocator
1441                    .start_with_allocator(serve_addr, allocator, None)
1442                    .await
1443            }
1444        });
1445
1446        tx.send(RemoteProcessAllocatorMessage::Allocate {
1447            view: extent.clone().into(),
1448            bootstrap_addr,
1449            hosts: vec![],
1450            client_context: None,
1451        })
1452        .await
1453        .unwrap();
1454
1455        // Allocated
1456        let m = rx.recv().await.unwrap();
1457        assert_matches!(
1458            m,
1459            RemoteProcessProcStateMessage::Allocated { world_id, view }
1460            if world_id == world_id && view.extent() == extent
1461        );
1462
1463        read_all_created(&mut rx, extent.num_ranks()).await;
1464        read_all_running(&mut rx, extent.num_ranks()).await;
1465
1466        // allocation finished. now we stop it.
1467        tracing::info!("stopping allocation");
1468        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1469        // receive all stops
1470        next_tx.send(()).unwrap();
1471
1472        read_all_stopped(&mut rx, extent.num_ranks()).await;
1473
1474        remote_allocator.terminate();
1475        handle.await.unwrap().unwrap();
1476    }
1477
1478    #[timed_test::async_timed_test(timeout_secs = 15)]
1479    async fn test_realloc() {
1480        let config = hyperactor::config::global::lock();
1481        let _guard = config.override_key(
1482            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1483            Duration::from_millis(100),
1484        );
1485        hyperactor_telemetry::initialize_logging(ClockKind::default());
1486        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1487        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1488        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1489
1490        let extent = extent!(host = 1, gpu = 2);
1491
1492        let tx = channel::dial(serve_addr.clone()).unwrap();
1493
1494        let world_id: WorldId = id!(test_world_id);
1495        let mut alloc1 = MockAllocWrapper::new_block_next(
1496            MockAlloc::new(),
1497            // block after all created, all running
1498            extent.num_ranks() * 2,
1499        );
1500        let next_tx1 = alloc1.notify_tx();
1501        alloc1
1502            .alloc
1503            .expect_world_id()
1504            .return_const(world_id.clone());
1505        alloc1.alloc.expect_extent().return_const(extent.clone());
1506
1507        set_procstate_expectations(&mut alloc1.alloc, extent.clone());
1508        alloc1.alloc.expect_next().return_const(None);
1509        alloc1.alloc.expect_stop().times(1).return_once(|| Ok(()));
1510        // second allocation
1511        let mut alloc2 = MockAllocWrapper::new_block_next(
1512            MockAlloc::new(),
1513            // block after all created, all running
1514            extent.num_ranks() * 2,
1515        );
1516        let next_tx2 = alloc2.notify_tx();
1517        alloc2
1518            .alloc
1519            .expect_world_id()
1520            .return_const(world_id.clone());
1521        alloc2.alloc.expect_extent().return_const(extent.clone());
1522        set_procstate_expectations(&mut alloc2.alloc, extent.clone());
1523        alloc2.alloc.expect_next().return_const(None);
1524        alloc2.alloc.expect_stop().times(1).return_once(|| Ok(()));
1525
1526        let mut allocator = MockAllocator::new();
1527        allocator
1528            .expect_allocate()
1529            .times(1)
1530            .return_once(|_| Ok(alloc1));
1531        // second alloc
1532        allocator
1533            .expect_allocate()
1534            .times(1)
1535            .return_once(|_| Ok(alloc2));
1536
1537        let remote_allocator = RemoteProcessAllocator::new();
1538        let handle = tokio::spawn({
1539            let remote_allocator = remote_allocator.clone();
1540            async move {
1541                remote_allocator
1542                    .start_with_allocator(serve_addr, allocator, None)
1543                    .await
1544            }
1545        });
1546
1547        tx.send(RemoteProcessAllocatorMessage::Allocate {
1548            view: extent.clone().into(),
1549            bootstrap_addr: bootstrap_addr.clone(),
1550            hosts: vec![],
1551            client_context: None,
1552        })
1553        .await
1554        .unwrap();
1555
1556        // Allocated
1557        let m = rx.recv().await.unwrap();
1558        assert_matches!(
1559            m,
1560            RemoteProcessProcStateMessage::Allocated { world_id, view }
1561            if world_id == world_id && extent == view.extent()
1562        );
1563
1564        read_all_created(&mut rx, extent.num_ranks()).await;
1565        read_all_running(&mut rx, extent.num_ranks()).await;
1566
1567        // allocation finished now we request a new one
1568        tx.send(RemoteProcessAllocatorMessage::Allocate {
1569            view: extent.clone().into(),
1570            bootstrap_addr,
1571            hosts: vec![],
1572            client_context: None,
1573        })
1574        .await
1575        .unwrap();
1576        // unblock next for the first allocation
1577        next_tx1.send(()).unwrap();
1578        // we expect a stop(), then Stopped proc states, then a new Allocated
1579        read_all_stopped(&mut rx, extent.num_ranks()).await;
1580        let m = rx.recv().await.unwrap();
1581        assert_matches!(m, RemoteProcessProcStateMessage::Done(_));
1582        let m = rx.recv().await.unwrap();
1583        assert_matches!(
1584            m,
1585            RemoteProcessProcStateMessage::Allocated { world_id, view }
1586            if world_id == world_id && extent == view.extent()
1587        );
1588        // ProcStates for the new allocation
1589        read_all_created(&mut rx, extent.num_ranks()).await;
1590        read_all_running(&mut rx, extent.num_ranks()).await;
1591        // finally stop
1592        tracing::info!("stopping allocation");
1593        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1594        // receive all stops
1595        next_tx2.send(()).unwrap();
1596
1597        read_all_stopped(&mut rx, extent.num_ranks()).await;
1598
1599        remote_allocator.terminate();
1600        handle.await.unwrap().unwrap();
1601    }
1602
1603    #[timed_test::async_timed_test(timeout_secs = 15)]
1604    async fn test_upstream_closed() {
1605        // Use temporary config for this test
1606        let config = hyperactor::config::global::lock();
1607        let _guard1 = config.override_key(
1608            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1609            Duration::from_secs(1),
1610        );
1611        let _guard2 = config.override_key(
1612            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1613            Duration::from_millis(100),
1614        );
1615
1616        hyperactor_telemetry::initialize_logging(ClockKind::default());
1617        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1618        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1619        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1620
1621        let extent = extent!(host = 1, gpu = 2);
1622
1623        let tx = channel::dial(serve_addr.clone()).unwrap();
1624
1625        let world_id: WorldId = id!(test_world_id);
1626        let mut alloc = MockAllocWrapper::new_block_next(
1627            MockAlloc::new(),
1628            // block after all created, all running
1629            extent.num_ranks() * 2,
1630        );
1631        let next_tx = alloc.notify_tx();
1632        alloc.alloc.expect_world_id().return_const(world_id.clone());
1633        alloc.alloc.expect_extent().return_const(extent.clone());
1634
1635        set_procstate_expectations(&mut alloc.alloc, extent.clone());
1636
1637        alloc.alloc.expect_next().return_const(None);
1638        // we expect a stop due to the failure
1639        // synchronize test with the stop
1640        let (stop_tx, stop_rx) = oneshot::channel();
1641        alloc.alloc.expect_stop().times(1).return_once(|| {
1642            stop_tx.send(()).unwrap();
1643            Ok(())
1644        });
1645
1646        let mut allocator = MockAllocator::new();
1647        allocator
1648            .expect_allocate()
1649            .times(1)
1650            .return_once(|_| Ok(alloc));
1651
1652        let remote_allocator = RemoteProcessAllocator::new();
1653        let handle = tokio::spawn({
1654            let remote_allocator = remote_allocator.clone();
1655            async move {
1656                remote_allocator
1657                    .start_with_allocator(serve_addr, allocator, None)
1658                    .await
1659            }
1660        });
1661
1662        tx.send(RemoteProcessAllocatorMessage::Allocate {
1663            view: extent.clone().into(),
1664            bootstrap_addr,
1665            hosts: vec![],
1666            client_context: None,
1667        })
1668        .await
1669        .unwrap();
1670
1671        // Allocated
1672        let m = rx.recv().await.unwrap();
1673        assert_matches!(
1674            m,
1675            RemoteProcessProcStateMessage::Allocated { world_id, view }
1676            if world_id == world_id && view.extent() == extent
1677        );
1678
1679        read_all_created(&mut rx, extent.num_ranks()).await;
1680        read_all_running(&mut rx, extent.num_ranks()).await;
1681
1682        // allocation finished. terminate connection.
1683        tracing::info!("closing upstream");
1684        drop(rx);
1685        // wait for the heartbeat to expire
1686        #[allow(clippy::disallowed_methods)]
1687        tokio::time::sleep(Duration::from_secs(2)).await;
1688        // wait for the stop to be called
1689        stop_rx.await.unwrap();
1690        // unblock next
1691        next_tx.send(()).unwrap();
1692        remote_allocator.terminate();
1693        handle.await.unwrap().unwrap();
1694    }
1695
1696    #[timed_test::async_timed_test(timeout_secs = 15)]
1697    async fn test_inner_alloc_failure() {
1698        let config = hyperactor::config::global::lock();
1699        let _guard = config.override_key(
1700            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1701            Duration::from_secs(60),
1702        );
1703        hyperactor_telemetry::initialize_logging(ClockKind::default());
1704        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1705        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1706        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1707
1708        let extent = extent!(host = 1, gpu = 2);
1709
1710        let tx = channel::dial(serve_addr.clone()).unwrap();
1711
1712        let test_world_id: WorldId = id!(test_world_id);
1713        let mut alloc = MockAllocWrapper::new_block_next(
1714            MockAlloc::new(),
1715            // block after the failure update
1716            1,
1717        );
1718        let next_tx = alloc.notify_tx();
1719        alloc
1720            .alloc
1721            .expect_world_id()
1722            .return_const(test_world_id.clone());
1723        alloc.alloc.expect_extent().return_const(extent.clone());
1724        alloc
1725            .alloc
1726            .expect_next()
1727            .times(1)
1728            .return_const(Some(ProcState::Failed {
1729                world_id: test_world_id.clone(),
1730                description: "test".to_string(),
1731            }));
1732        alloc.alloc.expect_next().times(1).return_const(None);
1733
1734        alloc.alloc.expect_stop().times(1).return_once(|| Ok(()));
1735
1736        let mut allocator = MockAllocator::new();
1737        allocator
1738            .expect_allocate()
1739            .times(1)
1740            .return_once(|_| Ok(alloc));
1741
1742        let remote_allocator = RemoteProcessAllocator::new();
1743        let handle = tokio::spawn({
1744            let remote_allocator = remote_allocator.clone();
1745            async move {
1746                remote_allocator
1747                    .start_with_allocator(serve_addr, allocator, None)
1748                    .await
1749            }
1750        });
1751
1752        tx.send(RemoteProcessAllocatorMessage::Allocate {
1753            view: extent.clone().into(),
1754            bootstrap_addr,
1755            hosts: vec![],
1756            client_context: None,
1757        })
1758        .await
1759        .unwrap();
1760
1761        // Allocated
1762        let m = rx.recv().await.unwrap();
1763        assert_matches!(
1764            m,
1765            RemoteProcessProcStateMessage::Allocated { world_id, view }
1766            if world_id == test_world_id && view.extent() == extent
1767        );
1768        // Failed
1769        let m = rx.recv().await.unwrap();
1770        assert_matches!(m, RemoteProcessProcStateMessage::Update(ProcState::Failed {world_id, description}) if world_id == test_world_id && description == "test");
1771
1772        tracing::info!("stopping allocation");
1773        tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap();
1774        // receive all stops
1775        next_tx.send(()).unwrap();
1776        // we are expecting 1 Done when Alloc successfully stops.
1777        let m = rx.recv().await.unwrap();
1778        assert_matches!(m, RemoteProcessProcStateMessage::Done(world_id) if world_id == test_world_id);
1779
1780        remote_allocator.terminate();
1781        handle.await.unwrap().unwrap();
1782    }
1783
1784    #[timed_test::async_timed_test(timeout_secs = 15)]
1785    async fn test_trace_id_propagation() {
1786        let config = hyperactor::config::global::lock();
1787        let _guard = config.override_key(
1788            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1789            Duration::from_secs(60),
1790        );
1791        hyperactor_telemetry::initialize_logging(ClockKind::default());
1792        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1793        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1794        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1795
1796        let extent = extent!(host = 1, gpu = 1);
1797        let tx = channel::dial(serve_addr.clone()).unwrap();
1798        let test_world_id: WorldId = id!(test_world_id);
1799        let test_trace_id = "test_trace_id_12345";
1800
1801        // Create a mock alloc that we can verify receives the correct trace id
1802        let mut alloc = MockAlloc::new();
1803        alloc.expect_world_id().return_const(test_world_id.clone());
1804        alloc.expect_extent().return_const(extent.clone());
1805        alloc.expect_next().return_const(None);
1806
1807        // Create a mock allocator that captures the AllocSpec passed to it
1808        let mut allocator = MockAllocator::new();
1809        allocator
1810            .expect_allocate()
1811            .times(1)
1812            .withf(move |spec: &AllocSpec| {
1813                // Verify that the trace id is correctly set in the constraints
1814                spec.constraints
1815                    .match_labels
1816                    .get(CLIENT_TRACE_ID_LABEL)
1817                    .is_some_and(|trace_id| trace_id == test_trace_id)
1818            })
1819            .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
1820
1821        let remote_allocator = RemoteProcessAllocator::new();
1822        let handle = tokio::spawn({
1823            let remote_allocator = remote_allocator.clone();
1824            async move {
1825                remote_allocator
1826                    .start_with_allocator(serve_addr, allocator, None)
1827                    .await
1828            }
1829        });
1830
1831        // Send allocate message with client context containing trace id
1832        tx.send(RemoteProcessAllocatorMessage::Allocate {
1833            view: extent.clone().into(),
1834            bootstrap_addr,
1835            hosts: vec![],
1836            client_context: Some(ClientContext {
1837                trace_id: test_trace_id.to_string(),
1838            }),
1839        })
1840        .await
1841        .unwrap();
1842
1843        // Verify we get the allocated message
1844        let m = rx.recv().await.unwrap();
1845        assert_matches!(
1846            m,
1847            RemoteProcessProcStateMessage::Allocated { world_id, view }
1848            if world_id == test_world_id && view.extent() == extent
1849        );
1850
1851        // Verify we get the done message since the mock alloc returns None immediately
1852        let m = rx.recv().await.unwrap();
1853        assert_matches!(m, RemoteProcessProcStateMessage::Done(world_id) if world_id == test_world_id);
1854
1855        remote_allocator.terminate();
1856        handle.await.unwrap().unwrap();
1857    }
1858
1859    #[timed_test::async_timed_test(timeout_secs = 15)]
1860    async fn test_trace_id_propagation_no_client_context() {
1861        let config = hyperactor::config::global::lock();
1862        let _guard = config.override_key(
1863            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1864            Duration::from_secs(60),
1865        );
1866        hyperactor_telemetry::initialize_logging(ClockKind::default());
1867        let serve_addr = ChannelAddr::any(ChannelTransport::Unix);
1868        let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix);
1869        let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap();
1870
1871        let extent = extent!(host = 1, gpu = 1);
1872        let tx = channel::dial(serve_addr.clone()).unwrap();
1873        let test_world_id: WorldId = id!(test_world_id);
1874
1875        // Create a mock alloc
1876        let mut alloc = MockAlloc::new();
1877        alloc.expect_world_id().return_const(test_world_id.clone());
1878        alloc.expect_extent().return_const(extent.clone());
1879        alloc.expect_next().return_const(None);
1880
1881        // Create a mock allocator that verifies no trace id is set when client_context is None
1882        let mut allocator = MockAllocator::new();
1883        allocator
1884            .expect_allocate()
1885            .times(1)
1886            .withf(move |spec: &AllocSpec| {
1887                // Verify that no trace id is set in the constraints when client_context is None
1888                spec.constraints.match_labels.is_empty()
1889            })
1890            .return_once(|_| Ok(MockAllocWrapper::new(alloc)));
1891
1892        let remote_allocator = RemoteProcessAllocator::new();
1893        let handle = tokio::spawn({
1894            let remote_allocator = remote_allocator.clone();
1895            async move {
1896                remote_allocator
1897                    .start_with_allocator(serve_addr, allocator, None)
1898                    .await
1899            }
1900        });
1901
1902        // Send allocate message without client context
1903        tx.send(RemoteProcessAllocatorMessage::Allocate {
1904            view: extent.clone().into(),
1905            bootstrap_addr,
1906            hosts: vec![],
1907            client_context: None,
1908        })
1909        .await
1910        .unwrap();
1911
1912        // Verify we get the allocated message
1913        let m = rx.recv().await.unwrap();
1914        assert_matches!(
1915            m,
1916            RemoteProcessProcStateMessage::Allocated { world_id, view }
1917            if world_id == test_world_id && view.extent() == extent
1918        );
1919
1920        // Verify we get the done message since the mock alloc returns None immediately
1921        let m = rx.recv().await.unwrap();
1922        assert_matches!(m, RemoteProcessProcStateMessage::Done(world_id) if world_id == test_world_id);
1923
1924        remote_allocator.terminate();
1925        handle.await.unwrap().unwrap();
1926    }
1927}
1928
1929#[cfg(test)]
1930mod test_alloc {
1931    use std::os::unix::process::ExitStatusExt;
1932
1933    use hyperactor::clock::ClockKind;
1934    use hyperactor::config;
1935    use ndslice::extent;
1936    use nix::sys::signal;
1937    use nix::unistd::Pid;
1938    use timed_test::async_timed_test;
1939
1940    use super::*;
1941
1942    #[async_timed_test(timeout_secs = 60)]
1943    async fn test_alloc_simple() {
1944        // Use temporary config for this test
1945        let config = hyperactor::config::global::lock();
1946        let _guard1 = config.override_key(
1947            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
1948            Duration::from_secs(1),
1949        );
1950        let _guard2 = config.override_key(
1951            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
1952            Duration::from_millis(100),
1953        );
1954        hyperactor_telemetry::initialize_logging(ClockKind::default());
1955
1956        let spec = AllocSpec {
1957            extent: extent!(host = 2, gpu = 2),
1958            constraints: Default::default(),
1959        };
1960        let world_id = WorldId("test_world_id".to_string());
1961        let transport = ChannelTransport::Unix;
1962
1963        let task1_allocator = RemoteProcessAllocator::new();
1964        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
1965        let task1_addr_string = task1_addr.to_string();
1966        let task1_cmd =
1967            Command::new(buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap());
1968        let task2_allocator = RemoteProcessAllocator::new();
1969        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
1970        let task2_addr_string = task2_addr.to_string();
1971        let task2_cmd =
1972            Command::new(buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap());
1973        let task1_allocator_copy = task1_allocator.clone();
1974        let task1_allocator_handle = tokio::spawn(async move {
1975            tracing::info!("spawning task1");
1976            task1_allocator_copy
1977                .start(task1_cmd, task1_addr, None)
1978                .await
1979                .unwrap();
1980        });
1981        let task2_allocator_copy = task2_allocator.clone();
1982        let task2_allocator_handle = tokio::spawn(async move {
1983            task2_allocator_copy
1984                .start(task2_cmd, task2_addr, None)
1985                .await
1986                .unwrap();
1987        });
1988
1989        let mut initializer = MockRemoteProcessAllocInitializer::new();
1990        initializer.expect_initialize_alloc().return_once(move || {
1991            Ok(vec![
1992                RemoteProcessAllocHost {
1993                    hostname: task1_addr_string,
1994                    id: "task1".to_string(),
1995                },
1996                RemoteProcessAllocHost {
1997                    hostname: task2_addr_string,
1998                    id: "task2".to_string(),
1999                },
2000            ])
2001        });
2002        let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, initializer)
2003            .await
2004            .unwrap();
2005        let mut procs = HashSet::new();
2006        let mut started_procs = HashSet::new();
2007        let mut proc_points = HashSet::new();
2008        for _ in 0..spec.extent.num_ranks() * 2 {
2009            let proc_state = alloc.next().await.unwrap();
2010            tracing::debug!("test got message: {:?}", proc_state);
2011            match proc_state {
2012                ProcState::Created { proc_id, point, .. } => {
2013                    procs.insert(proc_id);
2014                    proc_points.insert(point);
2015                }
2016                ProcState::Running { proc_id, .. } => {
2017                    assert!(procs.contains(&proc_id));
2018                    started_procs.insert(proc_id);
2019                }
2020                _ => panic!("expected Created or Running"),
2021            }
2022        }
2023        assert_eq!(procs, started_procs);
2024        // ensure coords coverage
2025        assert!(
2026            spec.extent
2027                .points()
2028                .all(|point| proc_points.contains(&point))
2029        );
2030
2031        // ensure no more pending items
2032        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2033        tokio::select! {
2034            _ = hyperactor::clock::RealClock.sleep_until(timeout) => {},
2035            _ = alloc.next() => panic!("expected no more items"),
2036        }
2037
2038        // stop the allocation
2039        alloc.stop().await.unwrap();
2040        for _ in 0..spec.extent.num_ranks() {
2041            let proc_state = alloc.next().await.unwrap();
2042            tracing::info!("test received next proc_state: {:?}", proc_state);
2043            match proc_state {
2044                ProcState::Stopped { proc_id, reason } => {
2045                    assert!(started_procs.remove(&proc_id));
2046                    assert_eq!(reason, ProcStopReason::Stopped);
2047                }
2048                _ => panic!("expected stopped"),
2049            }
2050        }
2051        // Exactly one None
2052        let proc_state = alloc.next().await;
2053        assert!(proc_state.is_none());
2054        // Anything afterwards is None
2055        let proc_state = alloc.next().await;
2056        assert!(proc_state.is_none());
2057
2058        task1_allocator.terminate();
2059        task1_allocator_handle.await.unwrap();
2060        task2_allocator.terminate();
2061        task2_allocator_handle.await.unwrap();
2062    }
2063
2064    #[async_timed_test(timeout_secs = 60)]
2065    async fn test_alloc_host_failure() {
2066        // Use temporary config for this test
2067        let config = hyperactor::config::global::lock();
2068        let _guard1 = config.override_key(
2069            hyperactor::config::MESSAGE_DELIVERY_TIMEOUT,
2070            Duration::from_secs(1),
2071        );
2072        let _guard2 = config.override_key(
2073            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2074            Duration::from_millis(100),
2075        );
2076        hyperactor_telemetry::initialize_logging(ClockKind::default());
2077
2078        let spec = AllocSpec {
2079            extent: extent!(host = 2, gpu = 2),
2080            constraints: Default::default(),
2081        };
2082        let world_id = WorldId("test_world_id".to_string());
2083        let transport = ChannelTransport::Unix;
2084
2085        let task1_allocator = RemoteProcessAllocator::new();
2086        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2087        let task1_addr_string = task1_addr.to_string();
2088        let task1_cmd =
2089            Command::new(buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap());
2090        let task2_allocator = RemoteProcessAllocator::new();
2091        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2092        let task2_addr_string = task2_addr.to_string();
2093        let task2_cmd =
2094            Command::new(buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap());
2095        let task1_allocator_copy = task1_allocator.clone();
2096        let task1_allocator_handle = tokio::spawn(async move {
2097            tracing::info!("spawning task1");
2098            task1_allocator_copy
2099                .start(task1_cmd, task1_addr, None)
2100                .await
2101                .unwrap();
2102            tracing::info!("task1 terminated");
2103        });
2104        let task2_allocator_copy = task2_allocator.clone();
2105        let task2_allocator_handle = tokio::spawn(async move {
2106            task2_allocator_copy
2107                .start(task2_cmd, task2_addr, None)
2108                .await
2109                .unwrap();
2110            tracing::info!("task2 terminated");
2111        });
2112
2113        let mut initializer = MockRemoteProcessAllocInitializer::new();
2114        initializer.expect_initialize_alloc().return_once(move || {
2115            Ok(vec![
2116                RemoteProcessAllocHost {
2117                    hostname: task1_addr_string,
2118                    id: "task1".to_string(),
2119                },
2120                RemoteProcessAllocHost {
2121                    hostname: task2_addr_string,
2122                    id: "task2".to_string(),
2123                },
2124            ])
2125        });
2126        let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, initializer)
2127            .await
2128            .unwrap();
2129        for _ in 0..spec.extent.num_ranks() * 2 {
2130            match alloc.next().await {
2131                Some(ProcState::Created { .. }) | Some(ProcState::Running { .. }) => {}
2132                _ => panic!("expected Created or Running"),
2133            }
2134        }
2135
2136        // ensure no more pending items
2137        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2138        tokio::select! {
2139            _ = hyperactor::clock::RealClock
2140            .sleep_until(timeout) => {},
2141            _ = alloc.next() => panic!("expected no more items"),
2142        }
2143
2144        // now we kill task1 and wait for timeout
2145        tracing::info!("aborting task1 allocator");
2146        task1_allocator_handle.abort();
2147        RealClock
2148            .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2)
2149            .await;
2150        for _ in 0..spec.extent.num_ranks() / 2 {
2151            let proc_state = alloc.next().await.unwrap();
2152            tracing::info!("test received next proc_state: {:?}", proc_state);
2153            match proc_state {
2154                ProcState::Stopped { proc_id: _, reason } => {
2155                    assert_eq!(reason, ProcStopReason::HostWatchdog);
2156                }
2157                _ => panic!("expected stopped"),
2158            }
2159        }
2160        // no more events
2161        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2162        tokio::select! {
2163            _ = hyperactor::clock::RealClock
2164            .sleep_until(timeout) => {},
2165            _ = alloc.next() => panic!("expected no more items"),
2166        }
2167
2168        // abort the second host
2169        tracing::info!("aborting task2 allocator");
2170        task2_allocator_handle.abort();
2171        RealClock
2172            .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2)
2173            .await;
2174        for _ in 0..spec.extent.num_ranks() / 2 {
2175            let proc_state = alloc.next().await.unwrap();
2176            tracing::info!("test received next proc_state: {:?}", proc_state);
2177            match proc_state {
2178                ProcState::Stopped { proc_id: _, reason } => {
2179                    assert_eq!(reason, ProcStopReason::HostWatchdog);
2180                }
2181                _ => panic!("expected stopped"),
2182            }
2183        }
2184        // now the alloc has stopped we always get None
2185        let proc_state = alloc.next().await;
2186        assert!(proc_state.is_none());
2187        // Anything afterwards is None
2188        let proc_state = alloc.next().await;
2189        assert!(proc_state.is_none());
2190    }
2191
2192    #[async_timed_test(timeout_secs = 15)]
2193    async fn test_alloc_inner_alloc_failure() {
2194        // SAFETY: Test happens in single-threaded code.
2195        unsafe {
2196            std::env::set_var("MONARCH_MESSAGE_DELIVERY_TIMEOUT_SECS", "1");
2197        }
2198        let config = hyperactor::config::global::lock();
2199        let _guard = config.override_key(
2200            hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL,
2201            Duration::from_millis(100),
2202        );
2203        hyperactor_telemetry::initialize_logging(ClockKind::default());
2204
2205        let spec = AllocSpec {
2206            extent: extent!(host = 2, gpu = 2),
2207            constraints: Default::default(),
2208        };
2209        let world_id = WorldId("test_world_id".to_string());
2210        let transport = ChannelTransport::Unix;
2211
2212        let task1_allocator = RemoteProcessAllocator::new();
2213        let task1_addr = ChannelAddr::any(ChannelTransport::Unix);
2214        let task1_addr_string = task1_addr.to_string();
2215        let task1_cmd =
2216            Command::new(buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap());
2217        let task2_allocator = RemoteProcessAllocator::new();
2218        let task2_addr = ChannelAddr::any(ChannelTransport::Unix);
2219        let task2_addr_string = task2_addr.to_string();
2220        // non-existent binary to fail the allocation
2221        let task2_cmd = Command::new("/caught/somewhere/in/time");
2222        let task1_allocator_copy = task1_allocator.clone();
2223        let task1_allocator_handle = tokio::spawn(async move {
2224            tracing::info!("spawning task1");
2225            task1_allocator_copy
2226                .start(task1_cmd, task1_addr, None)
2227                .await
2228                .unwrap();
2229        });
2230        let task2_allocator_copy = task2_allocator.clone();
2231        let task2_allocator_handle = tokio::spawn(async move {
2232            task2_allocator_copy
2233                .start(task2_cmd, task2_addr, None)
2234                .await
2235                .unwrap();
2236        });
2237
2238        let mut initializer = MockRemoteProcessAllocInitializer::new();
2239        initializer.expect_initialize_alloc().return_once(move || {
2240            Ok(vec![
2241                RemoteProcessAllocHost {
2242                    hostname: task1_addr_string,
2243                    id: "task1".to_string(),
2244                },
2245                RemoteProcessAllocHost {
2246                    hostname: task2_addr_string,
2247                    id: "task2".to_string(),
2248                },
2249            ])
2250        });
2251        let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, initializer)
2252            .await
2253            .unwrap();
2254        let mut procs = HashSet::new();
2255        let mut started_procs = HashSet::new();
2256        let mut proc_points = HashSet::new();
2257        let mut failed = 0;
2258        // task 1 procs + 1 failed event for task 2
2259        for _ in 0..spec.extent.num_ranks() + 1 {
2260            let proc_state = alloc.next().await.unwrap();
2261            tracing::debug!("test got message: {:?}", proc_state);
2262            match proc_state {
2263                ProcState::Created { proc_id, point, .. } => {
2264                    procs.insert(proc_id);
2265                    proc_points.insert(point);
2266                }
2267                ProcState::Running { proc_id, .. } => {
2268                    assert!(procs.contains(&proc_id));
2269                    started_procs.insert(proc_id);
2270                }
2271                ProcState::Failed { .. } => {
2272                    failed += 1;
2273                }
2274                _ => panic!("expected Created, Running or Failed"),
2275            }
2276        }
2277        assert_eq!(procs, started_procs);
2278        assert_eq!(failed, 1);
2279        // ensure coords coverage for task 1
2280        for rank in 0..spec.extent.num_ranks() / 2 {
2281            let point = spec.extent.point_of_rank(rank).unwrap();
2282            assert!(proc_points.contains(&point));
2283        }
2284
2285        // ensure no more pending items
2286        let timeout = hyperactor::clock::RealClock.now() + std::time::Duration::from_millis(1000);
2287        tokio::select! {
2288            _ = hyperactor::clock::RealClock
2289            .sleep_until(timeout) => {},
2290            _ = alloc.next() => panic!("expected no more items"),
2291        }
2292
2293        // stop the allocation
2294        alloc.stop().await.unwrap();
2295        for _ in 0..spec.extent.num_ranks() / 2 {
2296            let proc_state = alloc.next().await.unwrap();
2297            tracing::info!("test received next proc_state: {:?}", proc_state);
2298            match proc_state {
2299                ProcState::Stopped { proc_id, reason } => {
2300                    assert!(started_procs.remove(&proc_id));
2301                    assert_eq!(reason, ProcStopReason::Stopped);
2302                }
2303                _ => panic!("expected stopped"),
2304            }
2305        }
2306        // Exactly one None
2307        let proc_state = alloc.next().await;
2308        assert!(proc_state.is_none());
2309        // Anything afterwards is None
2310        let proc_state = alloc.next().await;
2311        assert!(proc_state.is_none());
2312
2313        task1_allocator.terminate();
2314        task1_allocator_handle.await.unwrap();
2315        task2_allocator.terminate();
2316        task2_allocator_handle.await.unwrap();
2317    }
2318
2319    #[tracing_test::traced_test]
2320    #[async_timed_test(timeout_secs = 60)]
2321    async fn test_remote_process_alloc_signal_handler() {
2322        let num_proc_meshes = 5;
2323        let hosts_per_proc_mesh = 5;
2324
2325        let pid_addr = ChannelAddr::any(ChannelTransport::Unix);
2326        let (pid_addr, mut pid_rx) = channel::serve::<u32>(pid_addr).await.unwrap();
2327
2328        let addresses = (0..(num_proc_meshes * hosts_per_proc_mesh))
2329            .map(|_| ChannelAddr::any(ChannelTransport::Unix).to_string())
2330            .collect::<Vec<_>>();
2331
2332        let remote_process_allocators = addresses
2333            .iter()
2334            .map(|addr| {
2335                Command::new(
2336                    buck_resources::get("monarch/hyperactor_mesh/remote_process_allocator")
2337                        .unwrap(),
2338                )
2339                .env("RUST_LOG", "info")
2340                .arg(format!("--addr={addr}"))
2341                .stdout(std::process::Stdio::piped())
2342                .spawn()
2343                .unwrap()
2344            })
2345            .collect::<Vec<_>>();
2346
2347        let done_allocating_addr = ChannelAddr::any(ChannelTransport::Unix);
2348        let (done_allocating_addr, mut done_allocating_rx) =
2349            channel::serve::<()>(done_allocating_addr).await.unwrap();
2350        let mut remote_process_alloc = Command::new(
2351            buck_resources::get("monarch/hyperactor_mesh/remote_process_alloc").unwrap(),
2352        )
2353        .arg(format!("--done-allocating-addr={}", done_allocating_addr))
2354        .arg(format!("--addresses={}", addresses.join(",")))
2355        .arg(format!("--num-proc-meshes={}", num_proc_meshes))
2356        .arg(format!("--hosts-per-proc-mesh={}", hosts_per_proc_mesh))
2357        .arg(format!("--pid-addr={}", pid_addr))
2358        .spawn()
2359        .unwrap();
2360
2361        done_allocating_rx.recv().await.unwrap();
2362        let mut received_pids = Vec::new();
2363        while let Ok(pid) = pid_rx.recv().await {
2364            received_pids.push(pid);
2365            if received_pids.len() == remote_process_allocators.len() {
2366                break;
2367            }
2368        }
2369
2370        signal::kill(
2371            Pid::from_raw(remote_process_alloc.id().unwrap() as i32),
2372            signal::SIGINT,
2373        )
2374        .unwrap();
2375
2376        assert_eq!(
2377            remote_process_alloc.wait().await.unwrap().signal(),
2378            Some(signal::SIGINT as i32)
2379        );
2380
2381        RealClock.sleep(tokio::time::Duration::from_secs(5)).await;
2382
2383        // Assert that the processes spawned by ProcessAllocator have been killed
2384        for child_pid in received_pids {
2385            let pid_check = Command::new("kill")
2386                .arg("-0")
2387                .arg(child_pid.to_string())
2388                .output()
2389                .await
2390                .expect("Failed to check if PID is alive");
2391
2392            assert!(
2393                !pid_check.status.success(),
2394                "PID {} should no longer be alive",
2395                child_pid
2396            );
2397        }
2398
2399        // Cleanup remote process allocator processes as SIGINT causes the current
2400        // allocs to stop but not the RemoteProcessAllocator loops
2401        for mut remote_process_allocator in remote_process_allocators {
2402            remote_process_allocator.kill().await.unwrap();
2403        }
2404    }
2405}