hyperactor_mesh/
mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use async_trait::async_trait;
10use hyperactor::RemoteMessage;
11use ndslice::Range;
12use ndslice::Shape;
13use ndslice::ShapeError;
14use ndslice::SliceIterator;
15
16/// A mesh of nodes, organized into the topology described by its shape (see [`Shape`]).
17#[async_trait]
18pub trait Mesh {
19    /// The type of the node contained in the mesh.
20    type Node;
21
22    /// The type of identifiers for this mesh.
23    type Id: RemoteMessage;
24
25    /// The type of a slice of this mesh. Slices should not outlive their
26    /// parent mesh.
27    type Sliced<'a>: Mesh<Node = Self::Node> + 'a
28    where
29        Self: 'a;
30
31    /// The shape of this mesh.
32    fn shape(&self) -> &Shape;
33
34    /// Sub-slice this mesh, specifying the included ranges for
35    /// the dimension with the labeled name.
36    fn select<R: Into<Range>>(&self, label: &str, range: R)
37    -> Result<Self::Sliced<'_>, ShapeError>;
38
39    /// Retrieve contained node at the provided index. The index is
40    /// relative to the shape of the mesh.
41    fn get(&self, index: usize) -> Option<Self::Node>;
42
43    /// Iterate over all the nodes in this mesh.
44    fn iter(&self) -> MeshIter<'_, Self> {
45        MeshIter {
46            mesh: self,
47            slice_iter: self.shape().slice().iter(),
48        }
49    }
50
51    /// The global identifier for this mesh.
52    fn id(&self) -> Self::Id;
53}
54
55/// An iterator over the nodes of a mesh.
56pub struct MeshIter<'a, M: Mesh + ?Sized> {
57    mesh: &'a M,
58    slice_iter: SliceIterator,
59}
60
61impl<M: Mesh> Iterator for MeshIter<'_, M> {
62    type Item = M::Node;
63
64    fn next(&mut self) -> Option<Self::Item> {
65        self.slice_iter
66            .next()
67            .map(|index| self.mesh.get(index).unwrap())
68    }
69}