hyperactor_mesh/v1/host_mesh/
mesh_agent.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! The mesh agent actor that manages a host.
10
11use std::collections::HashMap;
12use std::fmt;
13use std::pin::Pin;
14
15use async_trait::async_trait;
16use enum_as_inner::EnumAsInner;
17use hyperactor::Actor;
18use hyperactor::ActorHandle;
19use hyperactor::ActorRef;
20use hyperactor::Context;
21use hyperactor::Handler;
22use hyperactor::Instance;
23use hyperactor::Named;
24use hyperactor::PortRef;
25use hyperactor::Proc;
26use hyperactor::ProcId;
27use hyperactor::RefClient;
28use hyperactor::channel::ChannelTransport;
29use hyperactor::host::Host;
30use hyperactor::host::HostError;
31use hyperactor::host::LocalProcManager;
32use serde::Deserialize;
33use serde::Serialize;
34
35use crate::bootstrap;
36use crate::bootstrap::BootstrapCommand;
37use crate::bootstrap::BootstrapProcManager;
38use crate::proc_mesh::mesh_agent::ProcMeshAgent;
39use crate::resource;
40use crate::v1::Name;
41
42type ProcManagerSpawnFuture =
43    Pin<Box<dyn Future<Output = anyhow::Result<ActorHandle<ProcMeshAgent>>> + Send>>;
44type ProcManagerSpawnFn = Box<dyn Fn(Proc) -> ProcManagerSpawnFuture + Send + Sync>;
45
46/// Represents the different ways a [`Host`] can be managed by an agent.
47///
48/// A host can either:
49/// - [`Process`] — a host running as an external OS process, managed by
50///   [`BootstrapProcManager`].
51/// - [`Local`] — a host running in-process, managed by
52///   [`LocalProcManager`] with a custom spawn function.
53///
54/// This abstraction lets the same `HostAgent` work across both
55/// out-of-process and in-process execution modes.
56#[derive(EnumAsInner)]
57pub enum HostAgentMode {
58    Process(Host<BootstrapProcManager>),
59    Local(Host<LocalProcManager<ProcManagerSpawnFn>>),
60}
61
62impl HostAgentMode {
63    fn system_proc(&self) -> &Proc {
64        #[allow(clippy::match_same_arms)]
65        match self {
66            HostAgentMode::Process(host) => host.system_proc(),
67            HostAgentMode::Local(host) => host.system_proc(),
68        }
69    }
70}
71
72/// A mesh agent is responsible for managing a host iny a [`HostMesh`],
73/// through the resource behaviors defined in [`crate::resource`].
74#[hyperactor::export(
75    handlers=[
76        resource::CreateOrUpdate<()>,
77        resource::GetState<ProcState>,
78        resource::GetRankStatus,
79        ShutdownHost
80    ]
81)]
82pub struct HostMeshAgent {
83    host: Option<HostAgentMode>,
84    created: HashMap<Name, (usize, Result<(ProcId, ActorRef<ProcMeshAgent>), HostError>)>,
85}
86
87impl fmt::Debug for HostMeshAgent {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        f.debug_struct("HostMeshAgent")
90            .field("host", &"..")
91            .field("created", &self.created)
92            .finish()
93    }
94}
95
96#[async_trait]
97impl Actor for HostMeshAgent {
98    type Params = HostAgentMode;
99
100    async fn new(host: HostAgentMode) -> anyhow::Result<Self> {
101        Ok(Self {
102            host: Some(host),
103            created: HashMap::new(),
104        })
105    }
106}
107
108#[async_trait]
109impl Handler<resource::CreateOrUpdate<()>> for HostMeshAgent {
110    #[tracing::instrument("HostMeshAgent::CreateOrUpdate", level = "info", skip_all, fields(name=%create_or_update.name))]
111    async fn handle(
112        &mut self,
113        _cx: &Context<Self>,
114        create_or_update: resource::CreateOrUpdate<()>,
115    ) -> anyhow::Result<()> {
116        if self.created.contains_key(&create_or_update.name) {
117            // Already created: there is no update.
118            return Ok(());
119        }
120
121        let host = self.host.as_mut().expect("host present");
122        let created = match host {
123            HostAgentMode::Process(host) => {
124                host.spawn(create_or_update.name.clone().to_string()).await
125            }
126            HostAgentMode::Local(host) => {
127                host.spawn(create_or_update.name.clone().to_string()).await
128            }
129        };
130
131        if let Err(e) = &created {
132            tracing::error!("failed to spawn proc {}: {}", create_or_update.name, e);
133        }
134        self.created.insert(
135            create_or_update.name.clone(),
136            (create_or_update.rank.unwrap(), created),
137        );
138
139        Ok(())
140    }
141}
142
143#[async_trait]
144impl Handler<resource::GetRankStatus> for HostMeshAgent {
145    async fn handle(
146        &mut self,
147        cx: &Context<Self>,
148        get_rank_status: resource::GetRankStatus,
149    ) -> anyhow::Result<()> {
150        let Some(created) = self.created.get(&get_rank_status.name) else {
151            // TODO: how can we get the host's rank here? we should model its absence explicitly.
152            get_rank_status
153                .reply
154                .send(cx, (usize::MAX, resource::Status::NotExist).into())?;
155            return Ok(());
156        };
157
158        let rank_status = match created {
159            (rank, Ok(_)) => (*rank, resource::Status::Running),
160            (rank, Err(e)) => (*rank, resource::Status::Failed(e.to_string())),
161        };
162        get_rank_status.reply.send(cx, rank_status.into())?;
163
164        Ok(())
165    }
166}
167
168#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient)]
169pub struct ShutdownHost {
170    /// Grace window: send SIGTERM and wait this long before
171    /// escalating.
172    pub timeout: std::time::Duration,
173    /// Max number of children to terminate concurrently on this host.
174    pub max_in_flight: usize,
175    /// Ack that the agent finished shutdown work (best-effort).
176    #[reply]
177    pub ack: hyperactor::PortRef<()>,
178}
179
180#[async_trait]
181impl Handler<ShutdownHost> for HostMeshAgent {
182    async fn handle(&mut self, cx: &Context<Self>, msg: ShutdownHost) -> anyhow::Result<()> {
183        // Ack immediately so caller can await.
184        msg.ack.send(cx, ())?;
185
186        if let Some(host_mode) = self.host.take() {
187            match host_mode {
188                HostAgentMode::Process(host) => {
189                    let summary = host
190                        .terminate_children(msg.timeout, msg.max_in_flight.clamp(1, 256))
191                        .await;
192                    tracing::info!(?summary, "terminated children on host");
193                }
194                HostAgentMode::Local(host) => {
195                    let summary = host
196                        .terminate_children(msg.timeout, msg.max_in_flight)
197                        .await;
198                    tracing::info!(?summary, "terminated children on local host");
199                }
200            }
201        }
202        // Drop the host to release any resources that somehow survived.
203        let _ = self.host.take();
204
205        Ok(())
206    }
207}
208
209#[derive(Debug, Clone, PartialEq, Eq, Named, Serialize, Deserialize)]
210pub struct ProcState {
211    pub proc_id: ProcId,
212    pub mesh_agent: ActorRef<ProcMeshAgent>,
213    pub bootstrap_command: Option<BootstrapCommand>,
214    pub proc_status: Option<bootstrap::ProcStatus>,
215}
216
217#[async_trait]
218impl Handler<resource::GetState<ProcState>> for HostMeshAgent {
219    async fn handle(
220        &mut self,
221        cx: &Context<Self>,
222        get_state: resource::GetState<ProcState>,
223    ) -> anyhow::Result<()> {
224        let manager = self
225            .host
226            .as_mut()
227            .expect("host")
228            .as_process()
229            .map(Host::manager);
230        let state = match self.created.get(&get_state.name) {
231            Some((_rank, Ok((proc_id, mesh_agent)))) => resource::State {
232                name: get_state.name.clone(),
233                status: resource::Status::Running,
234                state: Some(ProcState {
235                    proc_id: proc_id.clone(),
236                    mesh_agent: mesh_agent.clone(),
237                    bootstrap_command: manager.map(|m| m.command().clone()),
238                    proc_status: match manager {
239                        Some(manager) => Some(manager.status(proc_id).await.unwrap()),
240                        None => None,
241                    },
242                }),
243            },
244            Some((_rank, Err(e))) => resource::State {
245                name: get_state.name.clone(),
246                status: resource::Status::Failed(e.to_string()),
247                state: None,
248            },
249            None => resource::State {
250                name: get_state.name.clone(),
251                status: resource::Status::NotExist,
252                state: None,
253            },
254        };
255
256        get_state.reply.send(cx, state)?;
257        Ok(())
258    }
259}
260
261/// A trampoline actor that spawns a [`Host`], and sends a reference to the
262/// corresponding [`HostMeshAgent`] to the provided reply port.
263///
264/// This is used to bootstrap host meshes from proc meshes.
265#[derive(Debug)]
266#[hyperactor::export(
267    spawn = true,
268    handlers=[GetHostMeshAgent]
269)]
270pub(crate) struct HostMeshAgentProcMeshTrampoline {
271    host_mesh_agent: ActorHandle<HostMeshAgent>,
272    reply_port: PortRef<ActorRef<HostMeshAgent>>,
273}
274
275#[async_trait]
276impl Actor for HostMeshAgentProcMeshTrampoline {
277    type Params = (
278        ChannelTransport,
279        PortRef<ActorRef<HostMeshAgent>>,
280        Option<BootstrapCommand>,
281        bool, /* local? */
282    );
283
284    async fn new((transport, reply_port, command, local): Self::Params) -> anyhow::Result<Self> {
285        let host = if local {
286            let spawn: ProcManagerSpawnFn = Box::new(|proc| Box::pin(ProcMeshAgent::boot_v1(proc)));
287            let manager = LocalProcManager::new(spawn);
288            let (host, _) = Host::serve(manager, transport.any()).await?;
289            HostAgentMode::Local(host)
290        } else {
291            let command = match command {
292                Some(command) => command,
293                None => BootstrapCommand::current()?,
294            };
295            tracing::info!("booting host with proc command {:?}", command);
296            let manager = BootstrapProcManager::new(command);
297            let (host, _) = Host::serve(manager, transport.any()).await?;
298            HostAgentMode::Process(host)
299        };
300
301        let host_mesh_agent = host
302            .system_proc()
303            .clone()
304            .spawn::<HostMeshAgent>("agent", host)
305            .await?;
306
307        Ok(Self {
308            host_mesh_agent,
309            reply_port,
310        })
311    }
312
313    async fn init(&mut self, this: &Instance<Self>) -> anyhow::Result<()> {
314        self.reply_port.send(this, self.host_mesh_agent.bind())?;
315        Ok(())
316    }
317}
318
319#[derive(Serialize, Deserialize, Debug, Named, Handler, RefClient)]
320pub struct GetHostMeshAgent {
321    #[reply]
322    pub host_mesh_agent: PortRef<ActorRef<HostMeshAgent>>,
323}
324
325#[async_trait]
326impl Handler<GetHostMeshAgent> for HostMeshAgentProcMeshTrampoline {
327    async fn handle(
328        &mut self,
329        cx: &Context<Self>,
330        get_host_mesh_agent: GetHostMeshAgent,
331    ) -> anyhow::Result<()> {
332        get_host_mesh_agent
333            .host_mesh_agent
334            .send(cx, self.host_mesh_agent.bind())?;
335        Ok(())
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use std::assert_matches::assert_matches;
342
343    use hyperactor::Proc;
344    use hyperactor::channel::ChannelTransport;
345
346    use super::*;
347    use crate::bootstrap::ProcStatus;
348    use crate::resource::CreateOrUpdateClient;
349    use crate::resource::GetStateClient;
350
351    #[tokio::test]
352    async fn test_basic() {
353        let (host, _handle) = Host::serve(
354            BootstrapProcManager::new(BootstrapCommand::test()),
355            ChannelTransport::Unix.any(),
356        )
357        .await
358        .unwrap();
359
360        let host_addr = host.addr().clone();
361        let system_proc = host.system_proc().clone();
362        let host_agent = system_proc
363            .spawn::<HostMeshAgent>("agent", HostAgentMode::Process(host))
364            .await
365            .unwrap();
366
367        let client_proc = Proc::direct(ChannelTransport::Unix.any(), "client".to_string())
368            .await
369            .unwrap();
370        let (client, _client_handle) = client_proc.instance("client").unwrap();
371
372        let name = Name::new("proc1");
373
374        // First, create the proc, then query its state:
375
376        host_agent
377            .create_or_update(&client, name.clone(), resource::Rank::new(0), ())
378            .await
379            .unwrap();
380        assert_matches!(
381            host_agent.get_state(&client, name.clone()).await.unwrap(),
382            resource::State {
383                name: resource_name,
384                status: resource::Status::Running,
385                state: Some(ProcState {
386                    // The proc itself should be direct addressed, with its name directly.
387                    proc_id,
388                    // The mesh agent should run in the same proc, under the name
389                    // "agent".
390                    mesh_agent,
391                    bootstrap_command,
392                    proc_status: Some(ProcStatus::Ready { pid: _, started_at: _, addr: _, agent: proc_status_mesh_agent}),
393                }),
394            } if name == resource_name
395              && proc_id == ProcId::Direct(host_addr.clone(), name.to_string())
396              && mesh_agent == ActorRef::attest(ProcId::Direct(host_addr.clone(), name.to_string()).actor_id("agent", 0)) && bootstrap_command == Some(BootstrapCommand::test())
397              && mesh_agent == proc_status_mesh_agent
398        );
399    }
400}