hyperactor_mesh/
mesh_selection.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9// until used publically
10#![allow(dead_code)]
11
12use ndslice::Selection;
13use ndslice::shape::Range;
14use ndslice::shape::Shape;
15
16/// Specifies how to handle dimensions that exist in one mesh but not the other
17#[derive(Clone, Copy)]
18pub enum MappingMode {
19    /// Broadcast to any dimensions that exist only in the target mesh
20    /// Dimensions only in origin mesh are handled by rank iteration
21    BroadcastMissing,
22    /// Error if any dimensions exist in only one of the meshes
23    ExactMatch,
24}
25
26/// Describes how two meshes should be aligned
27#[derive(Clone, Copy)]
28pub enum AlignPolicy {
29    /// Creates a mapping from each origin rank to the entire target mesh
30    Broadcast,
31    /// Maps matching dimensions 1:1 and handles missing dimensions according to MappingMode
32    Mapped {
33        /// Specifies how to handle dimensions that exist in one mesh but not the other
34        mode: MappingMode,
35    },
36}
37
38/// Verifies that any matching dimensions between origin and target shapes have equal sizes
39fn verify_dimension_sizes(origin: &Shape, target: &Shape) -> Result<(), anyhow::Error> {
40    for label in origin.labels() {
41        if let Some(target_pos) = target.labels().iter().position(|l| l == label) {
42            let origin_pos = origin.labels().iter().position(|l| l == label).unwrap();
43            let origin_size = origin.slice().sizes()[origin_pos];
44            let target_size = target.slice().sizes()[target_pos];
45
46            if origin_size != target_size {
47                return Err(anyhow::Error::msg(format!(
48                    "dimension {} has mismatched sizes: {} vs {}",
49                    label, origin_size, target_size
50                )));
51            }
52        }
53    }
54    Ok(())
55}
56
57/// Given a set of coordinates, create a set of single-value ranges for each dimension..
58fn exact_mapping<'a>(
59    target_selection: &Selection,
60    coords: &'a [(String, usize)],
61    origin_rank: &usize,
62) -> Result<Vec<(&'a str, Range)>, anyhow::Error> {
63    let coord_dim = coords.iter().map(|(_, d)| *d).collect::<Vec<_>>();
64    if target_selection.contains(&coord_dim) {
65        Ok(coords
66            .iter()
67            .map(|(label, index)| (label.as_str(), Range::from(*index)))
68            .collect::<Vec<(&'a str, Range)>>())
69    } else {
70        Err(anyhow::Error::msg(format!(
71            "origin rank {} is not selected in target",
72            origin_rank
73        )))
74    }
75}