1#![allow(dead_code)] use std::collections::HashMap;
12
13use anyhow::Context;
14use anyhow::Result;
15use anyhow::ensure;
16use itertools::Itertools;
17use itertools::izip;
18use ndslice::Slice;
19use pyo3::pyclass;
20use serde::Deserialize;
21use serde::Serialize;
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
33#[pyclass(frozen, module = "monarch_tensor_worker._internal")]
34#[pyo3(get_all)]
35pub struct Dim {
36 name: String,
38 rank: usize,
40 size: usize,
42 members: Vec<usize>,
45}
46
47impl Dim {
48 pub fn members(&self) -> &[usize] {
49 &self.members
50 }
51 pub fn size(&self) -> usize {
52 self.size
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
59#[pyclass(frozen, module = "monarch_tensor_worker._internal")]
60#[pyo3(get_all)]
61pub struct DeviceMesh {
62 dims: HashMap<String, Dim>,
64 all_ranks: Vec<usize>,
66}
67
68impl DeviceMesh {
69 pub fn new(names: Vec<String>, ranks: Slice, rank: usize) -> Result<Self> {
72 let mut dims = HashMap::new();
73 let coordinates = ranks.coordinates(rank)?;
74
75 ensure!(
78 names.len() == coordinates.len(),
79 "names and coordinates mismatch in length: names={names:#?}, coordinates={coordinates:#?}",
80 );
81 for (coordinate, name, size, stride) in
82 izip!(coordinates, names, ranks.sizes(), ranks.strides())
83 {
84 let start = rank - stride * coordinate;
85 let members: Vec<usize> = (start..start + stride * size).step_by(*stride).collect();
86 assert_eq!(members[coordinate], rank);
87 dims.insert(
88 name.clone(),
89 Dim {
90 name,
91 rank: coordinate,
92 size: *size,
93 members,
94 },
95 );
96 }
97
98 Ok(Self {
99 dims,
100 all_ranks: ranks.iter().collect(),
101 })
102 }
103
104 pub fn ranks(&self) -> HashMap<String, usize> {
106 self.dims.iter().map(|(n, d)| (n.clone(), d.rank)).collect()
107 }
108
109 pub fn sizes(&self) -> HashMap<String, usize> {
111 self.dims.iter().map(|(n, d)| (n.clone(), d.size)).collect()
112 }
113
114 pub fn dim(&self, name: &str) -> Option<&Dim> {
115 self.dims.get(name)
116 }
117
118 pub fn get_ranks_for_dim_slice(&self, names: &[String]) -> Result<Vec<usize>> {
121 if names.is_empty() {
123 return Ok(vec![]);
124 }
125
126 if let [name] = names {
128 return Ok(self
129 .dims
130 .get(name)
131 .with_context(|| format!("no dim with name {}", name))?
132 .members
133 .clone());
134 }
135
136 let dims: Vec<&Dim> = names
138 .iter()
139 .map(|n| {
140 self.dims
141 .get(n)
142 .with_context(|| format!("no dim with name {}", n))
143 })
144 .collect::<Result<Vec<_>, _>>()?;
145
146 let strides: Vec<usize> = dims
148 .iter()
149 .map(|d| match d.members.as_slice() {
150 [d0, d1, ..] => d1 - d0,
151 _ => 0,
152 })
153 .collect();
154
155 let start = dims[0].members[dims[0].rank]
157 - dims
158 .iter()
159 .zip(&strides)
160 .map(|(d, &s)| s * d.rank)
161 .sum::<usize>();
162
163 Ok(dims
165 .iter()
166 .map(|d| 0..d.size)
167 .multi_cartesian_product()
168 .map(|idxs| {
169 start
170 + idxs
171 .into_iter()
172 .zip(&strides)
173 .map(|(i, &s)| i * s)
174 .sum::<usize>()
175 })
176 .collect())
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn basic() {
186 let names = vec!["x".to_string(), "y".to_string()];
187 let ranks = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
188 let mesh = DeviceMesh::new(names, ranks, 1).unwrap();
189 assert_eq!(mesh.dims.len(), 2);
190 assert_eq!(mesh.all_ranks.len(), 6);
191 assert_eq!(mesh.dims["x"].rank, 0);
192 assert_eq!(mesh.dims["x"].members, vec![1, 4]);
193 assert_eq!(mesh.dims["y"].rank, 1);
194 assert_eq!(mesh.dims["y"].members, vec![0, 1, 2]);
195 }
196
197 #[test]
198 fn get_ranks_for_dim_slice() -> Result<()> {
199 let names = vec!["x".to_string(), "y".to_string()];
201 let ranks = Slice::new(0, vec![2, 3], vec![3, 1])?;
202 let mesh = DeviceMesh::new(names, ranks, 1)?;
203 assert!(mesh.get_ranks_for_dim_slice(&[])?.is_empty());
204 assert_eq!(
205 mesh.get_ranks_for_dim_slice(&["y".to_string()])?,
206 mesh.dims["y"].members,
207 );
208 assert_eq!(
209 mesh.get_ranks_for_dim_slice(&["x".to_string(), "y".to_string()])?,
210 mesh.all_ranks,
211 );
212
213 let names = vec!["x".to_string(), "y".to_string(), "z".to_string()];
215 let ranks = Slice::new(0, vec![2, 2, 2], vec![4, 2, 1])?;
216 let mesh = DeviceMesh::new(names, ranks, 1)?;
217 assert!(mesh.get_ranks_for_dim_slice(&[])?.is_empty());
218 assert_eq!(
219 mesh.get_ranks_for_dim_slice(&["x".to_string()])?,
220 vec![1, 5],
221 );
222 assert_eq!(
223 mesh.get_ranks_for_dim_slice(&["y".to_string()])?,
224 vec![1, 3],
225 );
226 assert_eq!(
227 mesh.get_ranks_for_dim_slice(&["z".to_string()])?,
228 vec![0, 1],
229 );
230 assert_eq!(
231 mesh.get_ranks_for_dim_slice(&["x".to_string(), "y".to_string()])?,
232 vec![1, 3, 5, 7],
233 );
234 assert_eq!(
235 mesh.get_ranks_for_dim_slice(&["y".to_string(), "z".to_string()])?,
236 vec![0, 1, 2, 3],
237 );
238 assert_eq!(
239 mesh.get_ranks_for_dim_slice(&["x".to_string(), "y".to_string(), "z".to_string()])?,
240 vec![0, 1, 2, 3, 4, 5, 6, 7]
241 );
242
243 Ok(())
244 }
245}