1use std::collections::HashMap;
12use std::fmt;
13use std::pin::Pin;
14
15use async_trait::async_trait;
16use enum_as_inner::EnumAsInner;
17use hyperactor::Actor;
18use hyperactor::ActorHandle;
19use hyperactor::ActorId;
20use hyperactor::ActorRef;
21use hyperactor::Context;
22use hyperactor::Handler;
23use hyperactor::Instance;
24use hyperactor::Named;
25use hyperactor::PortRef;
26use hyperactor::Proc;
27use hyperactor::ProcId;
28use hyperactor::RefClient;
29use hyperactor::channel::ChannelTransport;
30use hyperactor::context;
31use hyperactor::host::Host;
32use hyperactor::host::HostError;
33use hyperactor::host::LocalProcManager;
34use hyperactor::host::SingleTerminate;
35use serde::Deserialize;
36use serde::Serialize;
37use tokio::time::Duration;
38
39use crate::bootstrap;
40use crate::bootstrap::BootstrapCommand;
41use crate::bootstrap::BootstrapProcConfig;
42use crate::bootstrap::BootstrapProcManager;
43use crate::proc_mesh::mesh_agent::ProcMeshAgent;
44use crate::resource;
45use crate::resource::ProcSpec;
46use crate::v1::Name;
47
48type ProcManagerSpawnFuture =
49 Pin<Box<dyn Future<Output = anyhow::Result<ActorHandle<ProcMeshAgent>>> + Send>>;
50type ProcManagerSpawnFn = Box<dyn Fn(Proc) -> ProcManagerSpawnFuture + Send + Sync>;
51
52#[derive(EnumAsInner)]
63pub enum HostAgentMode {
64 Process(Host<BootstrapProcManager>),
65 Local(Host<LocalProcManager<ProcManagerSpawnFn>>),
66}
67
68impl HostAgentMode {
69 fn system_proc(&self) -> &Proc {
70 #[allow(clippy::match_same_arms)]
71 match self {
72 HostAgentMode::Process(host) => host.system_proc(),
73 HostAgentMode::Local(host) => host.system_proc(),
74 }
75 }
76
77 async fn terminate_proc(
78 &self,
79 cx: &impl context::Actor,
80 proc: &ProcId,
81 timeout: Duration,
82 ) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
83 #[allow(clippy::match_same_arms)]
84 match self {
85 HostAgentMode::Process(host) => host.terminate_proc(cx, proc, timeout).await,
86 HostAgentMode::Local(host) => host.terminate_proc(cx, proc, timeout).await,
87 }
88 }
89}
90
91#[derive(Debug)]
92struct ProcCreationState {
93 rank: usize,
94 created: Result<(ProcId, ActorRef<ProcMeshAgent>), HostError>,
95 stopped: bool,
96}
97
98#[hyperactor::export(
101 handlers=[
102 resource::CreateOrUpdate<ProcSpec>,
103 resource::Stop,
104 resource::GetState<ProcState>,
105 resource::GetRankStatus { cast = true },
106 ShutdownHost
107 ]
108)]
109pub struct HostMeshAgent {
110 host: Option<HostAgentMode>,
111 created: HashMap<Name, ProcCreationState>,
112}
113
114impl fmt::Debug for HostMeshAgent {
115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116 f.debug_struct("HostMeshAgent")
117 .field("host", &"..")
118 .field("created", &self.created)
119 .finish()
120 }
121}
122
123#[async_trait]
124impl Actor for HostMeshAgent {
125 type Params = HostAgentMode;
126
127 async fn new(host: HostAgentMode) -> anyhow::Result<Self> {
128 if let HostAgentMode::Process(_) = host {
129 let (directory, file) = hyperactor_telemetry::log_file_path(
130 hyperactor_telemetry::env::Env::current(),
131 None,
132 )
133 .unwrap();
134 eprintln!(
135 "Monarch internal logs are being written to {}/{}.log",
136 directory, file
137 );
138 }
139 Ok(Self {
140 host: Some(host),
141 created: HashMap::new(),
142 })
143 }
144}
145
146#[async_trait]
147impl Handler<resource::CreateOrUpdate<ProcSpec>> for HostMeshAgent {
148 #[tracing::instrument("HostMeshAgent::CreateOrUpdate", level = "info", skip_all, fields(name=%create_or_update.name))]
149 async fn handle(
150 &mut self,
151 _cx: &Context<Self>,
152 create_or_update: resource::CreateOrUpdate<ProcSpec>,
153 ) -> anyhow::Result<()> {
154 if self.created.contains_key(&create_or_update.name) {
155 return Ok(());
157 }
158
159 let host = self.host.as_mut().expect("host present");
160 let created = match host {
161 HostAgentMode::Process(host) => {
162 host.spawn(
163 create_or_update.name.clone().to_string(),
164 BootstrapProcConfig {
165 create_rank: create_or_update.rank.unwrap(),
166 client_config_override: create_or_update
167 .spec
168 .client_config_override
169 .clone(),
170 },
171 )
172 .await
173 }
174 HostAgentMode::Local(host) => {
175 host.spawn(create_or_update.name.clone().to_string(), ())
176 .await
177 }
178 };
179
180 if let Err(e) = &created {
181 tracing::error!("failed to spawn proc {}: {}", create_or_update.name, e);
182 }
183 self.created.insert(
184 create_or_update.name.clone(),
185 ProcCreationState {
186 rank: create_or_update.rank.unwrap(),
187 created,
188 stopped: false,
189 },
190 );
191
192 Ok(())
193 }
194}
195
196#[async_trait]
197impl Handler<resource::Stop> for HostMeshAgent {
198 async fn handle(&mut self, cx: &Context<Self>, message: resource::Stop) -> anyhow::Result<()> {
199 let host = self.host.as_mut().expect("host present");
200 let manager = host.as_process().map(Host::manager);
201 let timeout = hyperactor::config::global::get(hyperactor::config::PROCESS_EXIT_TIMEOUT);
202 let proc = self.created.get_mut(&message.name);
205 if let Some(ProcCreationState {
206 created: Ok((proc_id, _)),
207 stopped,
208 ..
209 }) = proc
210 {
211 let proc_status = match manager {
212 Some(manager) => manager.status(proc_id).await,
213 None => None,
214 };
215 let should_stop = if let Some(status) = &proc_status {
219 resource::Status::from(status.clone()).is_healthy()
220 } else {
221 !*stopped
222 };
223 if should_stop {
224 host.terminate_proc(&cx, proc_id, timeout).await?;
225 *stopped = true;
226 }
227 }
228
229 Ok(())
230 }
231}
232
233#[async_trait]
234impl Handler<resource::GetRankStatus> for HostMeshAgent {
235 async fn handle(
236 &mut self,
237 cx: &Context<Self>,
238 get_rank_status: resource::GetRankStatus,
239 ) -> anyhow::Result<()> {
240 use crate::resource::Status;
241 use crate::v1::StatusOverlay;
242
243 let manager = self
244 .host
245 .as_mut()
246 .and_then(|h| h.as_process())
247 .map(Host::manager);
248 let (rank, status) = match self.created.get(&get_rank_status.name) {
249 Some(ProcCreationState {
250 rank,
251 created: Ok((proc_id, _mesh_agent)),
252 stopped,
253 }) => {
254 let proc_status = match manager {
255 Some(manager) => manager.status(proc_id).await,
256 None => None,
257 };
258 let status = if let Some(status) = &proc_status {
261 status.clone().into()
262 } else if *stopped {
263 resource::Status::Stopped
264 } else {
265 resource::Status::Running
266 };
267 (*rank, status)
268 }
269 Some(ProcCreationState {
272 rank,
273 created: Err(e),
274 ..
275 }) => (*rank, Status::Failed(e.to_string())),
276 None => (usize::MAX, Status::NotExist),
277 };
278
279 let overlay = if rank == usize::MAX {
280 StatusOverlay::new()
281 } else {
282 StatusOverlay::try_from_runs(vec![(rank..(rank + 1), status)])
283 .expect("valid single-run overlay")
284 };
285 let result = get_rank_status.reply.send(cx, overlay);
286 if let Err(e) = result {
290 tracing::warn!(
291 actor = %cx.self_id(),
292 "failed to send GetRankStatus reply to {} due to error: {}",
293 get_rank_status.reply.port_id().actor_id(),
294 e
295 );
296 }
297 Ok(())
298 }
299}
300
301#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient)]
302pub struct ShutdownHost {
303 pub timeout: std::time::Duration,
306 pub max_in_flight: usize,
308 #[reply]
310 pub ack: hyperactor::PortRef<()>,
311}
312
313#[async_trait]
314impl Handler<ShutdownHost> for HostMeshAgent {
315 async fn handle(&mut self, cx: &Context<Self>, msg: ShutdownHost) -> anyhow::Result<()> {
316 msg.ack.send(cx, ())?;
318
319 if let Some(host_mode) = self.host.take() {
320 match host_mode {
321 HostAgentMode::Process(host) => {
322 let summary = host
323 .terminate_children(cx, msg.timeout, msg.max_in_flight.clamp(1, 256))
324 .await;
325 tracing::info!(?summary, "terminated children on host");
326 }
327 HostAgentMode::Local(host) => {
328 let summary = host
329 .terminate_children(cx, msg.timeout, msg.max_in_flight)
330 .await;
331 tracing::info!(?summary, "terminated children on local host");
332 }
333 }
334 }
335 let _ = self.host.take();
337
338 Ok(())
339 }
340}
341
342#[derive(Debug, Clone, PartialEq, Eq, Named, Serialize, Deserialize)]
343pub struct ProcState {
344 pub proc_id: ProcId,
345 pub create_rank: usize,
346 pub mesh_agent: ActorRef<ProcMeshAgent>,
347 pub bootstrap_command: Option<BootstrapCommand>,
348 pub proc_status: Option<bootstrap::ProcStatus>,
349}
350
351#[async_trait]
352impl Handler<resource::GetState<ProcState>> for HostMeshAgent {
353 async fn handle(
354 &mut self,
355 cx: &Context<Self>,
356 get_state: resource::GetState<ProcState>,
357 ) -> anyhow::Result<()> {
358 let manager: Option<&BootstrapProcManager> = self
359 .host
360 .as_mut()
361 .expect("host")
362 .as_process()
363 .map(Host::manager);
364 let state = match self.created.get(&get_state.name) {
365 Some(ProcCreationState {
366 rank,
367 created: Ok((proc_id, mesh_agent)),
368 stopped,
369 }) => {
370 let proc_status = match manager {
371 Some(manager) => manager.status(proc_id).await,
372 None => None,
373 };
374 let status = if let Some(status) = &proc_status {
377 status.clone().into()
378 } else if *stopped {
379 resource::Status::Stopped
380 } else {
381 resource::Status::Running
382 };
383 resource::State {
384 name: get_state.name.clone(),
385 status,
386 state: Some(ProcState {
387 proc_id: proc_id.clone(),
388 create_rank: *rank,
389 mesh_agent: mesh_agent.clone(),
390 bootstrap_command: manager.map(|m| m.command().clone()),
391 proc_status,
392 }),
393 }
394 }
395 Some(ProcCreationState {
396 created: Err(e), ..
397 }) => resource::State {
398 name: get_state.name.clone(),
399 status: resource::Status::Failed(e.to_string()),
400 state: None,
401 },
402 None => resource::State {
403 name: get_state.name.clone(),
404 status: resource::Status::NotExist,
405 state: None,
406 },
407 };
408
409 let result = get_state.reply.send(cx, state);
410 if let Err(e) = result {
414 tracing::warn!(
415 actor = %cx.self_id(),
416 "failed to send GetState reply to {} due to error: {}",
417 get_state.reply.port_id().actor_id(),
418 e
419 );
420 }
421 Ok(())
422 }
423}
424
425#[derive(Debug)]
430#[hyperactor::export(
431 spawn = true,
432 handlers=[GetHostMeshAgent]
433)]
434pub(crate) struct HostMeshAgentProcMeshTrampoline {
435 host_mesh_agent: ActorHandle<HostMeshAgent>,
436 reply_port: PortRef<ActorRef<HostMeshAgent>>,
437}
438
439#[async_trait]
440impl Actor for HostMeshAgentProcMeshTrampoline {
441 type Params = (
442 ChannelTransport,
443 PortRef<ActorRef<HostMeshAgent>>,
444 Option<BootstrapCommand>,
445 bool, );
447
448 async fn new((transport, reply_port, command, local): Self::Params) -> anyhow::Result<Self> {
449 let host = if local {
450 let spawn: ProcManagerSpawnFn = Box::new(|proc| Box::pin(ProcMeshAgent::boot_v1(proc)));
451 let manager = LocalProcManager::new(spawn);
452 let (host, _) = Host::serve(manager, transport.any()).await?;
453 HostAgentMode::Local(host)
454 } else {
455 let command = match command {
456 Some(command) => command,
457 None => BootstrapCommand::current()?,
458 };
459 tracing::info!("booting host with proc command {:?}", command);
460 let manager = BootstrapProcManager::new(command).unwrap();
461 let (host, _) = Host::serve(manager, transport.any()).await?;
462 HostAgentMode::Process(host)
463 };
464
465 let host_mesh_agent = host
466 .system_proc()
467 .clone()
468 .spawn::<HostMeshAgent>("agent", host)
469 .await?;
470
471 Ok(Self {
472 host_mesh_agent,
473 reply_port,
474 })
475 }
476
477 async fn init(&mut self, this: &Instance<Self>) -> anyhow::Result<()> {
478 self.reply_port.send(this, self.host_mesh_agent.bind())?;
479 Ok(())
480 }
481}
482
483#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient)]
484pub struct GetHostMeshAgent {
485 #[reply]
486 pub host_mesh_agent: PortRef<ActorRef<HostMeshAgent>>,
487}
488
489#[async_trait]
490impl Handler<GetHostMeshAgent> for HostMeshAgentProcMeshTrampoline {
491 async fn handle(
492 &mut self,
493 cx: &Context<Self>,
494 get_host_mesh_agent: GetHostMeshAgent,
495 ) -> anyhow::Result<()> {
496 get_host_mesh_agent
497 .host_mesh_agent
498 .send(cx, self.host_mesh_agent.bind())?;
499 Ok(())
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use std::assert_matches::assert_matches;
506
507 use hyperactor::Proc;
508 use hyperactor::channel::ChannelTransport;
509
510 use super::*;
511 use crate::bootstrap::ProcStatus;
512 use crate::resource::CreateOrUpdateClient;
513 use crate::resource::GetStateClient;
514
515 #[tokio::test]
516 #[cfg(fbcode_build)]
517 async fn test_basic() {
518 let (host, _handle) = Host::serve(
519 BootstrapProcManager::new(BootstrapCommand::test()).unwrap(),
520 ChannelTransport::Unix.any(),
521 )
522 .await
523 .unwrap();
524
525 let host_addr = host.addr().clone();
526 let system_proc = host.system_proc().clone();
527 let host_agent = system_proc
528 .spawn::<HostMeshAgent>("agent", HostAgentMode::Process(host))
529 .await
530 .unwrap();
531
532 let client_proc = Proc::direct(ChannelTransport::Unix.any(), "client".to_string())
533 .await
534 .unwrap();
535 let (client, _client_handle) = client_proc.instance("client").unwrap();
536
537 let name = Name::new("proc1");
538
539 host_agent
542 .create_or_update(
543 &client,
544 name.clone(),
545 resource::Rank::new(0),
546 ProcSpec::default(),
547 )
548 .await
549 .unwrap();
550 assert_matches!(
551 host_agent.get_state(&client, name.clone()).await.unwrap(),
552 resource::State {
553 name: resource_name,
554 status: resource::Status::Running,
555 state: Some(ProcState {
556 proc_id,
558 mesh_agent,
561 bootstrap_command,
562 proc_status: Some(ProcStatus::Ready { pid: _, started_at: _, addr: _, agent: proc_status_mesh_agent}),
563 ..
564 }),
565 } if name == resource_name
566 && proc_id == ProcId::Direct(host_addr.clone(), name.to_string())
567 && mesh_agent == ActorRef::attest(ProcId::Direct(host_addr.clone(), name.to_string()).actor_id("agent", 0)) && bootstrap_command == Some(BootstrapCommand::test())
568 && mesh_agent == proc_status_mesh_agent
569 );
570 }
571}