1use std::cell::OnceCell;
12use std::collections::HashMap;
13use std::fmt;
14use std::pin::Pin;
15
16use async_trait::async_trait;
17use enum_as_inner::EnumAsInner;
18use hyperactor::Actor;
19use hyperactor::ActorHandle;
20use hyperactor::ActorId;
21use hyperactor::ActorRef;
22use hyperactor::Context;
23use hyperactor::HandleClient;
24use hyperactor::Handler;
25use hyperactor::Instance;
26use hyperactor::PortHandle;
27use hyperactor::PortRef;
28use hyperactor::Proc;
29use hyperactor::ProcId;
30use hyperactor::RefClient;
31use hyperactor::channel::ChannelTransport;
32use hyperactor::context;
33use hyperactor::host::Host;
34use hyperactor::host::HostError;
35use hyperactor::host::LocalProcManager;
36use hyperactor::host::SingleTerminate;
37use serde::Deserialize;
38use serde::Serialize;
39use tokio::time::Duration;
40use typeuri::Named;
41
42use crate::bootstrap;
43use crate::bootstrap::BootstrapCommand;
44use crate::bootstrap::BootstrapProcConfig;
45use crate::bootstrap::BootstrapProcManager;
46use crate::proc_mesh::mesh_agent::ProcMeshAgent;
47use crate::resource;
48use crate::resource::ProcSpec;
49use crate::v1::Name;
50
51type ProcManagerSpawnFuture =
52 Pin<Box<dyn Future<Output = anyhow::Result<ActorHandle<ProcMeshAgent>>> + Send>>;
53type ProcManagerSpawnFn = Box<dyn Fn(Proc) -> ProcManagerSpawnFuture + Send + Sync>;
54
55#[derive(EnumAsInner)]
66pub enum HostAgentMode {
67 Process(Host<BootstrapProcManager>),
68 Local(Host<LocalProcManager<ProcManagerSpawnFn>>),
69}
70
71impl HostAgentMode {
72 fn system_proc(&self) -> &Proc {
73 #[allow(clippy::match_same_arms)]
74 match self {
75 HostAgentMode::Process(host) => host.system_proc(),
76 HostAgentMode::Local(host) => host.system_proc(),
77 }
78 }
79
80 fn local_proc(&self) -> &Proc {
81 #[allow(clippy::match_same_arms)]
82 match self {
83 HostAgentMode::Process(host) => host.local_proc(),
84 HostAgentMode::Local(host) => host.local_proc(),
85 }
86 }
87
88 async fn terminate_proc(
89 &self,
90 cx: &impl context::Actor,
91 proc: &ProcId,
92 timeout: Duration,
93 ) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
94 #[allow(clippy::match_same_arms)]
95 match self {
96 HostAgentMode::Process(host) => host.terminate_proc(cx, proc, timeout).await,
97 HostAgentMode::Local(host) => host.terminate_proc(cx, proc, timeout).await,
98 }
99 }
100}
101
102#[derive(Debug)]
103struct ProcCreationState {
104 rank: usize,
105 created: Result<(ProcId, ActorRef<ProcMeshAgent>), HostError>,
106 stopped: bool,
107}
108
109#[hyperactor::export(
112 handlers=[
113 resource::CreateOrUpdate<ProcSpec>,
114 resource::Stop,
115 resource::GetState<ProcState>,
116 resource::GetRankStatus { cast = true },
117 resource::List,
118 ShutdownHost
119 ]
120)]
121pub struct HostMeshAgent {
122 host: Option<HostAgentMode>,
123 created: HashMap<Name, ProcCreationState>,
124 local_mesh_agent: OnceCell<anyhow::Result<ActorHandle<ProcMeshAgent>>>,
126}
127
128impl HostMeshAgent {
129 pub fn new(host: HostAgentMode) -> Self {
131 Self {
132 host: Some(host),
133 created: HashMap::new(),
134 local_mesh_agent: OnceCell::new(),
135 }
136 }
137}
138
139#[async_trait]
140impl Actor for HostMeshAgent {
141 async fn init(&mut self, this: &Instance<Self>) -> Result<(), anyhow::Error> {
142 this.bind::<Self>();
145 match self.host.as_mut().unwrap() {
146 HostAgentMode::Process(host) => {
147 host.serve();
148 let (directory, file) = hyperactor_telemetry::log_file_path(
149 hyperactor_telemetry::env::Env::current(),
150 None,
151 )
152 .unwrap();
153 eprintln!(
154 "Monarch internal logs are being written to {}/{}.log; execution id {}",
155 directory,
156 file,
157 hyperactor_telemetry::env::execution_id(),
158 );
159 }
160 HostAgentMode::Local(host) => {
161 host.serve();
162 }
163 };
164 Ok(())
165 }
166}
167
168impl fmt::Debug for HostMeshAgent {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 f.debug_struct("HostMeshAgent")
171 .field("host", &"..")
172 .field("created", &self.created)
173 .finish()
174 }
175}
176
177#[async_trait]
178impl Handler<resource::CreateOrUpdate<ProcSpec>> for HostMeshAgent {
179 #[tracing::instrument("HostMeshAgent::CreateOrUpdate", level = "info", skip_all, fields(name=%create_or_update.name))]
180 async fn handle(
181 &mut self,
182 _cx: &Context<Self>,
183 create_or_update: resource::CreateOrUpdate<ProcSpec>,
184 ) -> anyhow::Result<()> {
185 if self.created.contains_key(&create_or_update.name) {
186 return Ok(());
188 }
189
190 let host = self.host.as_mut().expect("host present");
191 let created = match host {
192 HostAgentMode::Process(host) => {
193 host.spawn(
194 create_or_update.name.clone().to_string(),
195 BootstrapProcConfig {
196 create_rank: create_or_update.rank.unwrap(),
197 client_config_override: create_or_update
198 .spec
199 .client_config_override
200 .clone(),
201 },
202 )
203 .await
204 }
205 HostAgentMode::Local(host) => {
206 host.spawn(create_or_update.name.clone().to_string(), ())
207 .await
208 }
209 };
210
211 if let Err(e) = &created {
212 tracing::error!("failed to spawn proc {}: {}", create_or_update.name, e);
213 }
214 self.created.insert(
215 create_or_update.name.clone(),
216 ProcCreationState {
217 rank: create_or_update.rank.unwrap(),
218 created,
219 stopped: false,
220 },
221 );
222
223 Ok(())
224 }
225}
226
227#[async_trait]
228impl Handler<resource::Stop> for HostMeshAgent {
229 async fn handle(&mut self, cx: &Context<Self>, message: resource::Stop) -> anyhow::Result<()> {
230 let host = self
231 .host
232 .as_mut()
233 .ok_or(anyhow::anyhow!("HostMeshAgent has already shut down"))?;
234 let manager = host.as_process().map(Host::manager);
235 let timeout = hyperactor_config::global::get(hyperactor::config::PROCESS_EXIT_TIMEOUT);
236 let proc = self.created.get_mut(&message.name);
239 if let Some(ProcCreationState {
240 created: Ok((proc_id, _)),
241 stopped,
242 ..
243 }) = proc
244 {
245 let proc_status = match manager {
246 Some(manager) => manager.status(proc_id).await,
247 None => None,
248 };
249 let should_stop = if let Some(status) = &proc_status {
253 resource::Status::from(status.clone()).is_healthy()
254 } else {
255 !*stopped
256 };
257 if should_stop {
258 host.terminate_proc(&cx, proc_id, timeout).await?;
259 *stopped = true;
260 }
261 }
262
263 Ok(())
264 }
265}
266
267#[async_trait]
268impl Handler<resource::GetRankStatus> for HostMeshAgent {
269 async fn handle(
270 &mut self,
271 cx: &Context<Self>,
272 get_rank_status: resource::GetRankStatus,
273 ) -> anyhow::Result<()> {
274 use crate::resource::Status;
275 use crate::v1::StatusOverlay;
276
277 let manager = self
278 .host
279 .as_mut()
280 .and_then(|h| h.as_process())
281 .map(Host::manager);
282 let (rank, status) = match self.created.get(&get_rank_status.name) {
283 Some(ProcCreationState {
284 rank,
285 created: Ok((proc_id, _mesh_agent)),
286 stopped,
287 }) => {
288 let proc_status = match manager {
289 Some(manager) => manager.status(proc_id).await,
290 None => None,
291 };
292 let status = if let Some(status) = &proc_status {
295 status.clone().into()
296 } else if *stopped {
297 resource::Status::Stopped
298 } else {
299 resource::Status::Running
300 };
301 (*rank, status)
302 }
303 Some(ProcCreationState {
306 rank,
307 created: Err(e),
308 ..
309 }) => (*rank, Status::Failed(e.to_string())),
310 None => (usize::MAX, Status::NotExist),
311 };
312
313 let overlay = if rank == usize::MAX {
314 StatusOverlay::new()
315 } else {
316 StatusOverlay::try_from_runs(vec![(rank..(rank + 1), status)])
317 .expect("valid single-run overlay")
318 };
319 let result = get_rank_status.reply.send(cx, overlay);
320 if let Err(e) = result {
324 tracing::warn!(
325 actor = %cx.self_id(),
326 "failed to send GetRankStatus reply to {} due to error: {}",
327 get_rank_status.reply.port_id().actor_id(),
328 e
329 );
330 }
331 Ok(())
332 }
333}
334
335#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient, HandleClient)]
336pub struct ShutdownHost {
337 pub timeout: std::time::Duration,
340 pub max_in_flight: usize,
342 #[reply]
344 pub ack: hyperactor::PortRef<()>,
345}
346wirevalue::register_type!(ShutdownHost);
347
348#[async_trait]
349impl Handler<ShutdownHost> for HostMeshAgent {
350 async fn handle(&mut self, cx: &Context<Self>, msg: ShutdownHost) -> anyhow::Result<()> {
351 msg.ack.send(cx, ())?;
353
354 if let Some(host_mode) = self.host.take() {
355 match host_mode {
356 HostAgentMode::Process(host) => {
357 let summary = host
358 .terminate_children(cx, msg.timeout, msg.max_in_flight.clamp(1, 256))
359 .await;
360 tracing::info!(?summary, "terminated children on host");
361 }
362 HostAgentMode::Local(host) => {
363 let summary = host
364 .terminate_children(cx, msg.timeout, msg.max_in_flight)
365 .await;
366 tracing::info!(?summary, "terminated children on local host");
367 }
368 }
369 }
370 let _ = self.host.take();
372
373 Ok(())
374 }
375}
376
377#[derive(Debug, Clone, PartialEq, Eq, Named, Serialize, Deserialize)]
378pub struct ProcState {
379 pub proc_id: ProcId,
380 pub create_rank: usize,
381 pub mesh_agent: ActorRef<ProcMeshAgent>,
382 pub bootstrap_command: Option<BootstrapCommand>,
383 pub proc_status: Option<bootstrap::ProcStatus>,
384}
385wirevalue::register_type!(ProcState);
386
387#[async_trait]
388impl Handler<resource::GetState<ProcState>> for HostMeshAgent {
389 async fn handle(
390 &mut self,
391 cx: &Context<Self>,
392 get_state: resource::GetState<ProcState>,
393 ) -> anyhow::Result<()> {
394 let manager: Option<&BootstrapProcManager> = self
395 .host
396 .as_mut()
397 .and_then(|h| h.as_process())
398 .map(Host::manager);
399 let state = match self.created.get(&get_state.name) {
400 Some(ProcCreationState {
401 rank,
402 created: Ok((proc_id, mesh_agent)),
403 stopped,
404 }) => {
405 let proc_status = match manager {
406 Some(manager) => manager.status(proc_id).await,
407 None => None,
408 };
409 let status = if let Some(status) = &proc_status {
412 status.clone().into()
413 } else if *stopped {
414 resource::Status::Stopped
415 } else {
416 resource::Status::Running
417 };
418 resource::State {
419 name: get_state.name.clone(),
420 status,
421 state: Some(ProcState {
422 proc_id: proc_id.clone(),
423 create_rank: *rank,
424 mesh_agent: mesh_agent.clone(),
425 bootstrap_command: manager.map(|m| m.command().clone()),
426 proc_status,
427 }),
428 }
429 }
430 Some(ProcCreationState {
431 created: Err(e), ..
432 }) => resource::State {
433 name: get_state.name.clone(),
434 status: resource::Status::Failed(e.to_string()),
435 state: None,
436 },
437 None => resource::State {
438 name: get_state.name.clone(),
439 status: resource::Status::NotExist,
440 state: None,
441 },
442 };
443
444 let result = get_state.reply.send(cx, state);
445 if let Err(e) = result {
449 tracing::warn!(
450 actor = %cx.self_id(),
451 "failed to send GetState reply to {} due to error: {}",
452 get_state.reply.port_id().actor_id(),
453 e
454 );
455 }
456 Ok(())
457 }
458}
459
460#[async_trait]
461impl Handler<resource::List> for HostMeshAgent {
462 async fn handle(&mut self, cx: &Context<Self>, list: resource::List) -> anyhow::Result<()> {
463 list.reply
464 .send(cx, self.created.keys().cloned().collect())?;
465 Ok(())
466 }
467}
468
469#[derive(Debug, hyperactor::Handler, hyperactor::HandleClient)]
473pub struct GetLocalProc {
474 #[reply]
475 pub proc_mesh_agent: PortHandle<ActorHandle<ProcMeshAgent>>,
476}
477
478#[async_trait]
479impl Handler<GetLocalProc> for HostMeshAgent {
480 async fn handle(
481 &mut self,
482 _cx: &Context<Self>,
483 GetLocalProc { proc_mesh_agent }: GetLocalProc,
484 ) -> anyhow::Result<()> {
485 let agent = self.local_mesh_agent.get_or_init(|| {
486 ProcMeshAgent::boot_v1(self.host.as_ref().unwrap().local_proc().clone())
487 });
488
489 match agent {
490 Err(e) => anyhow::bail!("error booting local proc: {}", e),
491 Ok(agent) => proc_mesh_agent.send(agent.clone())?,
492 };
493
494 Ok(())
495 }
496}
497
498#[derive(Debug)]
503#[hyperactor::export(
504 spawn = true,
505 handlers=[GetHostMeshAgent]
506)]
507pub(crate) struct HostMeshAgentProcMeshTrampoline {
508 host_mesh_agent: ActorHandle<HostMeshAgent>,
509 reply_port: PortRef<ActorRef<HostMeshAgent>>,
510}
511
512#[async_trait]
513impl Actor for HostMeshAgentProcMeshTrampoline {
514 async fn init(&mut self, this: &Instance<Self>) -> anyhow::Result<()> {
515 self.reply_port.send(this, self.host_mesh_agent.bind())?;
516 Ok(())
517 }
518}
519
520#[async_trait]
521impl hyperactor::RemoteSpawn for HostMeshAgentProcMeshTrampoline {
522 type Params = (
523 ChannelTransport,
524 PortRef<ActorRef<HostMeshAgent>>,
525 Option<BootstrapCommand>,
526 bool, );
528
529 async fn new((transport, reply_port, command, local): Self::Params) -> anyhow::Result<Self> {
530 let host = if local {
531 let spawn: ProcManagerSpawnFn =
532 Box::new(|proc| Box::pin(std::future::ready(ProcMeshAgent::boot_v1(proc))));
533 let manager = LocalProcManager::new(spawn);
534 let host = Host::new(manager, transport.any()).await?;
535 HostAgentMode::Local(host)
536 } else {
537 let command = match command {
538 Some(command) => command,
539 None => BootstrapCommand::current()?,
540 };
541 tracing::info!("booting host with proc command {:?}", command);
542 let manager = BootstrapProcManager::new(command).unwrap();
543 let host = Host::new(manager, transport.any()).await?;
544 HostAgentMode::Process(host)
545 };
546
547 let system_proc = host.system_proc().clone();
548 let host_mesh_agent = system_proc.spawn("agent", HostMeshAgent::new(host))?;
549
550 Ok(Self {
551 host_mesh_agent,
552 reply_port,
553 })
554 }
555}
556
557#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient)]
558pub struct GetHostMeshAgent {
559 #[reply]
560 pub host_mesh_agent: PortRef<ActorRef<HostMeshAgent>>,
561}
562wirevalue::register_type!(GetHostMeshAgent);
563
564#[async_trait]
565impl Handler<GetHostMeshAgent> for HostMeshAgentProcMeshTrampoline {
566 async fn handle(
567 &mut self,
568 cx: &Context<Self>,
569 get_host_mesh_agent: GetHostMeshAgent,
570 ) -> anyhow::Result<()> {
571 get_host_mesh_agent
572 .host_mesh_agent
573 .send(cx, self.host_mesh_agent.bind())?;
574 Ok(())
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use std::assert_matches::assert_matches;
581
582 use hyperactor::Proc;
583 use hyperactor::channel::ChannelTransport;
584
585 use super::*;
586 use crate::bootstrap::ProcStatus;
587 use crate::resource::CreateOrUpdateClient;
588 use crate::resource::GetStateClient;
589
590 #[tokio::test]
591 #[cfg(fbcode_build)]
592 async fn test_basic() {
593 let host = Host::new(
594 BootstrapProcManager::new(BootstrapCommand::test()).unwrap(),
595 ChannelTransport::Unix.any(),
596 )
597 .await
598 .unwrap();
599
600 let host_addr = host.addr().clone();
601 let system_proc = host.system_proc().clone();
602 let host_agent = system_proc
603 .spawn("agent", HostMeshAgent::new(HostAgentMode::Process(host)))
604 .unwrap();
605
606 let client_proc = Proc::direct(ChannelTransport::Unix.any(), "client".to_string()).unwrap();
607 let (client, _client_handle) = client_proc.instance("client").unwrap();
608
609 let name = Name::new("proc1").unwrap();
610
611 host_agent
614 .create_or_update(
615 &client,
616 name.clone(),
617 resource::Rank::new(0),
618 ProcSpec::default(),
619 )
620 .await
621 .unwrap();
622 assert_matches!(
623 host_agent.get_state(&client, name.clone()).await.unwrap(),
624 resource::State {
625 name: resource_name,
626 status: resource::Status::Running,
627 state: Some(ProcState {
628 proc_id,
630 mesh_agent,
633 bootstrap_command,
634 proc_status: Some(ProcStatus::Ready { pid: _, started_at: _, addr: _, agent: proc_status_mesh_agent}),
635 ..
636 }),
637 } if name == resource_name
638 && proc_id == ProcId::Direct(host_addr.clone(), name.to_string())
639 && mesh_agent == ActorRef::attest(ProcId::Direct(host_addr.clone(), name.to_string()).actor_id("agent", 0)) && bootstrap_command == Some(BootstrapCommand::test())
640 && mesh_agent == proc_status_mesh_agent
641 );
642 }
643}