monarch_tensor_worker/
device_mesh.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![allow(dead_code)] // Temporary, until code is exercised by the worker.
10
11use std::collections::HashMap;
12
13use anyhow::Context;
14use anyhow::Result;
15use anyhow::ensure;
16use itertools::Itertools;
17use itertools::izip;
18use ndslice::Slice;
19use pyo3::pyclass;
20use serde::Deserialize;
21use serde::Serialize;
22
23/// A single dimension in a [`DeviceMesh`], relative to a specific rank.
24/// Each dimension is the set of ranks that share all _other_ mesh coordinates
25/// with the owning rank while ranging over the dimension's coordinates.
26///
27/// For example, if a mesh has 3 dimensions, and the owning rank's coordinates
28/// are `(1, 2, 3)`, then:
29/// * [`Dim`] 0 will be `(*, 2, 3)`
30/// * [`Dim`] 1 will be `(1, *, 3)`
31/// * [`Dim`] 2 will be `(1, 2, *)`
32#[derive(Debug, Clone, Serialize, Deserialize)]
33#[pyclass(frozen, module = "monarch_tensor_worker._internal")]
34#[pyo3(get_all)]
35pub struct Dim {
36    /// The name of the dimension.
37    name: String,
38    /// The rank of this worker within the dimension's process group.
39    rank: usize,
40    /// The size of the dimension.
41    size: usize,
42    /// The ordered set of ranks within the dimension. `members[rank]` is always
43    /// equal to the owning rank.
44    members: Vec<usize>,
45}
46
47impl Dim {
48    pub fn members(&self) -> &[usize] {
49        &self.members
50    }
51    pub fn size(&self) -> usize {
52        self.size
53    }
54}
55
56/// A device mesh represents each (named) dimension ([`Dim`]) of a
57/// multi-dimensional mesh, relative to a specific rank.
58#[derive(Debug, Clone, Serialize, Deserialize)]
59#[pyclass(frozen, module = "monarch_tensor_worker._internal")]
60#[pyo3(get_all)]
61pub struct DeviceMesh {
62    /// Each dim in the mesh.
63    dims: HashMap<String, Dim>,
64    /// All ranks (i.e., the full device mesh).
65    all_ranks: Vec<usize>,
66}
67
68impl DeviceMesh {
69    /// Create a new [`DeviceMesh`] with the provided dimension names and
70    /// multi-dimensional slice. `rank` is the owning (self) rank.
71    pub fn new(names: Vec<String>, ranks: Slice, rank: usize) -> Result<Self> {
72        let mut dims = HashMap::new();
73        let coordinates = ranks.coordinates(rank)?;
74
75        // Check that all vecs are the same length.
76        // coordinates == sizes == strides is enforced by Slice.
77        ensure!(
78            names.len() == coordinates.len(),
79            "names and coordinates mismatch in length: names={names:#?}, coordinates={coordinates:#?}",
80        );
81        for (coordinate, name, size, stride) in
82            izip!(coordinates, names, ranks.sizes(), ranks.strides())
83        {
84            let start = rank - stride * coordinate;
85            let members: Vec<usize> = (start..start + stride * size).step_by(*stride).collect();
86            assert_eq!(members[coordinate], rank);
87            dims.insert(
88                name.clone(),
89                Dim {
90                    name,
91                    rank: coordinate,
92                    size: *size,
93                    members,
94                },
95            );
96        }
97
98        Ok(Self {
99            dims,
100            all_ranks: ranks.iter().collect(),
101        })
102    }
103
104    /// Return a dict of dimension names to their corresponding ranks.
105    pub fn ranks(&self) -> HashMap<String, usize> {
106        self.dims.iter().map(|(n, d)| (n.clone(), d.rank)).collect()
107    }
108
109    /// Return a dict of dimension names to their corresponding size.
110    pub fn sizes(&self) -> HashMap<String, usize> {
111        self.dims.iter().map(|(n, d)| (n.clone(), d.size)).collect()
112    }
113
114    pub fn dim(&self, name: &str) -> Option<&Dim> {
115        self.dims.get(name)
116    }
117
118    /// Return all the ranks that participate in collectives across the given
119    /// dim names.
120    pub fn get_ranks_for_dim_slice(&self, names: &[String]) -> Result<Vec<usize>> {
121        // Early returns for empty case.
122        if names.is_empty() {
123            return Ok(vec![]);
124        }
125
126        // Early return for single dimension cases
127        if let [name] = names {
128            return Ok(self
129                .dims
130                .get(name)
131                .with_context(|| format!("no dim with name {}", name))?
132                .members
133                .clone());
134        }
135
136        // Get all the dimensions
137        let dims: Vec<&Dim> = names
138            .iter()
139            .map(|n| {
140                self.dims
141                    .get(n)
142                    .with_context(|| format!("no dim with name {}", n))
143            })
144            .collect::<Result<Vec<_>, _>>()?;
145
146        // Calculate strides
147        let strides: Vec<usize> = dims
148            .iter()
149            .map(|d| match d.members.as_slice() {
150                [d0, d1, ..] => d1 - d0,
151                _ => 0,
152            })
153            .collect();
154
155        // Calculate start value
156        let start = dims[0].members[dims[0].rank]
157            - dims
158                .iter()
159                .zip(&strides)
160                .map(|(d, &s)| s * d.rank)
161                .sum::<usize>();
162
163        // Generate all combinations of indices and calculate ranks
164        Ok(dims
165            .iter()
166            .map(|d| 0..d.size)
167            .multi_cartesian_product()
168            .map(|idxs| {
169                start
170                    + idxs
171                        .into_iter()
172                        .zip(&strides)
173                        .map(|(i, &s)| i * s)
174                        .sum::<usize>()
175            })
176            .collect())
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn basic() {
186        let names = vec!["x".to_string(), "y".to_string()];
187        let ranks = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
188        let mesh = DeviceMesh::new(names, ranks, 1).unwrap();
189        assert_eq!(mesh.dims.len(), 2);
190        assert_eq!(mesh.all_ranks.len(), 6);
191        assert_eq!(mesh.dims["x"].rank, 0);
192        assert_eq!(mesh.dims["x"].members, vec![1, 4]);
193        assert_eq!(mesh.dims["y"].rank, 1);
194        assert_eq!(mesh.dims["y"].members, vec![0, 1, 2]);
195    }
196
197    #[test]
198    fn get_ranks_for_dim_slice() -> Result<()> {
199        // 2D mesh test (2x2)
200        let names = vec!["x".to_string(), "y".to_string()];
201        let ranks = Slice::new(0, vec![2, 3], vec![3, 1])?;
202        let mesh = DeviceMesh::new(names, ranks, 1)?;
203        assert!(mesh.get_ranks_for_dim_slice(&[])?.is_empty());
204        assert_eq!(
205            mesh.get_ranks_for_dim_slice(&["y".to_string()])?,
206            mesh.dims["y"].members,
207        );
208        assert_eq!(
209            mesh.get_ranks_for_dim_slice(&["x".to_string(), "y".to_string()])?,
210            mesh.all_ranks,
211        );
212
213        // 3D mesh test (2x2x2)
214        let names = vec!["x".to_string(), "y".to_string(), "z".to_string()];
215        let ranks = Slice::new(0, vec![2, 2, 2], vec![4, 2, 1])?;
216        let mesh = DeviceMesh::new(names, ranks, 1)?;
217        assert!(mesh.get_ranks_for_dim_slice(&[])?.is_empty());
218        assert_eq!(
219            mesh.get_ranks_for_dim_slice(&["x".to_string()])?,
220            vec![1, 5],
221        );
222        assert_eq!(
223            mesh.get_ranks_for_dim_slice(&["y".to_string()])?,
224            vec![1, 3],
225        );
226        assert_eq!(
227            mesh.get_ranks_for_dim_slice(&["z".to_string()])?,
228            vec![0, 1],
229        );
230        assert_eq!(
231            mesh.get_ranks_for_dim_slice(&["x".to_string(), "y".to_string()])?,
232            vec![1, 3, 5, 7],
233        );
234        assert_eq!(
235            mesh.get_ranks_for_dim_slice(&["y".to_string(), "z".to_string()])?,
236            vec![0, 1, 2, 3],
237        );
238        assert_eq!(
239            mesh.get_ranks_for_dim_slice(&["x".to_string(), "y".to_string(), "z".to_string()])?,
240            vec![0, 1, 2, 3, 4, 5, 6, 7]
241        );
242
243        Ok(())
244    }
245}