hyperactor_mesh/
test_utils.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use async_trait::async_trait;
10use hyperactor::Actor;
11use hyperactor::Bind;
12use hyperactor::Context;
13use hyperactor::Handler;
14use hyperactor::Unbind;
15use hyperactor::channel::ChannelTransport;
16use serde::Deserialize;
17use serde::Serialize;
18use typeuri::Named;
19
20use crate::host_mesh::HostMesh;
21
22/// Message that can be sent to an EmptyActor.
23#[derive(Serialize, Deserialize, Debug, Named, Clone, Bind, Unbind)]
24pub struct EmptyMessage();
25
26#[derive(Debug, PartialEq, Default)]
27#[hyperactor::export(
28    spawn = true,
29    handlers = [
30        EmptyMessage { cast = true },
31    ],
32)]
33pub struct EmptyActor();
34
35impl Actor for EmptyActor {}
36
37#[async_trait]
38impl Handler<EmptyMessage> for EmptyActor {
39    async fn handle(&mut self, _: &Context<Self>, _: EmptyMessage) -> Result<(), anyhow::Error> {
40        Ok(())
41    }
42}
43
44/// Create a local in-process host mesh with `n` hosts, all running in
45/// the current process using `Local` channel transport.
46///
47/// This is similar to [`HostMesh::local_in_process`] but supports
48/// multiple hosts. All hosts use [`LocalProcManager`] with
49/// [`ChannelTransport::Local`], so there is no IPC overhead.
50///
51/// # Examples
52///
53/// ```ignore
54/// let mut host_mesh = test_utils::local_host_mesh(4).await;
55/// let proc_mesh = host_mesh
56///     .spawn(instance, "test", ndslice::extent!(gpu = 8))
57///     .await
58///     .unwrap();
59/// // ... do something with the proc mesh ...
60/// // shutdown the host mesh.
61/// let _ = host_mesh.shutdown(&instance).await;
62/// ```
63pub async fn local_host_mesh(n: usize) -> HostMesh {
64    let addrs = (0..n).map(|_| ChannelTransport::Local.any()).collect();
65    let host_mesh = HostMesh::local_n_in_process(addrs).await.unwrap();
66    HostMesh::take(host_mesh)
67}