hyperactor_mesh/alloc/
local.rs1#![allow(dead_code)] use std::collections::HashMap;
14use std::collections::VecDeque;
15use std::time::Duration;
16
17use async_trait::async_trait;
18use hyperactor::ProcId;
19use hyperactor::WorldId;
20use hyperactor::channel;
21use hyperactor::channel::ChannelAddr;
22use hyperactor::mailbox::MailboxServer;
23use hyperactor::mailbox::MailboxServerHandle;
24use hyperactor::proc::Proc;
25use ndslice::view::Extent;
26use tokio::sync::mpsc;
27use tokio::time::sleep;
28
29use super::ProcStopReason;
30use crate::alloc::Alloc;
31use crate::alloc::AllocSpec;
32use crate::alloc::Allocator;
33use crate::alloc::AllocatorError;
34use crate::alloc::ProcState;
35use crate::proc_mesh::mesh_agent::ProcMeshAgent;
36use crate::shortuuid::ShortUuid;
37
38enum Action {
39 Start(usize),
40 Stop(usize, ProcStopReason),
41 Stopped,
42}
43
44pub struct LocalAllocator;
50
51#[async_trait]
52impl Allocator for LocalAllocator {
53 type Alloc = LocalAlloc;
54
55 async fn allocate(&mut self, spec: AllocSpec) -> Result<Self::Alloc, AllocatorError> {
56 Ok(LocalAlloc::new(spec))
57 }
58}
59
60struct LocalProc {
61 proc: Proc,
62 create_key: ShortUuid,
63 addr: ChannelAddr,
64 handle: MailboxServerHandle,
65}
66
67pub struct LocalAlloc {
69 spec: AllocSpec,
70 name: ShortUuid,
71 world_id: WorldId, procs: HashMap<usize, LocalProc>,
73 queue: VecDeque<ProcState>,
74 todo_tx: mpsc::UnboundedSender<Action>,
75 todo_rx: mpsc::UnboundedReceiver<Action>,
76 stopped: bool,
77 failed: bool,
78}
79
80impl LocalAlloc {
81 pub(crate) fn new(spec: AllocSpec) -> Self {
82 let name = ShortUuid::generate();
83 let (todo_tx, todo_rx) = mpsc::unbounded_channel();
84 for rank in 0..spec.extent.num_ranks() {
85 todo_tx.send(Action::Start(rank)).unwrap();
86 }
87 Self {
88 spec,
89 name: name.clone(),
90 world_id: WorldId(name.to_string()),
91 procs: HashMap::new(),
92 queue: VecDeque::new(),
93 todo_tx,
94 todo_rx,
95 stopped: false,
96 failed: false,
97 }
98 }
99
100 pub(crate) fn chaos_monkey(&self) -> impl Fn(usize, ProcStopReason) + 'static {
102 let todo_tx = self.todo_tx.clone();
103 move |rank, reason| {
104 todo_tx.send(Action::Stop(rank, reason)).unwrap();
105 }
106 }
107
108 pub(crate) fn stopper(&self) -> impl Fn() + 'static {
110 let todo_tx = self.todo_tx.clone();
111 let size = self.size();
112 move || {
113 for rank in 0..size {
114 todo_tx
115 .send(Action::Stop(rank, ProcStopReason::Stopped))
116 .unwrap();
117 }
118 todo_tx.send(Action::Stopped).unwrap();
119 }
120 }
121
122 pub(crate) fn name(&self) -> &ShortUuid {
123 &self.name
124 }
125
126 pub(crate) fn size(&self) -> usize {
127 self.spec.extent.num_ranks()
128 }
129}
130
131#[async_trait]
132impl Alloc for LocalAlloc {
133 async fn next(&mut self) -> Option<ProcState> {
134 if self.stopped {
135 return None;
136 }
137 if self.failed && !self.stopped {
138 futures::future::pending::<()>().await;
140 unreachable!("future::pending completed");
141 }
142 let event = loop {
143 if let state @ Some(_) = self.queue.pop_front() {
144 break state;
145 }
146
147 match self.todo_rx.recv().await? {
148 Action::Start(rank) => {
149 let (addr, proc_rx) = loop {
150 match channel::serve(ChannelAddr::any(self.transport())) {
151 Ok(addr_and_proc_rx) => break addr_and_proc_rx,
152 Err(err) => {
153 tracing::error!(
154 "failed to create channel for rank {}: {}",
155 rank,
156 err
157 );
158 #[allow(clippy::disallowed_methods)]
159 sleep(Duration::from_secs(1)).await;
160 continue;
161 }
162 }
163 };
164
165 let proc_id = match &self.spec.proc_name {
166 Some(name) => ProcId::Direct(addr.clone(), name.clone()),
167 None => ProcId::Ranked(self.world_id.clone(), rank),
168 };
169
170 let bspan = tracing::info_span!("mesh_agent_bootstrap");
171 let (proc, mesh_agent) = match ProcMeshAgent::bootstrap(proc_id.clone()).await {
172 Ok(proc_and_agent) => proc_and_agent,
173 Err(err) => {
174 let message = format!("failed spawn mesh agent for {}: {}", rank, err);
175 tracing::error!(message);
176 self.failed = true;
179 break Some(ProcState::Failed {
180 world_id: self.world_id.clone(),
181 description: message,
182 });
183 }
184 };
185 drop(bspan);
186
187 let handle = proc.clone().serve(proc_rx);
189
190 let create_key = ShortUuid::generate();
191
192 self.procs.insert(
193 rank,
194 LocalProc {
195 proc,
196 create_key: create_key.clone(),
197 addr: addr.clone(),
198 handle,
199 },
200 );
201
202 let point = match self.spec.extent.point_of_rank(rank) {
203 Ok(point) => point,
204 Err(err) => {
205 tracing::error!("failed to get point for rank {}: {}", rank, err);
206 return None;
207 }
208 };
209 let created = ProcState::Created {
210 create_key: create_key.clone(),
211 point,
212 pid: std::process::id(),
213 };
214 self.queue.push_back(ProcState::Running {
215 create_key,
216 proc_id,
217 mesh_agent: mesh_agent.bind(),
218 addr,
219 });
220 break Some(created);
221 }
222 Action::Stop(rank, reason) => {
223 let Some(mut proc_to_stop) = self.procs.remove(&rank) else {
224 continue;
225 };
226
227 proc_to_stop.handle.stop("received Action::Stop");
229
230 if let Err(err) = proc_to_stop
231 .proc
232 .destroy_and_wait::<()>(Duration::from_millis(10), None)
233 .await
234 {
235 tracing::error!("error while stopping proc {}: {}", rank, err);
236 }
237 break Some(ProcState::Stopped {
238 reason,
239 create_key: proc_to_stop.create_key.clone(),
240 });
241 }
242 Action::Stopped => break None,
243 }
244 };
245 self.stopped = event.is_none();
246 event
247 }
248
249 fn spec(&self) -> &AllocSpec {
250 &self.spec
251 }
252
253 fn extent(&self) -> &Extent {
254 &self.spec.extent
255 }
256
257 fn world_id(&self) -> &WorldId {
258 &self.world_id
259 }
260
261 async fn stop(&mut self) -> Result<(), AllocatorError> {
262 for rank in 0..self.size() {
263 self.todo_tx
264 .send(Action::Stop(rank, ProcStopReason::Stopped))
265 .unwrap();
266 }
267 self.todo_tx.send(Action::Stopped).unwrap();
268 Ok(())
269 }
270
271 fn is_local(&self) -> bool {
272 true
273 }
274}
275
276impl Drop for LocalAlloc {
277 fn drop(&mut self) {
278 tracing::debug!(
279 "dropping LocalAlloc of name: {}, world id: {}",
280 self.name,
281 self.world_id
282 );
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 crate::alloc_test_suite!(LocalAllocator);
291}