1#![allow(dead_code)] use std::collections::HashMap;
14use std::collections::VecDeque;
15use std::time::Duration;
16
17use async_trait::async_trait;
18use hyperactor::channel;
19use hyperactor::channel::ChannelAddr;
20use hyperactor::mailbox::MailboxServer;
21use hyperactor::mailbox::MailboxServerHandle;
22use hyperactor::proc::Proc;
23use hyperactor::reference as hyperactor_reference;
24use ndslice::view::Extent;
25use tokio::sync::mpsc;
26use tokio::time::sleep;
27
28use super::ProcStopReason;
29use crate::alloc::Alloc;
30use crate::alloc::AllocName;
31use crate::alloc::AllocSpec;
32use crate::alloc::Allocator;
33use crate::alloc::AllocatorError;
34use crate::alloc::ProcState;
35use crate::proc_agent::ProcAgent;
36use crate::shortuuid::ShortUuid;
37
38enum Action {
39 Start(usize),
40 Stop(usize, ProcStopReason),
41 Stopped,
42}
43
44pub struct LocalAllocator;
50
51#[async_trait]
52impl Allocator for LocalAllocator {
53 type Alloc = LocalAlloc;
54
55 async fn allocate(&mut self, spec: AllocSpec) -> Result<Self::Alloc, AllocatorError> {
56 let alloc = LocalAlloc::new(spec);
57 tracing::info!(
58 name = "LocalAllocStatus",
59 alloc_name = %alloc.alloc_name(),
60 status = "Allocated",
61 );
62 Ok(alloc)
63 }
64}
65
66struct LocalProc {
67 proc: Proc,
68 create_key: ShortUuid,
69 addr: ChannelAddr,
70 handle: MailboxServerHandle,
71}
72
73pub struct LocalAlloc {
75 spec: AllocSpec,
76 name: ShortUuid,
77 alloc_name: AllocName,
78 procs: HashMap<usize, LocalProc>,
79 queue: VecDeque<ProcState>,
80 todo_tx: mpsc::UnboundedSender<Action>,
81 todo_rx: mpsc::UnboundedReceiver<Action>,
82 stopped: bool,
83 failed: bool,
84}
85
86impl LocalAlloc {
87 pub(crate) fn new(spec: AllocSpec) -> Self {
88 let name = ShortUuid::generate();
89 let (todo_tx, todo_rx) = mpsc::unbounded_channel();
90 for rank in 0..spec.extent.num_ranks() {
91 todo_tx.send(Action::Start(rank)).unwrap();
92 }
93 Self {
94 spec,
95 name: name.clone(),
96 alloc_name: AllocName(name.to_string()),
97 procs: HashMap::new(),
98 queue: VecDeque::new(),
99 todo_tx,
100 todo_rx,
101 stopped: false,
102 failed: false,
103 }
104 }
105
106 pub(crate) fn chaos_monkey(&self) -> impl Fn(usize, ProcStopReason) + 'static {
108 let todo_tx = self.todo_tx.clone();
109 move |rank, reason| {
110 todo_tx.send(Action::Stop(rank, reason)).unwrap();
111 }
112 }
113
114 pub(crate) fn stopper(&self) -> impl Fn() + 'static {
116 let todo_tx = self.todo_tx.clone();
117 let size = self.size();
118 move || {
119 for rank in 0..size {
120 todo_tx
121 .send(Action::Stop(rank, ProcStopReason::Stopped))
122 .unwrap();
123 }
124 todo_tx.send(Action::Stopped).unwrap();
125 }
126 }
127
128 pub(crate) fn name(&self) -> &ShortUuid {
129 &self.name
130 }
131
132 pub(crate) fn size(&self) -> usize {
133 self.spec.extent.num_ranks()
134 }
135}
136
137#[async_trait]
138impl Alloc for LocalAlloc {
139 async fn next(&mut self) -> Option<ProcState> {
140 if self.stopped {
141 return None;
142 }
143 if self.failed && !self.stopped {
144 futures::future::pending::<()>().await;
146 unreachable!("future::pending completed");
147 }
148 let event = loop {
149 if let state @ Some(_) = self.queue.pop_front() {
150 break state;
151 }
152
153 match self.todo_rx.recv().await? {
154 Action::Start(rank) => {
155 let (addr, proc_rx) = loop {
156 match channel::serve(ChannelAddr::any(self.transport())) {
157 Ok(addr_and_proc_rx) => break addr_and_proc_rx,
158 Err(err) => {
159 tracing::error!(
160 "failed to create channel for rank {}: {}",
161 rank,
162 err
163 );
164 sleep(Duration::from_secs(1)).await;
165 continue;
166 }
167 }
168 };
169
170 let proc_name = match &self.spec.proc_name {
171 Some(name) => name.clone(),
172 None => format!("{}_{}", self.alloc_name.name(), rank),
173 };
174 let proc_id = hyperactor_reference::ProcId::with_name(addr.clone(), proc_name);
175
176 let bspan = tracing::info_span!("mesh_agent_bootstrap");
177 let (proc, mesh_agent) = match ProcAgent::bootstrap(proc_id.clone()).await {
178 Ok(proc_and_agent) => proc_and_agent,
179 Err(err) => {
180 let message = format!("failed spawn mesh agent for {}: {}", rank, err);
181 tracing::error!(message);
182 self.failed = true;
185 break Some(ProcState::Failed {
186 alloc_name: self.alloc_name.clone(),
187 description: message,
188 });
189 }
190 };
191 drop(bspan);
192
193 let handle = proc.clone().serve(proc_rx);
195
196 let create_key = ShortUuid::generate();
197
198 self.procs.insert(
199 rank,
200 LocalProc {
201 proc,
202 create_key: create_key.clone(),
203 addr: addr.clone(),
204 handle,
205 },
206 );
207
208 let point = match self.spec.extent.point_of_rank(rank) {
209 Ok(point) => point,
210 Err(err) => {
211 tracing::error!("failed to get point for rank {}: {}", rank, err);
212 return None;
213 }
214 };
215 let created = ProcState::Created {
216 create_key: create_key.clone(),
217 point,
218 pid: std::process::id(),
219 };
220 self.queue.push_back(ProcState::Running {
221 create_key,
222 proc_id,
223 mesh_agent: mesh_agent.bind(),
224 addr,
225 });
226 break Some(created);
227 }
228 Action::Stop(rank, reason) => {
229 let Some(mut proc_to_stop) = self.procs.remove(&rank) else {
230 continue;
231 };
232
233 proc_to_stop.handle.stop("received Action::Stop");
235
236 if let Err(err) = proc_to_stop
237 .proc
238 .destroy_and_wait::<()>(
239 Duration::from_millis(10),
240 None,
241 &reason.to_string(),
242 )
243 .await
244 {
245 tracing::error!("error while stopping proc {}: {}", rank, err);
246 }
247 break Some(ProcState::Stopped {
248 reason,
249 create_key: proc_to_stop.create_key.clone(),
250 });
251 }
252 Action::Stopped => break None,
253 }
254 };
255 self.stopped = event.is_none();
256 event
257 }
258
259 fn spec(&self) -> &AllocSpec {
260 &self.spec
261 }
262
263 fn extent(&self) -> &Extent {
264 &self.spec.extent
265 }
266
267 fn alloc_name(&self) -> &AllocName {
268 &self.alloc_name
269 }
270
271 async fn stop(&mut self) -> Result<(), AllocatorError> {
272 tracing::info!(
273 name = "LocalAllocStatus",
274 alloc_name = %self.alloc_name(),
275 status = "Stopping",
276 );
277 for rank in 0..self.size() {
278 self.todo_tx
279 .send(Action::Stop(rank, ProcStopReason::Stopped))
280 .unwrap();
281 }
282 self.todo_tx.send(Action::Stopped).unwrap();
283 tracing::info!(
284 name = "LocalAllocStatus",
285 alloc_name = %self.alloc_name(),
286 status = "Stop::Sent",
287 "Stop was sent to local procs; check their log to determine if it exited."
288 );
289 Ok(())
290 }
291
292 fn is_local(&self) -> bool {
293 true
294 }
295}
296
297impl Drop for LocalAlloc {
298 fn drop(&mut self) {
299 tracing::info!(
300 name = "LocalAllocStatus",
301 alloc_name = %self.alloc_name(),
302 status = "Dropped",
303 "dropping LocalAlloc of name: {}, alloc_name: {}",
304 self.name,
305 self.alloc_name
306 );
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 crate::alloc_test_suite!(LocalAllocator);
315}