hyperactor_mesh/resource/
mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![allow(dead_code)]
10
11//! This module defines common types for mesh resources. Meshes are managed as
12//! resources, usually by a controller actor implementing the [`crate::resource`]
13//! behavior.
14//!
15//! The mesh controller manages all aspects of the mesh lifecycle, and the owning
16//! actor uses the resource behavior directly to query the state of the mesh.
17
18use hyperactor::Bind;
19use hyperactor::Unbind;
20use ndslice::Extent;
21use serde::Deserialize;
22use serde::Serialize;
23use typeuri::Named;
24
25use crate::resource::Resource;
26use crate::resource::Status;
27use crate::v1::ValueMesh;
28
29/// Mesh specs
30#[derive(Debug, Named, Serialize, Deserialize)]
31pub struct Spec<S> {
32    /// All meshes have an extent
33    extent: Extent,
34    // supervisor: PortHandle<SupervisionEvent(?)>
35    /// The mesh-specific spec.
36    spec: S,
37}
38
39/// Mesh states
40#[derive(Debug, Named, Bind, Unbind, Serialize, Deserialize)]
41pub struct State<S> {
42    /// The current status for each rank in the mesh.
43    pub statuses: ValueMesh<Status>,
44    /// Mesh-specific state.
45    pub state: S,
46}
47
48/// A mesh trait bundles a set of types that together define a mesh resource.
49pub trait Mesh {
50    /// The mesh-specific specification for this resource.
51    type Spec: typeuri::Named
52        + Serialize
53        + for<'de> Deserialize<'de>
54        + Send
55        + Sync
56        + std::fmt::Debug;
57
58    /// The mesh-specific state for this resource.
59    type State: typeuri::Named
60        + Serialize
61        + for<'de> Deserialize<'de>
62        + Send
63        + Sync
64        + std::fmt::Debug;
65}
66
67impl<M: Mesh> Resource for M {
68    type Spec = Spec<M::Spec>;
69    type State = State<M::State>;
70}
71
72#[cfg(test)]
73mod test {
74    use hyperactor::Actor;
75    use hyperactor::Context;
76    use hyperactor::Handler;
77
78    use super::*;
79    use crate::resource::Controller;
80    use crate::resource::CreateOrUpdate;
81    use crate::resource::GetState;
82    use crate::resource::Stop;
83
84    // Consider upstreaming this into `hyperactor` -- lightweight handler definitions
85    // can be quite useful.
86    macro_rules! handler {
87        (
88            $actor:path,
89            $(
90                $name:ident: $msg:ty => $body:expr
91            ),* $(,)?
92        ) => {
93            $(
94                #[async_trait::async_trait]
95                impl Handler<$msg> for $actor {
96                    async fn handle(
97                        &mut self,
98                        #[allow(unused_variables)]
99                        cx: & Context<Self>,
100                        $name: $msg
101                    ) -> anyhow::Result<()> {
102                        $body
103                    }
104                }
105            )*
106        };
107    }
108
109    #[derive(Debug, Named, Serialize, Deserialize)]
110    struct TestMesh;
111
112    impl Mesh for TestMesh {
113        type Spec = ();
114        type State = ();
115    }
116
117    #[derive(Debug, Default, Named, Serialize, Deserialize)]
118    struct TestMeshController;
119
120    impl Actor for TestMeshController {}
121
122    // Ensure that TestMeshController conforms to the Controller behavior for TestMesh.
123    handler! {
124        TestMeshController,
125        _message: CreateOrUpdate<Spec<()>> => unimplemented!(),
126        _message: GetState<State<()>> => unimplemented!(),
127        _message: Stop => unimplemented!(),
128    }
129
130    hyperactor::assert_behaves!(TestMeshController as Controller<TestMesh>);
131
132    #[test]
133    fn test_state_serialize_and_deserialize_with_bincode() {
134        let region: ndslice::Region = ndslice::extent!(x = 5).into();
135        let num_ranks = region.num_ranks();
136        let data = State {
137            statuses: ValueMesh::new(region, vec![Status::Running; num_ranks]).unwrap(),
138            state: 0,
139        };
140        let encoded = bincode::serialize(&data).expect("serialization failed");
141        let decoded: State<i32> = bincode::deserialize(&encoded).expect("deserialization failed");
142        assert_eq!(decoded.state, data.state);
143        assert_eq!(decoded.statuses, data.statuses);
144    }
145}