hyperactor_mesh/alloc/
local.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Support for allocating procs in the local process.
10
11#![allow(dead_code)] // until it is used outside of testing
12
13use std::collections::HashMap;
14use std::collections::VecDeque;
15use std::time::Duration;
16
17use async_trait::async_trait;
18use hyperactor::channel;
19use hyperactor::channel::ChannelAddr;
20use hyperactor::mailbox::MailboxServer;
21use hyperactor::mailbox::MailboxServerHandle;
22use hyperactor::proc::Proc;
23use hyperactor::reference as hyperactor_reference;
24use ndslice::view::Extent;
25use tokio::sync::mpsc;
26use tokio::time::sleep;
27
28use super::ProcStopReason;
29use crate::alloc::Alloc;
30use crate::alloc::AllocName;
31use crate::alloc::AllocSpec;
32use crate::alloc::Allocator;
33use crate::alloc::AllocatorError;
34use crate::alloc::ProcState;
35use crate::proc_agent::ProcAgent;
36use crate::shortuuid::ShortUuid;
37
38enum Action {
39    Start(usize),
40    Stop(usize, ProcStopReason),
41    Stopped,
42}
43
44/// An allocator that runs procs in the local process. It is primarily useful for testing,
45/// or small meshes that can run entirely locally.
46///
47/// Currently, the allocator will allocate all procs, but each is treated as infallible,
48/// since they share fault domain with the client of the alloc.
49pub struct LocalAllocator;
50
51#[async_trait]
52impl Allocator for LocalAllocator {
53    type Alloc = LocalAlloc;
54
55    async fn allocate(&mut self, spec: AllocSpec) -> Result<Self::Alloc, AllocatorError> {
56        let alloc = LocalAlloc::new(spec);
57        tracing::info!(
58            name = "LocalAllocStatus",
59            alloc_name = %alloc.alloc_name(),
60            status = "Allocated",
61        );
62        Ok(alloc)
63    }
64}
65
66struct LocalProc {
67    proc: Proc,
68    create_key: ShortUuid,
69    addr: ChannelAddr,
70    handle: MailboxServerHandle,
71}
72
73/// A local allocation. It is a collection of procs that are running in the local process.
74pub struct LocalAlloc {
75    spec: AllocSpec,
76    name: ShortUuid,
77    alloc_name: AllocName,
78    procs: HashMap<usize, LocalProc>,
79    queue: VecDeque<ProcState>,
80    todo_tx: mpsc::UnboundedSender<Action>,
81    todo_rx: mpsc::UnboundedReceiver<Action>,
82    stopped: bool,
83    failed: bool,
84}
85
86impl LocalAlloc {
87    pub(crate) fn new(spec: AllocSpec) -> Self {
88        let name = ShortUuid::generate();
89        let (todo_tx, todo_rx) = mpsc::unbounded_channel();
90        for rank in 0..spec.extent.num_ranks() {
91            todo_tx.send(Action::Start(rank)).unwrap();
92        }
93        Self {
94            spec,
95            name: name.clone(),
96            alloc_name: AllocName(name.to_string()),
97            procs: HashMap::new(),
98            queue: VecDeque::new(),
99            todo_tx,
100            todo_rx,
101            stopped: false,
102            failed: false,
103        }
104    }
105
106    /// A chaos monkey that can be used to stop procs at random.
107    pub(crate) fn chaos_monkey(&self) -> impl Fn(usize, ProcStopReason) + 'static {
108        let todo_tx = self.todo_tx.clone();
109        move |rank, reason| {
110            todo_tx.send(Action::Stop(rank, reason)).unwrap();
111        }
112    }
113
114    /// A function to shut down the alloc for testing purposes.
115    pub(crate) fn stopper(&self) -> impl Fn() + 'static {
116        let todo_tx = self.todo_tx.clone();
117        let size = self.size();
118        move || {
119            for rank in 0..size {
120                todo_tx
121                    .send(Action::Stop(rank, ProcStopReason::Stopped))
122                    .unwrap();
123            }
124            todo_tx.send(Action::Stopped).unwrap();
125        }
126    }
127
128    pub(crate) fn name(&self) -> &ShortUuid {
129        &self.name
130    }
131
132    pub(crate) fn size(&self) -> usize {
133        self.spec.extent.num_ranks()
134    }
135}
136
137#[async_trait]
138impl Alloc for LocalAlloc {
139    async fn next(&mut self) -> Option<ProcState> {
140        if self.stopped {
141            return None;
142        }
143        if self.failed && !self.stopped {
144            // Failed alloc. Wait for stop().
145            futures::future::pending::<()>().await;
146            unreachable!("future::pending completed");
147        }
148        let event = loop {
149            if let state @ Some(_) = self.queue.pop_front() {
150                break state;
151            }
152
153            match self.todo_rx.recv().await? {
154                Action::Start(rank) => {
155                    let (addr, proc_rx) = loop {
156                        match channel::serve(ChannelAddr::any(self.transport())) {
157                            Ok(addr_and_proc_rx) => break addr_and_proc_rx,
158                            Err(err) => {
159                                tracing::error!(
160                                    "failed to create channel for rank {}: {}",
161                                    rank,
162                                    err
163                                );
164                                sleep(Duration::from_secs(1)).await;
165                                continue;
166                            }
167                        }
168                    };
169
170                    let proc_name = match &self.spec.proc_name {
171                        Some(name) => name.clone(),
172                        None => format!("{}_{}", self.alloc_name.name(), rank),
173                    };
174                    let proc_id = hyperactor_reference::ProcId::with_name(addr.clone(), proc_name);
175
176                    let bspan = tracing::info_span!("mesh_agent_bootstrap");
177                    let (proc, mesh_agent) = match ProcAgent::bootstrap(proc_id.clone()).await {
178                        Ok(proc_and_agent) => proc_and_agent,
179                        Err(err) => {
180                            let message = format!("failed spawn mesh agent for {}: {}", rank, err);
181                            tracing::error!(message);
182                            // It's unclear if this is actually recoverable in a practical sense,
183                            // so we give up.
184                            self.failed = true;
185                            break Some(ProcState::Failed {
186                                alloc_name: self.alloc_name.clone(),
187                                description: message,
188                            });
189                        }
190                    };
191                    drop(bspan);
192
193                    // Undeliverable messages get forwarded to the mesh agent.
194                    let handle = proc.clone().serve(proc_rx);
195
196                    let create_key = ShortUuid::generate();
197
198                    self.procs.insert(
199                        rank,
200                        LocalProc {
201                            proc,
202                            create_key: create_key.clone(),
203                            addr: addr.clone(),
204                            handle,
205                        },
206                    );
207
208                    let point = match self.spec.extent.point_of_rank(rank) {
209                        Ok(point) => point,
210                        Err(err) => {
211                            tracing::error!("failed to get point for rank {}: {}", rank, err);
212                            return None;
213                        }
214                    };
215                    let created = ProcState::Created {
216                        create_key: create_key.clone(),
217                        point,
218                        pid: std::process::id(),
219                    };
220                    self.queue.push_back(ProcState::Running {
221                        create_key,
222                        proc_id,
223                        mesh_agent: mesh_agent.bind(),
224                        addr,
225                    });
226                    break Some(created);
227                }
228                Action::Stop(rank, reason) => {
229                    let Some(mut proc_to_stop) = self.procs.remove(&rank) else {
230                        continue;
231                    };
232
233                    // Stop serving the mailbox.
234                    proc_to_stop.handle.stop("received Action::Stop");
235
236                    if let Err(err) = proc_to_stop
237                        .proc
238                        .destroy_and_wait::<()>(
239                            Duration::from_millis(10),
240                            None,
241                            &reason.to_string(),
242                        )
243                        .await
244                    {
245                        tracing::error!("error while stopping proc {}: {}", rank, err);
246                    }
247                    break Some(ProcState::Stopped {
248                        reason,
249                        create_key: proc_to_stop.create_key.clone(),
250                    });
251                }
252                Action::Stopped => break None,
253            }
254        };
255        self.stopped = event.is_none();
256        event
257    }
258
259    fn spec(&self) -> &AllocSpec {
260        &self.spec
261    }
262
263    fn extent(&self) -> &Extent {
264        &self.spec.extent
265    }
266
267    fn alloc_name(&self) -> &AllocName {
268        &self.alloc_name
269    }
270
271    async fn stop(&mut self) -> Result<(), AllocatorError> {
272        tracing::info!(
273            name = "LocalAllocStatus",
274            alloc_name = %self.alloc_name(),
275            status = "Stopping",
276        );
277        for rank in 0..self.size() {
278            self.todo_tx
279                .send(Action::Stop(rank, ProcStopReason::Stopped))
280                .unwrap();
281        }
282        self.todo_tx.send(Action::Stopped).unwrap();
283        tracing::info!(
284            name = "LocalAllocStatus",
285            alloc_name = %self.alloc_name(),
286            status = "Stop::Sent",
287            "Stop was sent to local procs; check their log to determine if it exited."
288        );
289        Ok(())
290    }
291
292    fn is_local(&self) -> bool {
293        true
294    }
295}
296
297impl Drop for LocalAlloc {
298    fn drop(&mut self) {
299        tracing::info!(
300            name = "LocalAllocStatus",
301            alloc_name = %self.alloc_name(),
302            status = "Dropped",
303            "dropping LocalAlloc of name: {}, alloc_name: {}",
304            self.name,
305            self.alloc_name
306        );
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    crate::alloc_test_suite!(LocalAllocator);
315}