hyperactor_mesh/v1/host_mesh/
mesh_agent.rs1use std::collections::HashMap;
12use std::fmt;
13use std::pin::Pin;
14
15use async_trait::async_trait;
16use enum_as_inner::EnumAsInner;
17use hyperactor::Actor;
18use hyperactor::ActorHandle;
19use hyperactor::ActorRef;
20use hyperactor::Context;
21use hyperactor::Handler;
22use hyperactor::Instance;
23use hyperactor::Named;
24use hyperactor::PortRef;
25use hyperactor::Proc;
26use hyperactor::ProcId;
27use hyperactor::RefClient;
28use hyperactor::channel::ChannelTransport;
29use hyperactor::host::Host;
30use hyperactor::host::HostError;
31use hyperactor::host::LocalProcManager;
32use serde::Deserialize;
33use serde::Serialize;
34
35use crate::bootstrap;
36use crate::bootstrap::BootstrapCommand;
37use crate::bootstrap::BootstrapProcManager;
38use crate::proc_mesh::mesh_agent::ProcMeshAgent;
39use crate::resource;
40use crate::v1::Name;
41
42type ProcManagerSpawnFuture =
43 Pin<Box<dyn Future<Output = anyhow::Result<ActorHandle<ProcMeshAgent>>> + Send>>;
44type ProcManagerSpawnFn = Box<dyn Fn(Proc) -> ProcManagerSpawnFuture + Send + Sync>;
45
46#[derive(EnumAsInner)]
57pub enum HostAgentMode {
58 Process(Host<BootstrapProcManager>),
59 Local(Host<LocalProcManager<ProcManagerSpawnFn>>),
60}
61
62impl HostAgentMode {
63 fn system_proc(&self) -> &Proc {
64 #[allow(clippy::match_same_arms)]
65 match self {
66 HostAgentMode::Process(host) => host.system_proc(),
67 HostAgentMode::Local(host) => host.system_proc(),
68 }
69 }
70}
71
72#[hyperactor::export(
75 handlers=[
76 resource::CreateOrUpdate<()>,
77 resource::GetState<ProcState>,
78 resource::GetRankStatus,
79 ShutdownHost
80 ]
81)]
82pub struct HostMeshAgent {
83 host: Option<HostAgentMode>,
84 created: HashMap<Name, (usize, Result<(ProcId, ActorRef<ProcMeshAgent>), HostError>)>,
85}
86
87impl fmt::Debug for HostMeshAgent {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 f.debug_struct("HostMeshAgent")
90 .field("host", &"..")
91 .field("created", &self.created)
92 .finish()
93 }
94}
95
96#[async_trait]
97impl Actor for HostMeshAgent {
98 type Params = HostAgentMode;
99
100 async fn new(host: HostAgentMode) -> anyhow::Result<Self> {
101 Ok(Self {
102 host: Some(host),
103 created: HashMap::new(),
104 })
105 }
106}
107
108#[async_trait]
109impl Handler<resource::CreateOrUpdate<()>> for HostMeshAgent {
110 #[tracing::instrument("HostMeshAgent::CreateOrUpdate", level = "info", skip_all, fields(name=%create_or_update.name))]
111 async fn handle(
112 &mut self,
113 _cx: &Context<Self>,
114 create_or_update: resource::CreateOrUpdate<()>,
115 ) -> anyhow::Result<()> {
116 if self.created.contains_key(&create_or_update.name) {
117 return Ok(());
119 }
120
121 let host = self.host.as_mut().expect("host present");
122 let created = match host {
123 HostAgentMode::Process(host) => {
124 host.spawn(create_or_update.name.clone().to_string()).await
125 }
126 HostAgentMode::Local(host) => {
127 host.spawn(create_or_update.name.clone().to_string()).await
128 }
129 };
130
131 if let Err(e) = &created {
132 tracing::error!("failed to spawn proc {}: {}", create_or_update.name, e);
133 }
134 self.created.insert(
135 create_or_update.name.clone(),
136 (create_or_update.rank.unwrap(), created),
137 );
138
139 Ok(())
140 }
141}
142
143#[async_trait]
144impl Handler<resource::GetRankStatus> for HostMeshAgent {
145 async fn handle(
146 &mut self,
147 cx: &Context<Self>,
148 get_rank_status: resource::GetRankStatus,
149 ) -> anyhow::Result<()> {
150 let Some(created) = self.created.get(&get_rank_status.name) else {
151 get_rank_status
153 .reply
154 .send(cx, (usize::MAX, resource::Status::NotExist).into())?;
155 return Ok(());
156 };
157
158 let rank_status = match created {
159 (rank, Ok(_)) => (*rank, resource::Status::Running),
160 (rank, Err(e)) => (*rank, resource::Status::Failed(e.to_string())),
161 };
162 get_rank_status.reply.send(cx, rank_status.into())?;
163
164 Ok(())
165 }
166}
167
168#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient)]
169pub struct ShutdownHost {
170 pub timeout: std::time::Duration,
173 pub max_in_flight: usize,
175 #[reply]
177 pub ack: hyperactor::PortRef<()>,
178}
179
180#[async_trait]
181impl Handler<ShutdownHost> for HostMeshAgent {
182 async fn handle(&mut self, cx: &Context<Self>, msg: ShutdownHost) -> anyhow::Result<()> {
183 msg.ack.send(cx, ())?;
185
186 if let Some(host_mode) = self.host.take() {
187 match host_mode {
188 HostAgentMode::Process(host) => {
189 let summary = host
190 .terminate_children(msg.timeout, msg.max_in_flight.clamp(1, 256))
191 .await;
192 tracing::info!(?summary, "terminated children on host");
193 }
194 HostAgentMode::Local(host) => {
195 let summary = host
196 .terminate_children(msg.timeout, msg.max_in_flight)
197 .await;
198 tracing::info!(?summary, "terminated children on local host");
199 }
200 }
201 }
202 let _ = self.host.take();
204
205 Ok(())
206 }
207}
208
209#[derive(Debug, Clone, PartialEq, Eq, Named, Serialize, Deserialize)]
210pub struct ProcState {
211 pub proc_id: ProcId,
212 pub mesh_agent: ActorRef<ProcMeshAgent>,
213 pub bootstrap_command: Option<BootstrapCommand>,
214 pub proc_status: Option<bootstrap::ProcStatus>,
215}
216
217#[async_trait]
218impl Handler<resource::GetState<ProcState>> for HostMeshAgent {
219 async fn handle(
220 &mut self,
221 cx: &Context<Self>,
222 get_state: resource::GetState<ProcState>,
223 ) -> anyhow::Result<()> {
224 let manager = self
225 .host
226 .as_mut()
227 .expect("host")
228 .as_process()
229 .map(Host::manager);
230 let state = match self.created.get(&get_state.name) {
231 Some((_rank, Ok((proc_id, mesh_agent)))) => resource::State {
232 name: get_state.name.clone(),
233 status: resource::Status::Running,
234 state: Some(ProcState {
235 proc_id: proc_id.clone(),
236 mesh_agent: mesh_agent.clone(),
237 bootstrap_command: manager.map(|m| m.command().clone()),
238 proc_status: match manager {
239 Some(manager) => Some(manager.status(proc_id).await.unwrap()),
240 None => None,
241 },
242 }),
243 },
244 Some((_rank, Err(e))) => resource::State {
245 name: get_state.name.clone(),
246 status: resource::Status::Failed(e.to_string()),
247 state: None,
248 },
249 None => resource::State {
250 name: get_state.name.clone(),
251 status: resource::Status::NotExist,
252 state: None,
253 },
254 };
255
256 get_state.reply.send(cx, state)?;
257 Ok(())
258 }
259}
260
261#[derive(Debug)]
266#[hyperactor::export(
267 spawn = true,
268 handlers=[GetHostMeshAgent]
269)]
270pub(crate) struct HostMeshAgentProcMeshTrampoline {
271 host_mesh_agent: ActorHandle<HostMeshAgent>,
272 reply_port: PortRef<ActorRef<HostMeshAgent>>,
273}
274
275#[async_trait]
276impl Actor for HostMeshAgentProcMeshTrampoline {
277 type Params = (
278 ChannelTransport,
279 PortRef<ActorRef<HostMeshAgent>>,
280 Option<BootstrapCommand>,
281 bool, );
283
284 async fn new((transport, reply_port, command, local): Self::Params) -> anyhow::Result<Self> {
285 let host = if local {
286 let spawn: ProcManagerSpawnFn = Box::new(|proc| Box::pin(ProcMeshAgent::boot_v1(proc)));
287 let manager = LocalProcManager::new(spawn);
288 let (host, _) = Host::serve(manager, transport.any()).await?;
289 HostAgentMode::Local(host)
290 } else {
291 let command = match command {
292 Some(command) => command,
293 None => BootstrapCommand::current()?,
294 };
295 tracing::info!("booting host with proc command {:?}", command);
296 let manager = BootstrapProcManager::new(command);
297 let (host, _) = Host::serve(manager, transport.any()).await?;
298 HostAgentMode::Process(host)
299 };
300
301 let host_mesh_agent = host
302 .system_proc()
303 .clone()
304 .spawn::<HostMeshAgent>("agent", host)
305 .await?;
306
307 Ok(Self {
308 host_mesh_agent,
309 reply_port,
310 })
311 }
312
313 async fn init(&mut self, this: &Instance<Self>) -> anyhow::Result<()> {
314 self.reply_port.send(this, self.host_mesh_agent.bind())?;
315 Ok(())
316 }
317}
318
319#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient)]
320pub struct GetHostMeshAgent {
321 #[reply]
322 pub host_mesh_agent: PortRef<ActorRef<HostMeshAgent>>,
323}
324
325#[async_trait]
326impl Handler<GetHostMeshAgent> for HostMeshAgentProcMeshTrampoline {
327 async fn handle(
328 &mut self,
329 cx: &Context<Self>,
330 get_host_mesh_agent: GetHostMeshAgent,
331 ) -> anyhow::Result<()> {
332 get_host_mesh_agent
333 .host_mesh_agent
334 .send(cx, self.host_mesh_agent.bind())?;
335 Ok(())
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use std::assert_matches::assert_matches;
342
343 use hyperactor::Proc;
344 use hyperactor::channel::ChannelTransport;
345
346 use super::*;
347 use crate::bootstrap::ProcStatus;
348 use crate::resource::CreateOrUpdateClient;
349 use crate::resource::GetStateClient;
350
351 #[tokio::test]
352 async fn test_basic() {
353 let (host, _handle) = Host::serve(
354 BootstrapProcManager::new(BootstrapCommand::test()),
355 ChannelTransport::Unix.any(),
356 )
357 .await
358 .unwrap();
359
360 let host_addr = host.addr().clone();
361 let system_proc = host.system_proc().clone();
362 let host_agent = system_proc
363 .spawn::<HostMeshAgent>("agent", HostAgentMode::Process(host))
364 .await
365 .unwrap();
366
367 let client_proc = Proc::direct(ChannelTransport::Unix.any(), "client".to_string())
368 .await
369 .unwrap();
370 let (client, _client_handle) = client_proc.instance("client").unwrap();
371
372 let name = Name::new("proc1");
373
374 host_agent
377 .create_or_update(&client, name.clone(), resource::Rank::new(0), ())
378 .await
379 .unwrap();
380 assert_matches!(
381 host_agent.get_state(&client, name.clone()).await.unwrap(),
382 resource::State {
383 name: resource_name,
384 status: resource::Status::Running,
385 state: Some(ProcState {
386 proc_id,
388 mesh_agent,
391 bootstrap_command,
392 proc_status: Some(ProcStatus::Ready { pid: _, started_at: _, addr: _, agent: proc_status_mesh_agent}),
393 }),
394 } if name == resource_name
395 && proc_id == ProcId::Direct(host_addr.clone(), name.to_string())
396 && mesh_agent == ActorRef::attest(ProcId::Direct(host_addr.clone(), name.to_string()).actor_id("agent", 0)) && bootstrap_command == Some(BootstrapCommand::test())
397 && mesh_agent == proc_status_mesh_agent
398 );
399 }
400}