hyperactor_mesh/alloc/
local.rs1#![allow(dead_code)] use std::collections::HashMap;
14use std::collections::VecDeque;
15use std::time::Duration;
16
17use async_trait::async_trait;
18use hyperactor::ProcId;
19use hyperactor::WorldId;
20use hyperactor::channel;
21use hyperactor::channel::ChannelAddr;
22use hyperactor::channel::ChannelTransport;
23use hyperactor::mailbox::MailboxServer;
24use hyperactor::mailbox::MailboxServerHandle;
25use hyperactor::proc::Proc;
26use ndslice::view::Extent;
27use tokio::sync::mpsc;
28use tokio::time::sleep;
29
30use super::ProcStopReason;
31use crate::alloc::Alloc;
32use crate::alloc::AllocSpec;
33use crate::alloc::Allocator;
34use crate::alloc::AllocatorError;
35use crate::alloc::ProcState;
36use crate::proc_mesh::mesh_agent::MeshAgent;
37use crate::shortuuid::ShortUuid;
38
39enum Action {
40 Start(usize),
41 Stop(usize, ProcStopReason),
42 Stopped,
43}
44
45pub struct LocalAllocator;
51
52#[async_trait]
53impl Allocator for LocalAllocator {
54 type Alloc = LocalAlloc;
55
56 async fn allocate(&mut self, spec: AllocSpec) -> Result<Self::Alloc, AllocatorError> {
57 Ok(LocalAlloc::new(spec))
58 }
59}
60
61struct LocalProc {
62 proc: Proc,
63 addr: ChannelAddr,
64 handle: MailboxServerHandle,
65}
66
67pub struct LocalAlloc {
69 spec: AllocSpec,
70 name: ShortUuid,
71 world_id: WorldId, procs: HashMap<usize, LocalProc>,
73 queue: VecDeque<ProcState>,
74 todo_tx: mpsc::UnboundedSender<Action>,
75 todo_rx: mpsc::UnboundedReceiver<Action>,
76 stopped: bool,
77 failed: bool,
78 transport: ChannelTransport,
79}
80
81impl LocalAlloc {
82 fn new(spec: AllocSpec) -> Self {
83 Self::new_with_transport(spec, ChannelTransport::Local)
84 }
85
86 pub(crate) fn new_with_transport(spec: AllocSpec, transport: ChannelTransport) -> Self {
87 let name = ShortUuid::generate();
88 let (todo_tx, todo_rx) = mpsc::unbounded_channel();
89 for rank in 0..spec.extent.num_ranks() {
90 todo_tx.send(Action::Start(rank)).unwrap();
91 }
92 Self {
93 spec,
94 name: name.clone(),
95 world_id: WorldId(name.to_string()),
96 procs: HashMap::new(),
97 queue: VecDeque::new(),
98 todo_tx,
99 todo_rx,
100 stopped: false,
101 failed: false,
102 transport,
103 }
104 }
105
106 pub(crate) fn chaos_monkey(&self) -> impl Fn(usize, ProcStopReason) + 'static {
108 let todo_tx = self.todo_tx.clone();
109 move |rank, reason| {
110 todo_tx.send(Action::Stop(rank, reason)).unwrap();
111 }
112 }
113
114 pub(crate) fn stopper(&self) -> impl Fn() + 'static {
116 let todo_tx = self.todo_tx.clone();
117 let size = self.size();
118 move || {
119 for rank in 0..size {
120 todo_tx
121 .send(Action::Stop(rank, ProcStopReason::Stopped))
122 .unwrap();
123 }
124 todo_tx.send(Action::Stopped).unwrap();
125 }
126 }
127
128 pub(crate) fn name(&self) -> &ShortUuid {
129 &self.name
130 }
131
132 pub(crate) fn size(&self) -> usize {
133 self.spec.extent.num_ranks()
134 }
135}
136
137#[async_trait]
138impl Alloc for LocalAlloc {
139 async fn next(&mut self) -> Option<ProcState> {
140 if self.stopped {
141 return None;
142 }
143 if self.failed && !self.stopped {
144 futures::future::pending::<()>().await;
146 unreachable!("future::pending completed");
147 }
148 let event = loop {
149 if let state @ Some(_) = self.queue.pop_front() {
150 break state;
151 }
152
153 match self.todo_rx.recv().await? {
154 Action::Start(rank) => {
155 let proc_id = ProcId::Ranked(self.world_id.clone(), rank);
156 let bspan = tracing::info_span!("mesh_agent_bootstrap");
157 let (proc, mesh_agent) = match MeshAgent::bootstrap(proc_id.clone()).await {
158 Ok(proc_and_agent) => proc_and_agent,
159 Err(err) => {
160 let message = format!("failed spawn mesh agent for {}: {}", rank, err);
161 tracing::error!(message);
162 self.failed = true;
165 break Some(ProcState::Failed {
166 world_id: self.world_id.clone(),
167 description: message,
168 });
169 }
170 };
171 drop(bspan);
172
173 let (addr, proc_rx) = loop {
174 match channel::serve(ChannelAddr::any(self.transport())).await {
175 Ok(addr_and_proc_rx) => break addr_and_proc_rx,
176 Err(err) => {
177 tracing::error!(
178 "failed to create channel for rank {}: {}",
179 rank,
180 err
181 );
182 #[allow(clippy::disallowed_methods)]
183 sleep(Duration::from_secs(1)).await;
184 continue;
185 }
186 }
187 };
188
189 let handle = proc.clone().serve(proc_rx);
191
192 self.procs.insert(
193 rank,
194 LocalProc {
195 proc,
196 addr: addr.clone(),
197 handle,
198 },
199 );
200
201 let point = match self.spec.extent.point_of_rank(rank) {
202 Ok(point) => point,
203 Err(err) => {
204 tracing::error!("failed to get point for rank {}: {}", rank, err);
205 return None;
206 }
207 };
208 let created = ProcState::Created {
209 proc_id: proc_id.clone(),
210 point,
211 pid: std::process::id(),
212 };
213 self.queue.push_back(ProcState::Running {
214 proc_id,
215 mesh_agent: mesh_agent.bind(),
216 addr,
217 });
218 break Some(created);
219 }
220 Action::Stop(rank, reason) => {
221 let Some(mut proc_to_stop) = self.procs.remove(&rank) else {
222 continue;
223 };
224
225 proc_to_stop.handle.stop("received Action::Stop");
227
228 if let Err(err) = proc_to_stop
229 .proc
230 .destroy_and_wait(Duration::from_millis(10), None)
231 .await
232 {
233 tracing::error!("error while stopping proc {}: {}", rank, err);
234 }
235 break Some(ProcState::Stopped {
236 reason,
237 proc_id: proc_to_stop.proc.proc_id().clone(),
238 });
239 }
240 Action::Stopped => break None,
241 }
242 };
243 self.stopped = event.is_none();
244 event
245 }
246
247 fn extent(&self) -> &Extent {
248 &self.spec.extent
249 }
250
251 fn world_id(&self) -> &WorldId {
252 &self.world_id
253 }
254
255 fn transport(&self) -> ChannelTransport {
256 self.transport.clone()
257 }
258
259 async fn stop(&mut self) -> Result<(), AllocatorError> {
260 for rank in 0..self.size() {
261 self.todo_tx
262 .send(Action::Stop(rank, ProcStopReason::Stopped))
263 .unwrap();
264 }
265 self.todo_tx.send(Action::Stopped).unwrap();
266 Ok(())
267 }
268}
269
270impl Drop for LocalAlloc {
271 fn drop(&mut self) {
272 tracing::debug!(
273 "dropping LocalAlloc of name: {}, world id: {}",
274 self.name,
275 self.world_id
276 );
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 crate::alloc_test_suite!(LocalAllocator);
285}