hyperactor_mesh/alloc/
local.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Support for allocating procs in the local process.
10
11#![allow(dead_code)] // until it is used outside of testing
12
13use std::collections::HashMap;
14use std::collections::VecDeque;
15use std::time::Duration;
16
17use async_trait::async_trait;
18use hyperactor::ProcId;
19use hyperactor::WorldId;
20use hyperactor::channel;
21use hyperactor::channel::ChannelAddr;
22use hyperactor::mailbox::MailboxServer;
23use hyperactor::mailbox::MailboxServerHandle;
24use hyperactor::proc::Proc;
25use ndslice::view::Extent;
26use tokio::sync::mpsc;
27use tokio::time::sleep;
28
29use super::ProcStopReason;
30use crate::alloc::Alloc;
31use crate::alloc::AllocSpec;
32use crate::alloc::Allocator;
33use crate::alloc::AllocatorError;
34use crate::alloc::ProcState;
35use crate::proc_mesh::mesh_agent::ProcMeshAgent;
36use crate::shortuuid::ShortUuid;
37
38enum Action {
39    Start(usize),
40    Stop(usize, ProcStopReason),
41    Stopped,
42}
43
44/// An allocator that runs procs in the local process. It is primarily useful for testing,
45/// or small meshes that can run entirely locally.
46///
47/// Currently, the allocator will allocate all procs, but each is treated as infallible,
48/// since they share fault domain with the client of the alloc.
49pub struct LocalAllocator;
50
51#[async_trait]
52impl Allocator for LocalAllocator {
53    type Alloc = LocalAlloc;
54
55    async fn allocate(&mut self, spec: AllocSpec) -> Result<Self::Alloc, AllocatorError> {
56        Ok(LocalAlloc::new(spec))
57    }
58}
59
60struct LocalProc {
61    proc: Proc,
62    create_key: ShortUuid,
63    addr: ChannelAddr,
64    handle: MailboxServerHandle,
65}
66
67/// A local allocation. It is a collection of procs that are running in the local process.
68pub struct LocalAlloc {
69    spec: AllocSpec,
70    name: ShortUuid,
71    world_id: WorldId, // to provide storage
72    procs: HashMap<usize, LocalProc>,
73    queue: VecDeque<ProcState>,
74    todo_tx: mpsc::UnboundedSender<Action>,
75    todo_rx: mpsc::UnboundedReceiver<Action>,
76    stopped: bool,
77    failed: bool,
78}
79
80impl LocalAlloc {
81    pub(crate) fn new(spec: AllocSpec) -> Self {
82        let name = ShortUuid::generate();
83        let (todo_tx, todo_rx) = mpsc::unbounded_channel();
84        for rank in 0..spec.extent.num_ranks() {
85            todo_tx.send(Action::Start(rank)).unwrap();
86        }
87        Self {
88            spec,
89            name: name.clone(),
90            world_id: WorldId(name.to_string()),
91            procs: HashMap::new(),
92            queue: VecDeque::new(),
93            todo_tx,
94            todo_rx,
95            stopped: false,
96            failed: false,
97        }
98    }
99
100    /// A chaos monkey that can be used to stop procs at random.
101    pub(crate) fn chaos_monkey(&self) -> impl Fn(usize, ProcStopReason) + 'static {
102        let todo_tx = self.todo_tx.clone();
103        move |rank, reason| {
104            todo_tx.send(Action::Stop(rank, reason)).unwrap();
105        }
106    }
107
108    /// A function to shut down the alloc for testing purposes.
109    pub(crate) fn stopper(&self) -> impl Fn() + 'static {
110        let todo_tx = self.todo_tx.clone();
111        let size = self.size();
112        move || {
113            for rank in 0..size {
114                todo_tx
115                    .send(Action::Stop(rank, ProcStopReason::Stopped))
116                    .unwrap();
117            }
118            todo_tx.send(Action::Stopped).unwrap();
119        }
120    }
121
122    pub(crate) fn name(&self) -> &ShortUuid {
123        &self.name
124    }
125
126    pub(crate) fn size(&self) -> usize {
127        self.spec.extent.num_ranks()
128    }
129}
130
131#[async_trait]
132impl Alloc for LocalAlloc {
133    async fn next(&mut self) -> Option<ProcState> {
134        if self.stopped {
135            return None;
136        }
137        if self.failed && !self.stopped {
138            // Failed alloc. Wait for stop().
139            futures::future::pending::<()>().await;
140            unreachable!("future::pending completed");
141        }
142        let event = loop {
143            if let state @ Some(_) = self.queue.pop_front() {
144                break state;
145            }
146
147            match self.todo_rx.recv().await? {
148                Action::Start(rank) => {
149                    let (addr, proc_rx) = loop {
150                        match channel::serve(ChannelAddr::any(self.transport())) {
151                            Ok(addr_and_proc_rx) => break addr_and_proc_rx,
152                            Err(err) => {
153                                tracing::error!(
154                                    "failed to create channel for rank {}: {}",
155                                    rank,
156                                    err
157                                );
158                                #[allow(clippy::disallowed_methods)]
159                                sleep(Duration::from_secs(1)).await;
160                                continue;
161                            }
162                        }
163                    };
164
165                    let proc_id = match &self.spec.proc_name {
166                        Some(name) => ProcId::Direct(addr.clone(), name.clone()),
167                        None => ProcId::Ranked(self.world_id.clone(), rank),
168                    };
169
170                    let bspan = tracing::info_span!("mesh_agent_bootstrap");
171                    let (proc, mesh_agent) = match ProcMeshAgent::bootstrap(proc_id.clone()).await {
172                        Ok(proc_and_agent) => proc_and_agent,
173                        Err(err) => {
174                            let message = format!("failed spawn mesh agent for {}: {}", rank, err);
175                            tracing::error!(message);
176                            // It's unclear if this is actually recoverable in a practical sense,
177                            // so we give up.
178                            self.failed = true;
179                            break Some(ProcState::Failed {
180                                world_id: self.world_id.clone(),
181                                description: message,
182                            });
183                        }
184                    };
185                    drop(bspan);
186
187                    // Undeliverable messages get forwarded to the mesh agent.
188                    let handle = proc.clone().serve(proc_rx);
189
190                    let create_key = ShortUuid::generate();
191
192                    self.procs.insert(
193                        rank,
194                        LocalProc {
195                            proc,
196                            create_key: create_key.clone(),
197                            addr: addr.clone(),
198                            handle,
199                        },
200                    );
201
202                    let point = match self.spec.extent.point_of_rank(rank) {
203                        Ok(point) => point,
204                        Err(err) => {
205                            tracing::error!("failed to get point for rank {}: {}", rank, err);
206                            return None;
207                        }
208                    };
209                    let created = ProcState::Created {
210                        create_key: create_key.clone(),
211                        point,
212                        pid: std::process::id(),
213                    };
214                    self.queue.push_back(ProcState::Running {
215                        create_key,
216                        proc_id,
217                        mesh_agent: mesh_agent.bind(),
218                        addr,
219                    });
220                    break Some(created);
221                }
222                Action::Stop(rank, reason) => {
223                    let Some(mut proc_to_stop) = self.procs.remove(&rank) else {
224                        continue;
225                    };
226
227                    // Stop serving the mailbox.
228                    proc_to_stop.handle.stop("received Action::Stop");
229
230                    if let Err(err) = proc_to_stop
231                        .proc
232                        .destroy_and_wait::<()>(Duration::from_millis(10), None)
233                        .await
234                    {
235                        tracing::error!("error while stopping proc {}: {}", rank, err);
236                    }
237                    break Some(ProcState::Stopped {
238                        reason,
239                        create_key: proc_to_stop.create_key.clone(),
240                    });
241                }
242                Action::Stopped => break None,
243            }
244        };
245        self.stopped = event.is_none();
246        event
247    }
248
249    fn spec(&self) -> &AllocSpec {
250        &self.spec
251    }
252
253    fn extent(&self) -> &Extent {
254        &self.spec.extent
255    }
256
257    fn world_id(&self) -> &WorldId {
258        &self.world_id
259    }
260
261    async fn stop(&mut self) -> Result<(), AllocatorError> {
262        for rank in 0..self.size() {
263            self.todo_tx
264                .send(Action::Stop(rank, ProcStopReason::Stopped))
265                .unwrap();
266        }
267        self.todo_tx.send(Action::Stopped).unwrap();
268        Ok(())
269    }
270
271    fn is_local(&self) -> bool {
272        true
273    }
274}
275
276impl Drop for LocalAlloc {
277    fn drop(&mut self) {
278        tracing::debug!(
279            "dropping LocalAlloc of name: {}, world id: {}",
280            self.name,
281            self.world_id
282        );
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    crate::alloc_test_suite!(LocalAllocator);
291}