hyperactor_mesh/alloc/
local.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Support for allocating procs in the local process.
10
11#![allow(dead_code)] // until it is used outside of testing
12
13use std::collections::HashMap;
14use std::collections::VecDeque;
15use std::time::Duration;
16
17use async_trait::async_trait;
18use hyperactor::ProcId;
19use hyperactor::WorldId;
20use hyperactor::channel;
21use hyperactor::channel::ChannelAddr;
22use hyperactor::channel::ChannelTransport;
23use hyperactor::mailbox::MailboxServer;
24use hyperactor::mailbox::MailboxServerHandle;
25use hyperactor::proc::Proc;
26use ndslice::view::Extent;
27use tokio::sync::mpsc;
28use tokio::time::sleep;
29
30use super::ProcStopReason;
31use crate::alloc::Alloc;
32use crate::alloc::AllocSpec;
33use crate::alloc::Allocator;
34use crate::alloc::AllocatorError;
35use crate::alloc::ProcState;
36use crate::proc_mesh::mesh_agent::MeshAgent;
37use crate::shortuuid::ShortUuid;
38
39enum Action {
40    Start(usize),
41    Stop(usize, ProcStopReason),
42    Stopped,
43}
44
45/// An allocator that runs procs in the local process. It is primarily useful for testing,
46/// or small meshes that can run entirely locally.
47///
48/// Currently, the allocator will allocate all procs, but each is treated as infallible,
49/// since they share fault domain with the client of the alloc.
50pub struct LocalAllocator;
51
52#[async_trait]
53impl Allocator for LocalAllocator {
54    type Alloc = LocalAlloc;
55
56    async fn allocate(&mut self, spec: AllocSpec) -> Result<Self::Alloc, AllocatorError> {
57        Ok(LocalAlloc::new(spec))
58    }
59}
60
61struct LocalProc {
62    proc: Proc,
63    addr: ChannelAddr,
64    handle: MailboxServerHandle,
65}
66
67/// A local allocation. It is a collection of procs that are running in the local process.
68pub struct LocalAlloc {
69    spec: AllocSpec,
70    name: ShortUuid,
71    world_id: WorldId, // to provide storage
72    procs: HashMap<usize, LocalProc>,
73    queue: VecDeque<ProcState>,
74    todo_tx: mpsc::UnboundedSender<Action>,
75    todo_rx: mpsc::UnboundedReceiver<Action>,
76    stopped: bool,
77    failed: bool,
78    transport: ChannelTransport,
79}
80
81impl LocalAlloc {
82    fn new(spec: AllocSpec) -> Self {
83        Self::new_with_transport(spec, ChannelTransport::Local)
84    }
85
86    pub(crate) fn new_with_transport(spec: AllocSpec, transport: ChannelTransport) -> Self {
87        let name = ShortUuid::generate();
88        let (todo_tx, todo_rx) = mpsc::unbounded_channel();
89        for rank in 0..spec.extent.num_ranks() {
90            todo_tx.send(Action::Start(rank)).unwrap();
91        }
92        Self {
93            spec,
94            name: name.clone(),
95            world_id: WorldId(name.to_string()),
96            procs: HashMap::new(),
97            queue: VecDeque::new(),
98            todo_tx,
99            todo_rx,
100            stopped: false,
101            failed: false,
102            transport,
103        }
104    }
105
106    /// A chaos monkey that can be used to stop procs at random.
107    pub(crate) fn chaos_monkey(&self) -> impl Fn(usize, ProcStopReason) + 'static {
108        let todo_tx = self.todo_tx.clone();
109        move |rank, reason| {
110            todo_tx.send(Action::Stop(rank, reason)).unwrap();
111        }
112    }
113
114    /// A function to shut down the alloc for testing purposes.
115    pub(crate) fn stopper(&self) -> impl Fn() + 'static {
116        let todo_tx = self.todo_tx.clone();
117        let size = self.size();
118        move || {
119            for rank in 0..size {
120                todo_tx
121                    .send(Action::Stop(rank, ProcStopReason::Stopped))
122                    .unwrap();
123            }
124            todo_tx.send(Action::Stopped).unwrap();
125        }
126    }
127
128    pub(crate) fn name(&self) -> &ShortUuid {
129        &self.name
130    }
131
132    pub(crate) fn size(&self) -> usize {
133        self.spec.extent.num_ranks()
134    }
135}
136
137#[async_trait]
138impl Alloc for LocalAlloc {
139    async fn next(&mut self) -> Option<ProcState> {
140        if self.stopped {
141            return None;
142        }
143        if self.failed && !self.stopped {
144            // Failed alloc. Wait for stop().
145            futures::future::pending::<()>().await;
146            unreachable!("future::pending completed");
147        }
148        let event = loop {
149            if let state @ Some(_) = self.queue.pop_front() {
150                break state;
151            }
152
153            match self.todo_rx.recv().await? {
154                Action::Start(rank) => {
155                    let proc_id = ProcId::Ranked(self.world_id.clone(), rank);
156                    let bspan = tracing::info_span!("mesh_agent_bootstrap");
157                    let (proc, mesh_agent) = match MeshAgent::bootstrap(proc_id.clone()).await {
158                        Ok(proc_and_agent) => proc_and_agent,
159                        Err(err) => {
160                            let message = format!("failed spawn mesh agent for {}: {}", rank, err);
161                            tracing::error!(message);
162                            // It's unclear if this is actually recoverable in a practical sense,
163                            // so we give up.
164                            self.failed = true;
165                            break Some(ProcState::Failed {
166                                world_id: self.world_id.clone(),
167                                description: message,
168                            });
169                        }
170                    };
171                    drop(bspan);
172
173                    let (addr, proc_rx) = loop {
174                        match channel::serve(ChannelAddr::any(self.transport())).await {
175                            Ok(addr_and_proc_rx) => break addr_and_proc_rx,
176                            Err(err) => {
177                                tracing::error!(
178                                    "failed to create channel for rank {}: {}",
179                                    rank,
180                                    err
181                                );
182                                #[allow(clippy::disallowed_methods)]
183                                sleep(Duration::from_secs(1)).await;
184                                continue;
185                            }
186                        }
187                    };
188
189                    // Undeliverable messages get forwarded to the mesh agent.
190                    let handle = proc.clone().serve(proc_rx);
191
192                    self.procs.insert(
193                        rank,
194                        LocalProc {
195                            proc,
196                            addr: addr.clone(),
197                            handle,
198                        },
199                    );
200
201                    let point = match self.spec.extent.point_of_rank(rank) {
202                        Ok(point) => point,
203                        Err(err) => {
204                            tracing::error!("failed to get point for rank {}: {}", rank, err);
205                            return None;
206                        }
207                    };
208                    let created = ProcState::Created {
209                        proc_id: proc_id.clone(),
210                        point,
211                        pid: std::process::id(),
212                    };
213                    self.queue.push_back(ProcState::Running {
214                        proc_id,
215                        mesh_agent: mesh_agent.bind(),
216                        addr,
217                    });
218                    break Some(created);
219                }
220                Action::Stop(rank, reason) => {
221                    let Some(mut proc_to_stop) = self.procs.remove(&rank) else {
222                        continue;
223                    };
224
225                    // Stop serving the mailbox.
226                    proc_to_stop.handle.stop("received Action::Stop");
227
228                    if let Err(err) = proc_to_stop
229                        .proc
230                        .destroy_and_wait(Duration::from_millis(10), None)
231                        .await
232                    {
233                        tracing::error!("error while stopping proc {}: {}", rank, err);
234                    }
235                    break Some(ProcState::Stopped {
236                        reason,
237                        proc_id: proc_to_stop.proc.proc_id().clone(),
238                    });
239                }
240                Action::Stopped => break None,
241            }
242        };
243        self.stopped = event.is_none();
244        event
245    }
246
247    fn extent(&self) -> &Extent {
248        &self.spec.extent
249    }
250
251    fn world_id(&self) -> &WorldId {
252        &self.world_id
253    }
254
255    fn transport(&self) -> ChannelTransport {
256        self.transport.clone()
257    }
258
259    async fn stop(&mut self) -> Result<(), AllocatorError> {
260        for rank in 0..self.size() {
261            self.todo_tx
262                .send(Action::Stop(rank, ProcStopReason::Stopped))
263                .unwrap();
264        }
265        self.todo_tx.send(Action::Stopped).unwrap();
266        Ok(())
267    }
268}
269
270impl Drop for LocalAlloc {
271    fn drop(&mut self) {
272        tracing::debug!(
273            "dropping LocalAlloc of name: {}, world id: {}",
274            self.name,
275            self.world_id
276        );
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    crate::alloc_test_suite!(LocalAllocator);
285}