1#![allow(dead_code)] use std::collections::HashMap;
14use std::collections::VecDeque;
15use std::time::Duration;
16
17use async_trait::async_trait;
18use hyperactor::ProcId;
19use hyperactor::WorldId;
20use hyperactor::channel;
21use hyperactor::channel::ChannelAddr;
22use hyperactor::mailbox::MailboxServer;
23use hyperactor::mailbox::MailboxServerHandle;
24use hyperactor::proc::Proc;
25use ndslice::view::Extent;
26use tokio::sync::mpsc;
27use tokio::time::sleep;
28
29use super::ProcStopReason;
30use crate::alloc::Alloc;
31use crate::alloc::AllocSpec;
32use crate::alloc::Allocator;
33use crate::alloc::AllocatorError;
34use crate::alloc::ProcState;
35use crate::proc_mesh::mesh_agent::ProcMeshAgent;
36use crate::shortuuid::ShortUuid;
37
38enum Action {
39 Start(usize),
40 Stop(usize, ProcStopReason),
41 Stopped,
42}
43
44pub struct LocalAllocator;
50
51#[async_trait]
52impl Allocator for LocalAllocator {
53 type Alloc = LocalAlloc;
54
55 async fn allocate(&mut self, spec: AllocSpec) -> Result<Self::Alloc, AllocatorError> {
56 let alloc = LocalAlloc::new(spec);
57 tracing::info!(
58 name = "LocalAllocStatus",
59 alloc_name = %alloc.world_id(),
60 status = "Allocated",
61 );
62 Ok(alloc)
63 }
64}
65
66struct LocalProc {
67 proc: Proc,
68 create_key: ShortUuid,
69 addr: ChannelAddr,
70 handle: MailboxServerHandle,
71}
72
73pub struct LocalAlloc {
75 spec: AllocSpec,
76 name: ShortUuid,
77 world_id: WorldId, procs: HashMap<usize, LocalProc>,
79 queue: VecDeque<ProcState>,
80 todo_tx: mpsc::UnboundedSender<Action>,
81 todo_rx: mpsc::UnboundedReceiver<Action>,
82 stopped: bool,
83 failed: bool,
84}
85
86impl LocalAlloc {
87 pub(crate) fn new(spec: AllocSpec) -> Self {
88 let name = ShortUuid::generate();
89 let (todo_tx, todo_rx) = mpsc::unbounded_channel();
90 for rank in 0..spec.extent.num_ranks() {
91 todo_tx.send(Action::Start(rank)).unwrap();
92 }
93 Self {
94 spec,
95 name: name.clone(),
96 world_id: WorldId(name.to_string()),
97 procs: HashMap::new(),
98 queue: VecDeque::new(),
99 todo_tx,
100 todo_rx,
101 stopped: false,
102 failed: false,
103 }
104 }
105
106 pub(crate) fn chaos_monkey(&self) -> impl Fn(usize, ProcStopReason) + 'static {
108 let todo_tx = self.todo_tx.clone();
109 move |rank, reason| {
110 todo_tx.send(Action::Stop(rank, reason)).unwrap();
111 }
112 }
113
114 pub(crate) fn stopper(&self) -> impl Fn() + 'static {
116 let todo_tx = self.todo_tx.clone();
117 let size = self.size();
118 move || {
119 for rank in 0..size {
120 todo_tx
121 .send(Action::Stop(rank, ProcStopReason::Stopped))
122 .unwrap();
123 }
124 todo_tx.send(Action::Stopped).unwrap();
125 }
126 }
127
128 pub(crate) fn name(&self) -> &ShortUuid {
129 &self.name
130 }
131
132 pub(crate) fn size(&self) -> usize {
133 self.spec.extent.num_ranks()
134 }
135}
136
137#[async_trait]
138impl Alloc for LocalAlloc {
139 async fn next(&mut self) -> Option<ProcState> {
140 if self.stopped {
141 return None;
142 }
143 if self.failed && !self.stopped {
144 futures::future::pending::<()>().await;
146 unreachable!("future::pending completed");
147 }
148 let event = loop {
149 if let state @ Some(_) = self.queue.pop_front() {
150 break state;
151 }
152
153 match self.todo_rx.recv().await? {
154 Action::Start(rank) => {
155 let (addr, proc_rx) = loop {
156 match channel::serve(ChannelAddr::any(self.transport())) {
157 Ok(addr_and_proc_rx) => break addr_and_proc_rx,
158 Err(err) => {
159 tracing::error!(
160 "failed to create channel for rank {}: {}",
161 rank,
162 err
163 );
164 #[allow(clippy::disallowed_methods)]
165 sleep(Duration::from_secs(1)).await;
166 continue;
167 }
168 }
169 };
170
171 let proc_id = match &self.spec.proc_name {
172 Some(name) => ProcId::Direct(addr.clone(), name.clone()),
173 None => ProcId::Ranked(self.world_id.clone(), rank),
174 };
175
176 let bspan = tracing::info_span!("mesh_agent_bootstrap");
177 let (proc, mesh_agent) = match ProcMeshAgent::bootstrap(proc_id.clone()).await {
178 Ok(proc_and_agent) => proc_and_agent,
179 Err(err) => {
180 let message = format!("failed spawn mesh agent for {}: {}", rank, err);
181 tracing::error!(message);
182 self.failed = true;
185 break Some(ProcState::Failed {
186 world_id: self.world_id.clone(),
187 description: message,
188 });
189 }
190 };
191 drop(bspan);
192
193 let handle = proc.clone().serve(proc_rx);
195
196 let create_key = ShortUuid::generate();
197
198 self.procs.insert(
199 rank,
200 LocalProc {
201 proc,
202 create_key: create_key.clone(),
203 addr: addr.clone(),
204 handle,
205 },
206 );
207
208 let point = match self.spec.extent.point_of_rank(rank) {
209 Ok(point) => point,
210 Err(err) => {
211 tracing::error!("failed to get point for rank {}: {}", rank, err);
212 return None;
213 }
214 };
215 let created = ProcState::Created {
216 create_key: create_key.clone(),
217 point,
218 pid: std::process::id(),
219 };
220 self.queue.push_back(ProcState::Running {
221 create_key,
222 proc_id,
223 mesh_agent: mesh_agent.bind(),
224 addr,
225 });
226 break Some(created);
227 }
228 Action::Stop(rank, reason) => {
229 let Some(mut proc_to_stop) = self.procs.remove(&rank) else {
230 continue;
231 };
232
233 proc_to_stop.handle.stop("received Action::Stop");
235
236 if let Err(err) = proc_to_stop
237 .proc
238 .destroy_and_wait::<()>(Duration::from_millis(10), None)
239 .await
240 {
241 tracing::error!("error while stopping proc {}: {}", rank, err);
242 }
243 break Some(ProcState::Stopped {
244 reason,
245 create_key: proc_to_stop.create_key.clone(),
246 });
247 }
248 Action::Stopped => break None,
249 }
250 };
251 self.stopped = event.is_none();
252 event
253 }
254
255 fn spec(&self) -> &AllocSpec {
256 &self.spec
257 }
258
259 fn extent(&self) -> &Extent {
260 &self.spec.extent
261 }
262
263 fn world_id(&self) -> &WorldId {
264 &self.world_id
265 }
266
267 async fn stop(&mut self) -> Result<(), AllocatorError> {
268 tracing::info!(
269 name = "LocalAllocStatus",
270 alloc_name = %self.world_id(),
271 status = "Stopping",
272 );
273 for rank in 0..self.size() {
274 self.todo_tx
275 .send(Action::Stop(rank, ProcStopReason::Stopped))
276 .unwrap();
277 }
278 self.todo_tx.send(Action::Stopped).unwrap();
279 tracing::info!(
280 name = "LocalAllocStatus",
281 alloc_name = %self.world_id(),
282 status = "Stop::Sent",
283 "Stop was sent to local procs; check their log to determine if it exited."
284 );
285 Ok(())
286 }
287
288 fn is_local(&self) -> bool {
289 true
290 }
291}
292
293impl Drop for LocalAlloc {
294 fn drop(&mut self) {
295 tracing::info!(
296 name = "LocalAllocStatus",
297 alloc_name = %self.world_id(),
298 status = "Dropped",
299 "dropping LocalAlloc of name: {}, world id: {}",
300 self.name,
301 self.world_id
302 );
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309
310 crate::alloc_test_suite!(LocalAllocator);
311}