hyperactor_mesh/
mesh_selection.rs1#![allow(dead_code)]
11
12use ndslice::Selection;
13use ndslice::shape::Range;
14use ndslice::shape::Shape;
15
16#[derive(Clone, Copy)]
18pub enum MappingMode {
19 BroadcastMissing,
22 ExactMatch,
24}
25
26#[derive(Clone, Copy)]
28pub enum AlignPolicy {
29 Broadcast,
31 Mapped {
33 mode: MappingMode,
35 },
36}
37
38fn verify_dimension_sizes(origin: &Shape, target: &Shape) -> Result<(), anyhow::Error> {
40 for label in origin.labels() {
41 if let Some(target_pos) = target.labels().iter().position(|l| l == label) {
42 let origin_pos = origin.labels().iter().position(|l| l == label).unwrap();
43 let origin_size = origin.slice().sizes()[origin_pos];
44 let target_size = target.slice().sizes()[target_pos];
45
46 if origin_size != target_size {
47 return Err(anyhow::Error::msg(format!(
48 "dimension {} has mismatched sizes: {} vs {}",
49 label, origin_size, target_size
50 )));
51 }
52 }
53 }
54 Ok(())
55}
56
57fn exact_mapping<'a>(
59 target_selection: &Selection,
60 coords: &'a [(String, usize)],
61 origin_rank: &usize,
62) -> Result<Vec<(&'a str, Range)>, anyhow::Error> {
63 let coord_dim = coords.iter().map(|(_, d)| *d).collect::<Vec<_>>();
64 if target_selection.contains(&coord_dim) {
65 Ok(coords
66 .iter()
67 .map(|(label, index)| (label.as_str(), Range::from(*index)))
68 .collect::<Vec<(&'a str, Range)>>())
69 } else {
70 Err(anyhow::Error::msg(format!(
71 "origin rank {} is not selected in target",
72 origin_rank
73 )))
74 }
75}