hyperactor_mesh/alloc/
local.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Support for allocating procs in the local process.
10
11#![allow(dead_code)] // until it is used outside of testing
12
13use std::collections::HashMap;
14use std::collections::VecDeque;
15use std::time::Duration;
16
17use async_trait::async_trait;
18use hyperactor::ProcId;
19use hyperactor::WorldId;
20use hyperactor::channel;
21use hyperactor::channel::ChannelAddr;
22use hyperactor::mailbox::MailboxServer;
23use hyperactor::mailbox::MailboxServerHandle;
24use hyperactor::proc::Proc;
25use ndslice::view::Extent;
26use tokio::sync::mpsc;
27use tokio::time::sleep;
28
29use super::ProcStopReason;
30use crate::alloc::Alloc;
31use crate::alloc::AllocSpec;
32use crate::alloc::Allocator;
33use crate::alloc::AllocatorError;
34use crate::alloc::ProcState;
35use crate::proc_mesh::mesh_agent::ProcMeshAgent;
36use crate::shortuuid::ShortUuid;
37
38enum Action {
39    Start(usize),
40    Stop(usize, ProcStopReason),
41    Stopped,
42}
43
44/// An allocator that runs procs in the local process. It is primarily useful for testing,
45/// or small meshes that can run entirely locally.
46///
47/// Currently, the allocator will allocate all procs, but each is treated as infallible,
48/// since they share fault domain with the client of the alloc.
49pub struct LocalAllocator;
50
51#[async_trait]
52impl Allocator for LocalAllocator {
53    type Alloc = LocalAlloc;
54
55    async fn allocate(&mut self, spec: AllocSpec) -> Result<Self::Alloc, AllocatorError> {
56        let alloc = LocalAlloc::new(spec);
57        tracing::info!(
58            name = "LocalAllocStatus",
59            alloc_name = %alloc.world_id(),
60            status = "Allocated",
61        );
62        Ok(alloc)
63    }
64}
65
66struct LocalProc {
67    proc: Proc,
68    create_key: ShortUuid,
69    addr: ChannelAddr,
70    handle: MailboxServerHandle,
71}
72
73/// A local allocation. It is a collection of procs that are running in the local process.
74pub struct LocalAlloc {
75    spec: AllocSpec,
76    name: ShortUuid,
77    world_id: WorldId, // to provide storage
78    procs: HashMap<usize, LocalProc>,
79    queue: VecDeque<ProcState>,
80    todo_tx: mpsc::UnboundedSender<Action>,
81    todo_rx: mpsc::UnboundedReceiver<Action>,
82    stopped: bool,
83    failed: bool,
84}
85
86impl LocalAlloc {
87    pub(crate) fn new(spec: AllocSpec) -> Self {
88        let name = ShortUuid::generate();
89        let (todo_tx, todo_rx) = mpsc::unbounded_channel();
90        for rank in 0..spec.extent.num_ranks() {
91            todo_tx.send(Action::Start(rank)).unwrap();
92        }
93        Self {
94            spec,
95            name: name.clone(),
96            world_id: WorldId(name.to_string()),
97            procs: HashMap::new(),
98            queue: VecDeque::new(),
99            todo_tx,
100            todo_rx,
101            stopped: false,
102            failed: false,
103        }
104    }
105
106    /// A chaos monkey that can be used to stop procs at random.
107    pub(crate) fn chaos_monkey(&self) -> impl Fn(usize, ProcStopReason) + 'static {
108        let todo_tx = self.todo_tx.clone();
109        move |rank, reason| {
110            todo_tx.send(Action::Stop(rank, reason)).unwrap();
111        }
112    }
113
114    /// A function to shut down the alloc for testing purposes.
115    pub(crate) fn stopper(&self) -> impl Fn() + 'static {
116        let todo_tx = self.todo_tx.clone();
117        let size = self.size();
118        move || {
119            for rank in 0..size {
120                todo_tx
121                    .send(Action::Stop(rank, ProcStopReason::Stopped))
122                    .unwrap();
123            }
124            todo_tx.send(Action::Stopped).unwrap();
125        }
126    }
127
128    pub(crate) fn name(&self) -> &ShortUuid {
129        &self.name
130    }
131
132    pub(crate) fn size(&self) -> usize {
133        self.spec.extent.num_ranks()
134    }
135}
136
137#[async_trait]
138impl Alloc for LocalAlloc {
139    async fn next(&mut self) -> Option<ProcState> {
140        if self.stopped {
141            return None;
142        }
143        if self.failed && !self.stopped {
144            // Failed alloc. Wait for stop().
145            futures::future::pending::<()>().await;
146            unreachable!("future::pending completed");
147        }
148        let event = loop {
149            if let state @ Some(_) = self.queue.pop_front() {
150                break state;
151            }
152
153            match self.todo_rx.recv().await? {
154                Action::Start(rank) => {
155                    let (addr, proc_rx) = loop {
156                        match channel::serve(ChannelAddr::any(self.transport())) {
157                            Ok(addr_and_proc_rx) => break addr_and_proc_rx,
158                            Err(err) => {
159                                tracing::error!(
160                                    "failed to create channel for rank {}: {}",
161                                    rank,
162                                    err
163                                );
164                                #[allow(clippy::disallowed_methods)]
165                                sleep(Duration::from_secs(1)).await;
166                                continue;
167                            }
168                        }
169                    };
170
171                    let proc_id = match &self.spec.proc_name {
172                        Some(name) => ProcId::Direct(addr.clone(), name.clone()),
173                        None => ProcId::Ranked(self.world_id.clone(), rank),
174                    };
175
176                    let bspan = tracing::info_span!("mesh_agent_bootstrap");
177                    let (proc, mesh_agent) = match ProcMeshAgent::bootstrap(proc_id.clone()).await {
178                        Ok(proc_and_agent) => proc_and_agent,
179                        Err(err) => {
180                            let message = format!("failed spawn mesh agent for {}: {}", rank, err);
181                            tracing::error!(message);
182                            // It's unclear if this is actually recoverable in a practical sense,
183                            // so we give up.
184                            self.failed = true;
185                            break Some(ProcState::Failed {
186                                world_id: self.world_id.clone(),
187                                description: message,
188                            });
189                        }
190                    };
191                    drop(bspan);
192
193                    // Undeliverable messages get forwarded to the mesh agent.
194                    let handle = proc.clone().serve(proc_rx);
195
196                    let create_key = ShortUuid::generate();
197
198                    self.procs.insert(
199                        rank,
200                        LocalProc {
201                            proc,
202                            create_key: create_key.clone(),
203                            addr: addr.clone(),
204                            handle,
205                        },
206                    );
207
208                    let point = match self.spec.extent.point_of_rank(rank) {
209                        Ok(point) => point,
210                        Err(err) => {
211                            tracing::error!("failed to get point for rank {}: {}", rank, err);
212                            return None;
213                        }
214                    };
215                    let created = ProcState::Created {
216                        create_key: create_key.clone(),
217                        point,
218                        pid: std::process::id(),
219                    };
220                    self.queue.push_back(ProcState::Running {
221                        create_key,
222                        proc_id,
223                        mesh_agent: mesh_agent.bind(),
224                        addr,
225                    });
226                    break Some(created);
227                }
228                Action::Stop(rank, reason) => {
229                    let Some(mut proc_to_stop) = self.procs.remove(&rank) else {
230                        continue;
231                    };
232
233                    // Stop serving the mailbox.
234                    proc_to_stop.handle.stop("received Action::Stop");
235
236                    if let Err(err) = proc_to_stop
237                        .proc
238                        .destroy_and_wait::<()>(Duration::from_millis(10), None)
239                        .await
240                    {
241                        tracing::error!("error while stopping proc {}: {}", rank, err);
242                    }
243                    break Some(ProcState::Stopped {
244                        reason,
245                        create_key: proc_to_stop.create_key.clone(),
246                    });
247                }
248                Action::Stopped => break None,
249            }
250        };
251        self.stopped = event.is_none();
252        event
253    }
254
255    fn spec(&self) -> &AllocSpec {
256        &self.spec
257    }
258
259    fn extent(&self) -> &Extent {
260        &self.spec.extent
261    }
262
263    fn world_id(&self) -> &WorldId {
264        &self.world_id
265    }
266
267    async fn stop(&mut self) -> Result<(), AllocatorError> {
268        tracing::info!(
269            name = "LocalAllocStatus",
270            alloc_name = %self.world_id(),
271            status = "Stopping",
272        );
273        for rank in 0..self.size() {
274            self.todo_tx
275                .send(Action::Stop(rank, ProcStopReason::Stopped))
276                .unwrap();
277        }
278        self.todo_tx.send(Action::Stopped).unwrap();
279        tracing::info!(
280            name = "LocalAllocStatus",
281            alloc_name = %self.world_id(),
282            status = "Stop::Sent",
283            "Stop was sent to local procs; check their log to determine if it exited."
284        );
285        Ok(())
286    }
287
288    fn is_local(&self) -> bool {
289        true
290    }
291}
292
293impl Drop for LocalAlloc {
294    fn drop(&mut self) {
295        tracing::info!(
296            name = "LocalAllocStatus",
297            alloc_name = %self.world_id(),
298            status = "Dropped",
299            "dropping LocalAlloc of name: {}, world id: {}",
300            self.name,
301            self.world_id
302        );
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    crate::alloc_test_suite!(LocalAllocator);
311}