hyperactor_mesh/v1/host_mesh/
mesh_agent.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! The mesh agent actor that manages a host.
10
11use std::collections::HashMap;
12use std::fmt;
13use std::pin::Pin;
14
15use async_trait::async_trait;
16use enum_as_inner::EnumAsInner;
17use hyperactor::Actor;
18use hyperactor::ActorHandle;
19use hyperactor::ActorId;
20use hyperactor::ActorRef;
21use hyperactor::Context;
22use hyperactor::Handler;
23use hyperactor::Instance;
24use hyperactor::Named;
25use hyperactor::PortRef;
26use hyperactor::Proc;
27use hyperactor::ProcId;
28use hyperactor::RefClient;
29use hyperactor::channel::ChannelTransport;
30use hyperactor::context;
31use hyperactor::host::Host;
32use hyperactor::host::HostError;
33use hyperactor::host::LocalProcManager;
34use hyperactor::host::SingleTerminate;
35use serde::Deserialize;
36use serde::Serialize;
37use tokio::time::Duration;
38
39use crate::bootstrap;
40use crate::bootstrap::BootstrapCommand;
41use crate::bootstrap::BootstrapProcConfig;
42use crate::bootstrap::BootstrapProcManager;
43use crate::proc_mesh::mesh_agent::ProcMeshAgent;
44use crate::resource;
45use crate::resource::ProcSpec;
46use crate::v1::Name;
47
48type ProcManagerSpawnFuture =
49    Pin<Box<dyn Future<Output = anyhow::Result<ActorHandle<ProcMeshAgent>>> + Send>>;
50type ProcManagerSpawnFn = Box<dyn Fn(Proc) -> ProcManagerSpawnFuture + Send + Sync>;
51
52/// Represents the different ways a [`Host`] can be managed by an agent.
53///
54/// A host can either:
55/// - [`Process`] — a host running as an external OS process, managed by
56///   [`BootstrapProcManager`].
57/// - [`Local`] — a host running in-process, managed by
58///   [`LocalProcManager`] with a custom spawn function.
59///
60/// This abstraction lets the same `HostAgent` work across both
61/// out-of-process and in-process execution modes.
62#[derive(EnumAsInner)]
63pub enum HostAgentMode {
64    Process(Host<BootstrapProcManager>),
65    Local(Host<LocalProcManager<ProcManagerSpawnFn>>),
66}
67
68impl HostAgentMode {
69    fn system_proc(&self) -> &Proc {
70        #[allow(clippy::match_same_arms)]
71        match self {
72            HostAgentMode::Process(host) => host.system_proc(),
73            HostAgentMode::Local(host) => host.system_proc(),
74        }
75    }
76
77    async fn terminate_proc(
78        &self,
79        cx: &impl context::Actor,
80        proc: &ProcId,
81        timeout: Duration,
82    ) -> Result<(Vec<ActorId>, Vec<ActorId>), anyhow::Error> {
83        #[allow(clippy::match_same_arms)]
84        match self {
85            HostAgentMode::Process(host) => host.terminate_proc(cx, proc, timeout).await,
86            HostAgentMode::Local(host) => host.terminate_proc(cx, proc, timeout).await,
87        }
88    }
89}
90
91#[derive(Debug)]
92struct ProcCreationState {
93    rank: usize,
94    created: Result<(ProcId, ActorRef<ProcMeshAgent>), HostError>,
95    stopped: bool,
96}
97
98/// A mesh agent is responsible for managing a host iny a [`HostMesh`],
99/// through the resource behaviors defined in [`crate::resource`].
100#[hyperactor::export(
101    handlers=[
102        resource::CreateOrUpdate<ProcSpec>,
103        resource::Stop,
104        resource::GetState<ProcState>,
105        resource::GetRankStatus { cast = true },
106        ShutdownHost
107    ]
108)]
109pub struct HostMeshAgent {
110    host: Option<HostAgentMode>,
111    created: HashMap<Name, ProcCreationState>,
112}
113
114impl fmt::Debug for HostMeshAgent {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        f.debug_struct("HostMeshAgent")
117            .field("host", &"..")
118            .field("created", &self.created)
119            .finish()
120    }
121}
122
123#[async_trait]
124impl Actor for HostMeshAgent {
125    type Params = HostAgentMode;
126
127    async fn new(host: HostAgentMode) -> anyhow::Result<Self> {
128        if let HostAgentMode::Process(_) = host {
129            let (directory, file) = hyperactor_telemetry::log_file_path(
130                hyperactor_telemetry::env::Env::current(),
131                None,
132            )
133            .unwrap();
134            eprintln!(
135                "Monarch internal logs are being written to {}/{}.log",
136                directory, file
137            );
138        }
139        Ok(Self {
140            host: Some(host),
141            created: HashMap::new(),
142        })
143    }
144}
145
146#[async_trait]
147impl Handler<resource::CreateOrUpdate<ProcSpec>> for HostMeshAgent {
148    #[tracing::instrument("HostMeshAgent::CreateOrUpdate", level = "info", skip_all, fields(name=%create_or_update.name))]
149    async fn handle(
150        &mut self,
151        _cx: &Context<Self>,
152        create_or_update: resource::CreateOrUpdate<ProcSpec>,
153    ) -> anyhow::Result<()> {
154        if self.created.contains_key(&create_or_update.name) {
155            // Already created: there is no update.
156            return Ok(());
157        }
158
159        let host = self.host.as_mut().expect("host present");
160        let created = match host {
161            HostAgentMode::Process(host) => {
162                host.spawn(
163                    create_or_update.name.clone().to_string(),
164                    BootstrapProcConfig {
165                        create_rank: create_or_update.rank.unwrap(),
166                        client_config_override: create_or_update
167                            .spec
168                            .client_config_override
169                            .clone(),
170                    },
171                )
172                .await
173            }
174            HostAgentMode::Local(host) => {
175                host.spawn(create_or_update.name.clone().to_string(), ())
176                    .await
177            }
178        };
179
180        if let Err(e) = &created {
181            tracing::error!("failed to spawn proc {}: {}", create_or_update.name, e);
182        }
183        self.created.insert(
184            create_or_update.name.clone(),
185            ProcCreationState {
186                rank: create_or_update.rank.unwrap(),
187                created,
188                stopped: false,
189            },
190        );
191
192        Ok(())
193    }
194}
195
196#[async_trait]
197impl Handler<resource::Stop> for HostMeshAgent {
198    async fn handle(&mut self, cx: &Context<Self>, message: resource::Stop) -> anyhow::Result<()> {
199        let host = self.host.as_mut().expect("host present");
200        let manager = host.as_process().map(Host::manager);
201        let timeout = hyperactor::config::global::get(hyperactor::config::PROCESS_EXIT_TIMEOUT);
202        // We don't remove the proc from the state map, instead we just store
203        // its state as Stopped.
204        let proc = self.created.get_mut(&message.name);
205        if let Some(ProcCreationState {
206            created: Ok((proc_id, _)),
207            stopped,
208            ..
209        }) = proc
210        {
211            let proc_status = match manager {
212                Some(manager) => manager.status(proc_id).await,
213                None => None,
214            };
215            // Fetch status from the ProcStatus object if it's available
216            // for more details.
217            // This prevents trying to kill a process that is already dead.
218            let should_stop = if let Some(status) = &proc_status {
219                resource::Status::from(status.clone()).is_healthy()
220            } else {
221                !*stopped
222            };
223            if should_stop {
224                host.terminate_proc(&cx, proc_id, timeout).await?;
225                *stopped = true;
226            }
227        }
228
229        Ok(())
230    }
231}
232
233#[async_trait]
234impl Handler<resource::GetRankStatus> for HostMeshAgent {
235    async fn handle(
236        &mut self,
237        cx: &Context<Self>,
238        get_rank_status: resource::GetRankStatus,
239    ) -> anyhow::Result<()> {
240        use crate::resource::Status;
241        use crate::v1::StatusOverlay;
242
243        let manager = self
244            .host
245            .as_mut()
246            .and_then(|h| h.as_process())
247            .map(Host::manager);
248        let (rank, status) = match self.created.get(&get_rank_status.name) {
249            Some(ProcCreationState {
250                rank,
251                created: Ok((proc_id, _mesh_agent)),
252                stopped,
253            }) => {
254                let proc_status = match manager {
255                    Some(manager) => manager.status(proc_id).await,
256                    None => None,
257                };
258                // Fetch status from the ProcStatus object if it's available
259                // for more details.
260                let status = if let Some(status) = &proc_status {
261                    status.clone().into()
262                } else if *stopped {
263                    resource::Status::Stopped
264                } else {
265                    resource::Status::Running
266                };
267                (*rank, status)
268            }
269            // If the creation failed, show as Failed instead of Stopped even if
270            // the proc was stopped.
271            Some(ProcCreationState {
272                rank,
273                created: Err(e),
274                ..
275            }) => (*rank, Status::Failed(e.to_string())),
276            None => (usize::MAX, Status::NotExist),
277        };
278
279        let overlay = if rank == usize::MAX {
280            StatusOverlay::new()
281        } else {
282            StatusOverlay::try_from_runs(vec![(rank..(rank + 1), status)])
283                .expect("valid single-run overlay")
284        };
285        let result = get_rank_status.reply.send(cx, overlay);
286        // Ignore errors, because returning Err from here would cause the HostMeshAgent
287        // to be stopped, which would take down the entire host. This only means
288        // some actor that requested the rank status failed to receive it.
289        if let Err(e) = result {
290            tracing::warn!(
291                actor = %cx.self_id(),
292                "failed to send GetRankStatus reply to {} due to error: {}",
293                get_rank_status.reply.port_id().actor_id(),
294                e
295            );
296        }
297        Ok(())
298    }
299}
300
301#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient)]
302pub struct ShutdownHost {
303    /// Grace window: send SIGTERM and wait this long before
304    /// escalating.
305    pub timeout: std::time::Duration,
306    /// Max number of children to terminate concurrently on this host.
307    pub max_in_flight: usize,
308    /// Ack that the agent finished shutdown work (best-effort).
309    #[reply]
310    pub ack: hyperactor::PortRef<()>,
311}
312
313#[async_trait]
314impl Handler<ShutdownHost> for HostMeshAgent {
315    async fn handle(&mut self, cx: &Context<Self>, msg: ShutdownHost) -> anyhow::Result<()> {
316        // Ack immediately so caller can await.
317        msg.ack.send(cx, ())?;
318
319        if let Some(host_mode) = self.host.take() {
320            match host_mode {
321                HostAgentMode::Process(host) => {
322                    let summary = host
323                        .terminate_children(cx, msg.timeout, msg.max_in_flight.clamp(1, 256))
324                        .await;
325                    tracing::info!(?summary, "terminated children on host");
326                }
327                HostAgentMode::Local(host) => {
328                    let summary = host
329                        .terminate_children(cx, msg.timeout, msg.max_in_flight)
330                        .await;
331                    tracing::info!(?summary, "terminated children on local host");
332                }
333            }
334        }
335        // Drop the host to release any resources that somehow survived.
336        let _ = self.host.take();
337
338        Ok(())
339    }
340}
341
342#[derive(Debug, Clone, PartialEq, Eq, Named, Serialize, Deserialize)]
343pub struct ProcState {
344    pub proc_id: ProcId,
345    pub create_rank: usize,
346    pub mesh_agent: ActorRef<ProcMeshAgent>,
347    pub bootstrap_command: Option<BootstrapCommand>,
348    pub proc_status: Option<bootstrap::ProcStatus>,
349}
350
351#[async_trait]
352impl Handler<resource::GetState<ProcState>> for HostMeshAgent {
353    async fn handle(
354        &mut self,
355        cx: &Context<Self>,
356        get_state: resource::GetState<ProcState>,
357    ) -> anyhow::Result<()> {
358        let manager: Option<&BootstrapProcManager> = self
359            .host
360            .as_mut()
361            .expect("host")
362            .as_process()
363            .map(Host::manager);
364        let state = match self.created.get(&get_state.name) {
365            Some(ProcCreationState {
366                rank,
367                created: Ok((proc_id, mesh_agent)),
368                stopped,
369            }) => {
370                let proc_status = match manager {
371                    Some(manager) => manager.status(proc_id).await,
372                    None => None,
373                };
374                // Fetch status from the ProcStatus object if it's available
375                // for more details.
376                let status = if let Some(status) = &proc_status {
377                    status.clone().into()
378                } else if *stopped {
379                    resource::Status::Stopped
380                } else {
381                    resource::Status::Running
382                };
383                resource::State {
384                    name: get_state.name.clone(),
385                    status,
386                    state: Some(ProcState {
387                        proc_id: proc_id.clone(),
388                        create_rank: *rank,
389                        mesh_agent: mesh_agent.clone(),
390                        bootstrap_command: manager.map(|m| m.command().clone()),
391                        proc_status,
392                    }),
393                }
394            }
395            Some(ProcCreationState {
396                created: Err(e), ..
397            }) => resource::State {
398                name: get_state.name.clone(),
399                status: resource::Status::Failed(e.to_string()),
400                state: None,
401            },
402            None => resource::State {
403                name: get_state.name.clone(),
404                status: resource::Status::NotExist,
405                state: None,
406            },
407        };
408
409        let result = get_state.reply.send(cx, state);
410        // Ignore errors, because returning Err from here would cause the HostMeshAgent
411        // to be stopped, which would take down the entire host. This only means
412        // some actor that requested the state of a proc failed to receive it.
413        if let Err(e) = result {
414            tracing::warn!(
415                actor = %cx.self_id(),
416                "failed to send GetState reply to {} due to error: {}",
417                get_state.reply.port_id().actor_id(),
418                e
419            );
420        }
421        Ok(())
422    }
423}
424
425/// A trampoline actor that spawns a [`Host`], and sends a reference to the
426/// corresponding [`HostMeshAgent`] to the provided reply port.
427///
428/// This is used to bootstrap host meshes from proc meshes.
429#[derive(Debug)]
430#[hyperactor::export(
431    spawn = true,
432    handlers=[GetHostMeshAgent]
433)]
434pub(crate) struct HostMeshAgentProcMeshTrampoline {
435    host_mesh_agent: ActorHandle<HostMeshAgent>,
436    reply_port: PortRef<ActorRef<HostMeshAgent>>,
437}
438
439#[async_trait]
440impl Actor for HostMeshAgentProcMeshTrampoline {
441    type Params = (
442        ChannelTransport,
443        PortRef<ActorRef<HostMeshAgent>>,
444        Option<BootstrapCommand>,
445        bool, /* local? */
446    );
447
448    async fn new((transport, reply_port, command, local): Self::Params) -> anyhow::Result<Self> {
449        let host = if local {
450            let spawn: ProcManagerSpawnFn = Box::new(|proc| Box::pin(ProcMeshAgent::boot_v1(proc)));
451            let manager = LocalProcManager::new(spawn);
452            let (host, _) = Host::serve(manager, transport.any()).await?;
453            HostAgentMode::Local(host)
454        } else {
455            let command = match command {
456                Some(command) => command,
457                None => BootstrapCommand::current()?,
458            };
459            tracing::info!("booting host with proc command {:?}", command);
460            let manager = BootstrapProcManager::new(command).unwrap();
461            let (host, _) = Host::serve(manager, transport.any()).await?;
462            HostAgentMode::Process(host)
463        };
464
465        let host_mesh_agent = host
466            .system_proc()
467            .clone()
468            .spawn::<HostMeshAgent>("agent", host)
469            .await?;
470
471        Ok(Self {
472            host_mesh_agent,
473            reply_port,
474        })
475    }
476
477    async fn init(&mut self, this: &Instance<Self>) -> anyhow::Result<()> {
478        self.reply_port.send(this, self.host_mesh_agent.bind())?;
479        Ok(())
480    }
481}
482
483#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient)]
484pub struct GetHostMeshAgent {
485    #[reply]
486    pub host_mesh_agent: PortRef<ActorRef<HostMeshAgent>>,
487}
488
489#[async_trait]
490impl Handler<GetHostMeshAgent> for HostMeshAgentProcMeshTrampoline {
491    async fn handle(
492        &mut self,
493        cx: &Context<Self>,
494        get_host_mesh_agent: GetHostMeshAgent,
495    ) -> anyhow::Result<()> {
496        get_host_mesh_agent
497            .host_mesh_agent
498            .send(cx, self.host_mesh_agent.bind())?;
499        Ok(())
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use std::assert_matches::assert_matches;
506
507    use hyperactor::Proc;
508    use hyperactor::channel::ChannelTransport;
509
510    use super::*;
511    use crate::bootstrap::ProcStatus;
512    use crate::resource::CreateOrUpdateClient;
513    use crate::resource::GetStateClient;
514
515    #[tokio::test]
516    #[cfg(fbcode_build)]
517    async fn test_basic() {
518        let (host, _handle) = Host::serve(
519            BootstrapProcManager::new(BootstrapCommand::test()).unwrap(),
520            ChannelTransport::Unix.any(),
521        )
522        .await
523        .unwrap();
524
525        let host_addr = host.addr().clone();
526        let system_proc = host.system_proc().clone();
527        let host_agent = system_proc
528            .spawn::<HostMeshAgent>("agent", HostAgentMode::Process(host))
529            .await
530            .unwrap();
531
532        let client_proc = Proc::direct(ChannelTransport::Unix.any(), "client".to_string())
533            .await
534            .unwrap();
535        let (client, _client_handle) = client_proc.instance("client").unwrap();
536
537        let name = Name::new("proc1");
538
539        // First, create the proc, then query its state:
540
541        host_agent
542            .create_or_update(
543                &client,
544                name.clone(),
545                resource::Rank::new(0),
546                ProcSpec::default(),
547            )
548            .await
549            .unwrap();
550        assert_matches!(
551            host_agent.get_state(&client, name.clone()).await.unwrap(),
552            resource::State {
553                name: resource_name,
554                status: resource::Status::Running,
555                state: Some(ProcState {
556                    // The proc itself should be direct addressed, with its name directly.
557                    proc_id,
558                    // The mesh agent should run in the same proc, under the name
559                    // "agent".
560                    mesh_agent,
561                    bootstrap_command,
562                    proc_status: Some(ProcStatus::Ready { pid: _, started_at: _, addr: _, agent: proc_status_mesh_agent}),
563                    ..
564                }),
565            } if name == resource_name
566              && proc_id == ProcId::Direct(host_addr.clone(), name.to_string())
567              && mesh_agent == ActorRef::attest(ProcId::Direct(host_addr.clone(), name.to_string()).actor_id("agent", 0)) && bootstrap_command == Some(BootstrapCommand::test())
568              && mesh_agent == proc_status_mesh_agent
569        );
570    }
571}