1use std::str::FromStr;
33use std::sync::Arc;
34
35use serde::Deserialize;
36use serde::Serialize;
37use thiserror::Error;
38
39use crate::Range;
40use crate::Shape; use crate::Slice;
42use crate::SliceError;
43use crate::SliceIterator;
44use crate::parse::Parser;
45use crate::parse::ParserError;
46
47#[derive(Debug, thiserror::Error)]
49pub enum ExtentError {
50 #[error("label/sizes dimension mismatch: {num_labels} != {num_sizes}")]
56 DimMismatch {
57 num_labels: usize,
59 num_sizes: usize,
61 },
62
63 #[error("overlapping label found: {label}")]
68 OverlappingLabel {
69 label: String,
71 },
72}
73
74#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
82pub struct Extent {
83 inner: Arc<ExtentData>,
84}
85
86fn _assert_extent_traits()
87where
88 Extent: Send + Sync + 'static,
89{
90}
91
92#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
97struct ExtentData {
98 labels: Vec<String>,
99 sizes: Vec<usize>,
100}
101
102impl From<&Shape> for Extent {
103 fn from(s: &Shape) -> Self {
104 Extent::new(s.labels().to_vec(), s.slice().sizes().to_vec()).unwrap()
106 }
107}
108
109impl From<Shape> for Extent {
110 fn from(s: Shape) -> Self {
111 Extent::from(&s)
112 }
113}
114
115impl Extent {
116 pub fn new(labels: Vec<String>, sizes: Vec<usize>) -> Result<Self, ExtentError> {
118 if labels.len() != sizes.len() {
119 return Err(ExtentError::DimMismatch {
120 num_labels: labels.len(),
121 num_sizes: sizes.len(),
122 });
123 }
124
125 Ok(Self {
126 inner: Arc::new(ExtentData { labels, sizes }),
127 })
128 }
129
130 pub fn unity() -> Extent {
131 Extent::new(vec![], vec![]).unwrap()
132 }
133
134 pub fn labels(&self) -> &[String] {
136 &self.inner.labels
137 }
138
139 pub fn sizes(&self) -> &[usize] {
141 &self.inner.sizes
142 }
143
144 pub fn size(&self, label: &str) -> Option<usize> {
147 self.position(label).map(|pos| self.sizes()[pos])
148 }
149
150 pub fn position(&self, label: &str) -> Option<usize> {
153 self.labels().iter().position(|l| l == label)
154 }
155
156 pub fn rank_of_coords(&self, coords: &[usize]) -> Result<usize, PointError> {
166 let sizes = self.sizes();
167 if coords.len() != sizes.len() {
168 return Err(PointError::DimMismatch {
169 expected: sizes.len(),
170 actual: coords.len(),
171 });
172 }
173 let mut stride = 1;
174 let mut result = 0;
175 for (&c, &size) in coords.iter().rev().zip(sizes.iter().rev()) {
176 if c >= size {
177 return Err(PointError::OutOfRangeIndex { size, index: c });
178 }
179 result += c * stride;
180 stride *= size;
181 }
182 Ok(result)
183 }
184
185 pub fn point(&self, coords: Vec<usize>) -> Result<Point, PointError> {
231 Ok(Point {
232 rank: self.rank_of_coords(&coords)?,
233 extent: self.clone(),
234 })
235 }
236
237 pub fn point_of_rank(&self, rank: usize) -> Result<Point, PointError> {
261 let total = self.num_ranks();
262 if rank >= total {
263 return Err(PointError::OutOfRangeRank { total, rank });
264 }
265 Ok(Point {
266 rank,
267 extent: self.clone(),
268 })
269 }
270
271 pub fn len(&self) -> usize {
276 self.sizes().len()
277 }
278
279 pub fn is_empty(&self) -> bool {
284 self.sizes().is_empty()
285 }
286
287 pub fn num_ranks(&self) -> usize {
292 self.sizes().iter().product()
293 }
294
295 pub fn into_inner(self) -> (Vec<String>, Vec<usize>) {
297 match Arc::try_unwrap(self.inner) {
298 Ok(data) => (data.labels, data.sizes),
299 Err(shared) => (shared.labels.clone(), shared.sizes.clone()),
300 }
301 }
302
303 pub fn to_slice(&self) -> Slice {
305 Slice::new_row_major(self.sizes())
306 }
307
308 pub fn iter(&self) -> impl Iterator<Item = (String, usize)> + use<'_> {
310 self.labels()
311 .iter()
312 .zip(self.sizes().iter())
313 .map(|(l, s)| (l.clone(), *s))
314 }
315
316 pub fn points(&self) -> ExtentPointsIterator<'_> {
318 ExtentPointsIterator::new(self)
319 }
320
321 pub fn concat(&self, other: &Extent) -> Result<Self, ExtentError> {
329 use std::collections::HashSet;
330 let lhs: HashSet<&str> = self.labels().iter().map(|s| s.as_str()).collect();
332 if let Some(dup) = other.labels().iter().find(|l| lhs.contains(l.as_str())) {
333 return Err(ExtentError::OverlappingLabel { label: dup.clone() });
334 }
335 let mut labels = self.labels().to_vec();
337 let mut sizes = self.sizes().to_vec();
338 labels.reserve(other.labels().len());
339 sizes.reserve(other.sizes().len());
340 labels.extend(other.labels().iter().cloned());
341 sizes.extend(other.sizes().iter().copied());
342 Extent::new(labels, sizes)
343 }
344}
345
346mod labels {
358 pub(super) fn is_safe_ident(s: &str) -> bool {
362 s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
363 }
364
365 pub(super) fn fmt_label(s: &str) -> String {
370 if is_safe_ident(s) {
371 s.to_string()
372 } else {
373 format!("{:?}", s)
374 }
375 }
376}
377
378impl std::fmt::Display for Extent {
414 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415 write!(f, "{{")?;
416 for i in 0..self.sizes().len() {
417 write!(
418 f,
419 "{}: {}",
420 labels::fmt_label(&self.labels()[i]),
421 self.sizes()[i]
422 )?;
423 if i + 1 != self.sizes().len() {
424 write!(f, ", ")?;
425 }
426 }
427 write!(f, "}}")
428 }
429}
430
431pub struct ExtentPointsIterator<'a> {
433 extent: &'a Extent,
434 next_rank: usize,
435}
436
437impl<'a> ExtentPointsIterator<'a> {
438 pub fn new(extent: &'a Extent) -> Self {
439 Self {
440 extent,
441 next_rank: 0,
442 }
443 }
444}
445
446impl<'a> Iterator for ExtentPointsIterator<'a> {
447 type Item = Point;
448
449 fn next(&mut self) -> Option<Self::Item> {
452 if self.next_rank == self.extent.num_ranks() {
453 return None;
454 }
455
456 let p = Point {
457 rank: self.next_rank,
458 extent: self.extent.clone(),
459 };
460 self.next_rank += 1;
461 Some(p)
462 }
463}
464
465#[derive(Debug, Error)]
467pub enum PointError {
468 #[error("dimension mismatch: expected {expected}, got {actual}")]
475 DimMismatch {
476 expected: usize,
478 actual: usize,
480 },
481
482 #[error("out of range: total ranks {total}; does not contain rank {rank}")]
489 OutOfRangeRank {
490 total: usize,
492 rank: usize,
494 },
495
496 #[error("out of range: dim size {size}; does not contain index {index}")]
504 OutOfRangeIndex {
505 size: usize,
507 index: usize,
509 },
510
511 #[error("failed to parse point: {reason}")]
513 ParseError { reason: String },
514}
515
516#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
549pub struct Point {
550 rank: usize,
551 extent: Extent,
552}
553
554pub struct CoordIter<'a> {
574 sizes: &'a [usize],
575 rank: usize,
576 stride: usize,
577 axis: usize,
578}
579
580impl<'a> Iterator for CoordIter<'a> {
581 type Item = usize;
582
583 fn next(&mut self) -> Option<Self::Item> {
588 if self.axis >= self.sizes.len() {
589 return None;
590 }
591 self.stride /= self.sizes[self.axis];
592 let q = self.rank / self.stride;
593 self.rank %= self.stride;
594 self.axis += 1;
595 Some(q)
596 }
597
598 fn size_hint(&self) -> (usize, Option<usize>) {
604 let rem = self.sizes.len().saturating_sub(self.axis);
605 (rem, Some(rem))
606 }
607}
608
609impl ExactSizeIterator for CoordIter<'_> {}
610
611impl<'a> IntoIterator for &'a Point {
612 type Item = usize;
613 type IntoIter = CoordIter<'a>;
614
615 fn into_iter(self) -> Self::IntoIter {
634 self.coords_iter()
635 }
636}
637
638fn _assert_point_traits()
639where
640 Point: Send + Sync + 'static,
641{
642}
643
644pub trait InExtent {
660 fn in_(self, extent: &Extent) -> Result<Point, PointError>;
661}
662
663impl InExtent for Vec<usize> {
664 fn in_(self, extent: &Extent) -> Result<Point, PointError> {
669 extent.point(self)
670 }
671}
672
673impl Point {
674 pub fn coords_iter(&self) -> CoordIter<'_> {
675 CoordIter {
676 sizes: self.extent.sizes(),
677 rank: self.rank,
678 stride: self.extent.sizes().iter().product(),
679 axis: 0,
680 }
681 }
682
683 pub fn coord(&self, i: usize) -> usize {
701 self.coords_iter()
702 .nth(i)
703 .expect("coord(i): axis out of bounds")
704 }
705
706 pub fn coords(&self) -> Vec<usize> {
721 self.coords_iter().collect()
722 }
723
724 pub fn rank(&self) -> usize {
727 self.rank
728 }
729
730 pub fn extent(&self) -> &Extent {
733 &self.extent
734 }
735
736 pub fn len(&self) -> usize {
751 self.extent.len()
752 }
753
754 pub fn is_empty(&self) -> bool {
770 self.extent.len() == 0
771 }
772
773 pub fn format_as_dict(&self) -> String {
785 format!(
786 "{{{}}}",
787 self.extent()
788 .labels()
789 .iter()
790 .zip(self.coords_iter())
791 .zip(self.extent().sizes())
792 .map(|((label, coord), size)| format!("'{}': {}/{}", label, coord, size))
793 .collect::<Vec<_>>()
794 .join(", ")
795 )
796 }
797}
798
799impl std::fmt::Display for Point {
835 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
836 let labels = self.extent.labels();
837 let sizes = self.extent.sizes();
838 let coords = self.coords();
839
840 for i in 0..labels.len() {
841 write!(
842 f,
843 "{}={}/{}",
844 labels::fmt_label(&labels[i]),
845 coords[i],
846 sizes[i]
847 )?;
848 if i + 1 != labels.len() {
849 write!(f, ",")?;
850 }
851 }
852 Ok(())
853 }
854}
855
856impl FromStr for Point {
857 type Err = PointError;
858
859 fn from_str(s: &str) -> Result<Self, Self::Err> {
860 let s = s.trim();
861
862 if s.is_empty() {
863 let empty_extent = Extent::unity();
864 return empty_extent.point(vec![]);
865 }
866
867 let mut labels = Vec::new();
868 let mut coords = Vec::new();
869 let mut sizes = Vec::new();
870
871 let mut chars = s.chars().peekable();
872
873 while chars.peek().is_some() {
874 while chars.peek() == Some(&' ') {
875 chars.next();
876 }
877
878 if chars.peek().is_none() {
879 break;
880 }
881
882 let label = if chars.peek() == Some(&'"') {
883 chars.next(); let mut label = String::new();
885 let mut escaped = false;
886
887 for ch in chars.by_ref() {
889 if escaped {
890 match ch {
891 '"' => label.push('"'),
892 '\\' => label.push('\\'),
893 _ => {
894 label.push('\\');
895 label.push(ch);
896 }
897 }
898 escaped = false;
899 } else if ch == '\\' {
900 escaped = true;
901 } else if ch == '"' {
902 break;
903 } else {
904 label.push(ch);
905 }
906 }
907
908 if label.is_empty() {
909 return Err(PointError::ParseError {
910 reason: "empty quoted label".to_string(),
911 });
912 }
913
914 label
915 } else {
916 let mut label = String::new();
917 while let Some(&ch) = chars.peek() {
918 if ch == '=' || ch == ' ' {
919 break;
920 }
921 label.push(chars.next().unwrap());
922 }
923
924 if label.is_empty() {
925 return Err(PointError::ParseError {
926 reason: "missing label".to_string(),
927 });
928 }
929
930 label
931 };
932
933 while chars.peek() == Some(&' ') {
934 chars.next();
935 }
936
937 if chars.next() != Some('=') {
938 return Err(PointError::ParseError {
939 reason: format!("expected '=' after label '{}'", label),
940 });
941 }
942
943 while chars.peek() == Some(&' ') {
944 chars.next();
945 }
946
947 let mut coord = String::new();
948 while let Some(&ch) = chars.peek() {
949 if ch == '/' || ch == ' ' {
950 break;
951 }
952 coord.push(chars.next().unwrap());
953 }
954
955 if coord.is_empty() {
956 return Err(PointError::ParseError {
957 reason: format!("missing coordinate for dimension '{}'", label),
958 });
959 }
960
961 while chars.peek() == Some(&' ') {
962 chars.next();
963 }
964
965 if chars.next() != Some('/') {
966 return Err(PointError::ParseError {
967 reason: format!("expected '/' after coordinate for dimension '{}'", label),
968 });
969 }
970
971 while chars.peek() == Some(&' ') {
972 chars.next();
973 }
974
975 let mut size = String::new();
976 while let Some(&ch) = chars.peek() {
977 if ch == ',' || ch == ' ' {
978 break;
979 }
980 size.push(chars.next().unwrap());
981 }
982
983 if size.is_empty() {
984 return Err(PointError::ParseError {
985 reason: format!("missing size for dimension '{}'", label),
986 });
987 }
988
989 let coord = coord.parse::<usize>().map_err(|e| PointError::ParseError {
990 reason: format!(
991 "invalid coordinate '{}' for dimension '{}': {}",
992 coord, label, e
993 ),
994 })?;
995
996 let size = size.parse::<usize>().map_err(|e| PointError::ParseError {
997 reason: format!("invalid size '{}' for dimension '{}': {}", size, label, e),
998 })?;
999
1000 labels.push(label);
1001 coords.push(coord);
1002 sizes.push(size);
1003
1004 while chars.peek() == Some(&' ') {
1005 chars.next();
1006 }
1007
1008 if chars.peek() == Some(&',') {
1009 chars.next(); while chars.peek() == Some(&' ') {
1011 chars.next();
1012 }
1013 if chars.peek().is_none() {
1015 return Err(PointError::ParseError {
1016 reason: "trailing comma".to_string(),
1017 });
1018 }
1019 }
1020 }
1021
1022 let extent = Extent::new(labels, sizes).map_err(|e| PointError::ParseError {
1023 reason: format!("failed to create extent: {}", e),
1024 })?;
1025
1026 extent.point(coords)
1027 }
1028}
1029
1030impl typeuri::Named for Point {
1031 fn typename() -> &'static str {
1032 "ndslice::Point"
1033 }
1034}
1035
1036hyperactor_config::impl_attrvalue!(Point);
1037
1038#[derive(Debug, Error)]
1040pub enum ViewError {
1041 #[error("no such dimension: {0}")]
1043 InvalidDim(String),
1044
1045 #[error("empty range: {range} for dimension {dim} of size {size}")]
1047 EmptyRange {
1048 range: Range,
1049 dim: String,
1050 size: usize,
1051 },
1052
1053 #[error(transparent)]
1054 ExtentError(#[from] ExtentError),
1055
1056 #[error("invalid range: selected ranks {selected} not a subset of base {base} ")]
1057 InvalidRange {
1058 base: Box<Region>,
1059 selected: Box<Region>,
1060 },
1061}
1062
1063#[derive(Debug, Error)]
1065pub enum RegionError {
1066 #[error("invalid point: this point does not belong to this region: {0}")]
1067 InvalidPoint(String),
1068
1069 #[error("out of range base rank: this base rank {0} does not belong to this region: {0}")]
1070 OutOfRangeBaseRank(usize, String),
1071}
1072
1073#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
1080pub struct Region {
1081 labels: Vec<String>,
1082 slice: Slice,
1083}
1084
1085impl Region {
1086 #[allow(dead_code)]
1087 fn empty() -> Region {
1088 Region {
1089 labels: Vec::new(),
1090 slice: Slice::new(0, Vec::new(), Vec::new()).unwrap(),
1091 }
1092 }
1093
1094 #[allow(dead_code)]
1097 pub fn new(labels: Vec<String>, slice: Slice) -> Self {
1098 Self { labels, slice }
1099 }
1100
1101 pub fn labels(&self) -> &[String] {
1103 &self.labels
1104 }
1105
1106 pub fn slice(&self) -> &Slice {
1109 &self.slice
1110 }
1111
1112 pub fn into_inner(self) -> (Vec<String>, Slice) {
1114 (self.labels, self.slice)
1115 }
1116
1117 pub fn extent(&self) -> Extent {
1119 Extent::new(self.labels.clone(), self.slice.sizes().to_vec()).unwrap()
1120 }
1121
1122 pub fn is_subset(&self, other: &Region) -> bool {
1125 let mut left = self.slice.iter().peekable();
1126 let mut right = other.slice.iter().peekable();
1127
1128 loop {
1129 match (left.peek(), right.peek()) {
1130 (Some(l), Some(r)) => {
1131 if l < r {
1132 return false;
1133 } else if l == r {
1134 left.next();
1135 right.next();
1136 } else {
1137 right.next();
1139 }
1140 }
1141 (Some(_), None) => return false,
1142 (None, _) => return true,
1143 }
1144 }
1145 }
1146
1147 pub fn remap(&self, target: &Region) -> Option<impl Iterator<Item = usize> + '_> {
1167 if !target.is_subset(self) {
1168 return None;
1169 }
1170
1171 let mut ours = self.slice.iter().enumerate();
1172 let mut theirs = target.slice.iter();
1173
1174 Some(std::iter::from_fn(move || {
1175 let needle = theirs.next()?;
1176 loop {
1177 let (index, value) = ours.next().unwrap();
1178 if value == needle {
1179 break Some(index);
1180 }
1181 }
1182 }))
1183 }
1184
1185 pub fn num_ranks(&self) -> usize {
1187 self.slice.len()
1188 }
1189
1190 pub fn base_rank_of_point(&self, p: Point) -> Result<usize, RegionError> {
1193 if p.extent() != &self.extent() {
1194 return Err(RegionError::InvalidPoint(
1195 "mismatched extent: p must be a point in this region’s extent".to_string(),
1196 ));
1197 }
1198
1199 Ok(self
1200 .slice()
1201 .location(&p.coords())
1202 .expect("should have valid location since extent is checked"))
1203 }
1204
1205 pub fn point_of_base_rank(&self, rank: usize) -> Result<Point, RegionError> {
1208 let coords = self
1209 .slice()
1210 .coordinates(rank)
1211 .map_err(|e| RegionError::OutOfRangeBaseRank(rank, e.to_string()))?;
1212 Ok(self
1213 .extent()
1214 .point(coords)
1215 .expect("should have valid point since coords is from this region"))
1216 }
1217}
1218
1219impl From<Extent> for Region {
1222 fn from(extent: Extent) -> Self {
1223 Region {
1224 labels: extent.labels().to_vec(),
1225 slice: extent.to_slice(),
1226 }
1227 }
1228}
1229
1230impl From<&Shape> for Region {
1231 fn from(s: &Shape) -> Self {
1232 Region {
1233 labels: s.labels().to_vec(),
1234 slice: s.slice().clone(),
1235 }
1236 }
1237}
1238
1239impl From<Shape> for Region {
1240 fn from(s: Shape) -> Self {
1241 Region::from(&s)
1242 }
1243}
1244
1245impl std::fmt::Display for Region {
1285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1286 if self.slice.offset() != 0 {
1287 write!(f, "{}+", self.slice.offset())?;
1288 }
1289 for i in 0..self.labels.len() {
1290 write!(
1291 f,
1292 "{}={}/{}",
1293 labels::fmt_label(&self.labels[i]),
1294 self.slice.sizes()[i],
1295 self.slice.strides()[i]
1296 )?;
1297 if i + 1 != self.labels.len() {
1298 write!(f, ",")?;
1299 }
1300 }
1301 Ok(())
1302 }
1303}
1304
1305#[derive(Debug, thiserror::Error)]
1306pub enum RegionParseError {
1307 #[error(transparent)]
1308 ParserError(#[from] ParserError),
1309
1310 #[error(transparent)]
1311 SliceError(#[from] SliceError),
1312}
1313
1314impl std::str::FromStr for Region {
1335 type Err = RegionParseError;
1336
1337 fn from_str(s: &str) -> Result<Self, Self::Err> {
1338 let mut parser = Parser::new(s, &["+", "=", ",", "/"]);
1339
1340 let offset: usize = if let Ok(offset) = parser.try_parse() {
1341 parser.expect("+")?;
1342 offset
1343 } else {
1344 0
1345 };
1346
1347 let mut labels = Vec::new();
1348 let mut sizes = Vec::new();
1349 let mut strides = Vec::new();
1350
1351 while !parser.is_empty() {
1352 if !labels.is_empty() {
1353 parser.expect(",")?;
1354 }
1355
1356 let label = if parser.peek_char() == Some('"') {
1358 parser.parse_string_literal()?
1359 } else {
1360 parser.next_or_err("label")?.to_string()
1361 };
1362 labels.push(label);
1363
1364 parser.expect("=")?;
1365 sizes.push(parser.try_parse()?);
1366 parser.expect("/")?;
1367 strides.push(parser.try_parse()?);
1368 }
1369
1370 Ok(Region {
1371 labels,
1372 slice: Slice::new(offset, sizes, strides)?,
1373 })
1374 }
1375}
1376
1377pub trait BuildFromRegion<T>: Sized {
1390 type Error;
1391
1392 fn build_dense(region: Region, values: Vec<T>) -> Result<Self, Self::Error>;
1394
1395 fn build_dense_unchecked(region: Region, values: Vec<T>) -> Self;
1397}
1398
1399pub trait BuildFromRegionIndexed<T>: Sized {
1413 type Error;
1414
1415 fn build_indexed(
1418 region: Region,
1419 pairs: impl IntoIterator<Item = (usize, T)>,
1420 ) -> Result<Self, Self::Error>;
1421}
1422
1423pub trait CollectMeshExt<T>: Iterator<Item = T> + Sized {
1437 fn collect_mesh<M>(self, region: Region) -> Result<M, M::Error>
1438 where
1439 M: BuildFromRegion<T>;
1440}
1441
1442impl<I, T> CollectMeshExt<T> for I
1445where
1446 I: Iterator<Item = T> + Sized,
1447{
1448 fn collect_mesh<M>(self, region: Region) -> Result<M, M::Error>
1449 where
1450 M: BuildFromRegion<T>,
1451 {
1452 M::build_dense(region, self.collect())
1453 }
1454}
1455
1456#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1458pub struct InvalidCardinality {
1459 pub expected: usize,
1460 pub actual: usize,
1461}
1462
1463pub trait CollectExactMeshExt<T>: ExactSizeIterator<Item = T> + Sized {
1470 fn collect_exact_mesh<M>(self, region: Region) -> Result<M, M::Error>
1471 where
1472 M: BuildFromRegion<T>,
1473 M::Error: From<InvalidCardinality>;
1474}
1475
1476impl<I, T> CollectExactMeshExt<T> for I
1479where
1480 I: ExactSizeIterator<Item = T> + Sized,
1481{
1482 fn collect_exact_mesh<M>(self, region: Region) -> Result<M, M::Error>
1483 where
1484 M: BuildFromRegion<T>,
1485 M::Error: From<InvalidCardinality>,
1486 {
1487 let expected = region.num_ranks();
1488 let actual = self.len();
1489 if actual != expected {
1490 return Err(M::Error::from(InvalidCardinality { expected, actual }));
1491 }
1492 Ok(M::build_dense_unchecked(region, self.collect()))
1493 }
1494}
1495
1496pub trait CollectIndexedMeshExt<T>: Iterator<Item = (usize, T)> + Sized {
1518 fn collect_indexed<M>(self, region: Region) -> Result<M, M::Error>
1519 where
1520 M: BuildFromRegionIndexed<T>;
1521}
1522
1523impl<I, T> CollectIndexedMeshExt<T> for I
1526where
1527 I: Iterator<Item = (usize, T)> + Sized,
1528{
1529 #[inline]
1530 fn collect_indexed<M>(self, region: Region) -> Result<M, M::Error>
1531 where
1532 M: BuildFromRegionIndexed<T>,
1533 {
1534 M::build_indexed(region, self)
1535 }
1536}
1537
1538pub trait MapIntoExt: Ranked {
1540 fn map_into<M, U>(&self, f: impl Fn(&Self::Item) -> U) -> M
1541 where
1542 Self: Sized,
1543 M: BuildFromRegion<U>,
1544 {
1545 let region = self.region().clone();
1546 let n = region.num_ranks();
1547 let values: Vec<U> = (0..n).map(|i| f(self.get(i).unwrap())).collect();
1548 M::build_dense_unchecked(region, values)
1549 }
1550
1551 fn try_map_into<M, U, E>(self, f: impl Fn(&Self::Item) -> Result<U, E>) -> Result<M, E>
1552 where
1553 Self: Sized,
1554 M: BuildFromRegion<U>,
1555 {
1556 let region = self.region().clone();
1557 let n = region.num_ranks();
1558 let mut out = Vec::with_capacity(n);
1559 for i in 0..n {
1560 out.push(f(self.get(i).unwrap())?);
1561 }
1562 Ok(M::build_dense_unchecked(region, out))
1563 }
1564}
1565
1566impl<T: Ranked> MapIntoExt for T {}
1569
1570pub trait View: Sized {
1572 type Item;
1574
1575 type View: View;
1577
1578 fn region(&self) -> Region;
1580
1581 fn get(&self, rank: usize) -> Option<Self::Item>;
1585
1586 #[allow(clippy::result_large_err)] fn subset(&self, region: Region) -> Result<Self::View, ViewError>;
1591}
1592
1593impl View for Region {
1595 type Item = usize;
1597
1598 type View = Region;
1600
1601 fn region(&self) -> Region {
1602 self.clone()
1603 }
1604
1605 fn subset(&self, region: Region) -> Result<Region, ViewError> {
1606 if region.is_subset(self) {
1607 Ok(region)
1608 } else {
1609 Err(ViewError::InvalidRange {
1610 base: Box::new(self.clone()),
1611 selected: Box::new(region),
1612 })
1613 }
1614 }
1615
1616 fn get(&self, rank: usize) -> Option<Self::Item> {
1617 self.slice.get(rank).ok()
1618 }
1619}
1620
1621impl View for Extent {
1623 type Item = usize;
1625
1626 type View = Region;
1629
1630 fn region(&self) -> Region {
1631 Region {
1632 labels: self.labels().to_vec(),
1633 slice: self.to_slice(),
1634 }
1635 }
1636
1637 fn subset(&self, region: Region) -> Result<Region, ViewError> {
1638 self.region().subset(region)
1639 }
1640
1641 fn get(&self, rank: usize) -> Option<Self::Item> {
1642 if rank < self.num_ranks() {
1643 Some(rank)
1644 } else {
1645 None
1646 }
1647 }
1648}
1649
1650pub trait Ranked: Sized {
1653 type Item: 'static;
1655
1656 fn region(&self) -> &Region;
1658
1659 fn get(&self, rank: usize) -> Option<&Self::Item>;
1661}
1662
1663pub trait RankedSliceable: Ranked {
1671 fn sliced(&self, region: Region) -> Self;
1675}
1676
1677impl<T: RankedSliceable> View for T
1678where
1679 T::Item: Clone + 'static,
1680{
1681 type Item = T::Item;
1682 type View = Self;
1683
1684 fn region(&self) -> Region {
1685 <Self as Ranked>::region(self).clone()
1686 }
1687
1688 fn get(&self, rank: usize) -> Option<Self::Item> {
1689 <Self as Ranked>::get(self, rank).cloned()
1690 }
1691
1692 fn subset(&self, region: Region) -> Result<Self, ViewError> {
1693 if !region.is_subset(self.region()) {
1694 return Err(ViewError::InvalidRange {
1695 base: Box::new(self.region().clone()),
1696 selected: Box::new(region.clone()),
1697 });
1698 }
1699
1700 Ok(self.sliced(region))
1701 }
1702}
1703
1704pub struct ViewIterator {
1706 extent: Extent, pos: SliceIterator, }
1709
1710impl Iterator for ViewIterator {
1711 type Item = (Point, usize);
1712 fn next(&mut self) -> Option<Self::Item> {
1713 let rank = self.pos.next()?;
1715 let coords = self.pos.slice.coordinates(rank).unwrap();
1717 let point = coords.in_(&self.extent).unwrap();
1718 Some((point, rank))
1719 }
1720}
1721
1722pub trait ViewExt: View {
1724 #[allow(clippy::result_large_err)] fn range<R: Into<Range>>(&self, dim: &str, range: R) -> Result<Self::View, ViewError>;
1751
1752 #[allow(clippy::result_large_err)] fn group_by(&self, dim: &str) -> Result<impl Iterator<Item = Self::View>, ViewError>;
1789
1790 fn extent(&self) -> Extent;
1792
1793 fn iter<'a>(&'a self) -> impl Iterator<Item = (Point, Self::Item)> + 'a;
1795
1796 fn values<'a>(&'a self) -> impl Iterator<Item = Self::Item> + 'a;
1798}
1799
1800impl<T: View> ViewExt for T {
1801 fn range<R: Into<Range>>(&self, dim: &str, range: R) -> Result<Self::View, ViewError> {
1802 let (labels, slice) = self.region().into_inner();
1803 let range = range.into();
1804 let dim = labels
1805 .iter()
1806 .position(|l| dim == l)
1807 .ok_or_else(|| ViewError::InvalidDim(dim.to_string()))?;
1808 let (mut offset, mut sizes, mut strides) = slice.into_inner();
1809 let (begin, end, step) = range.resolve(sizes[dim]);
1810 if end <= begin {
1811 return Err(ViewError::EmptyRange {
1812 range,
1813 dim: dim.to_string(),
1814 size: sizes[dim],
1815 });
1816 }
1817
1818 offset += strides[dim] * begin;
1819 sizes[dim] = (end - begin).div_ceil(step);
1820 strides[dim] *= step;
1821 let slice = Slice::new(offset, sizes, strides).unwrap();
1822
1823 self.subset(Region { labels, slice })
1824 }
1825
1826 fn group_by(&self, dim: &str) -> Result<impl Iterator<Item = Self::View>, ViewError> {
1827 let (labels, slice) = self.region().into_inner();
1828
1829 let dim = labels
1830 .iter()
1831 .position(|l| dim == l)
1832 .ok_or_else(|| ViewError::InvalidDim(dim.to_string()))?;
1833
1834 let (offset, sizes, strides) = slice.into_inner();
1835 let mut ranks_iter = Slice::new(offset, sizes[..dim].to_vec(), strides[..dim].to_vec())
1836 .unwrap()
1837 .iter();
1838
1839 let labels = labels[dim..].to_vec();
1840 let sizes = sizes[dim..].to_vec();
1841 let strides = strides[dim..].to_vec();
1842
1843 Ok(std::iter::from_fn(move || {
1844 let rank = ranks_iter.next()?;
1845 let slice = Slice::new(rank, sizes.clone(), strides.clone()).unwrap();
1846 Some(
1848 self.subset(Region {
1849 labels: labels.clone(),
1850 slice,
1851 })
1852 .unwrap(),
1853 )
1854 }))
1855 }
1856
1857 fn extent(&self) -> Extent {
1858 let (labels, slice) = self.region().into_inner();
1859 Extent::new(labels, slice.sizes().to_vec()).unwrap()
1860 }
1861
1862 fn iter(&self) -> impl Iterator<Item = (Point, Self::Item)> + '_ {
1863 let points = ViewIterator {
1864 extent: self.extent(),
1865 pos: self.region().slice().iter(),
1866 };
1867
1868 points.map(|(point, _)| (point.clone(), self.get(point.rank()).unwrap()))
1869 }
1870
1871 fn values(&self) -> impl Iterator<Item = Self::Item> + '_ {
1872 (0usize..self.extent().num_ranks()).map(|rank| self.get(rank).unwrap())
1873 }
1874}
1875
1876#[macro_export]
1884macro_rules! extent {
1885 ( $( $label:ident = $size:expr ),* $(,)? ) => {
1886 {
1887 #[allow(unused_mut)]
1888 let mut labels = Vec::new();
1889 #[allow(unused_mut)]
1890 let mut sizes = Vec::new();
1891
1892 $(
1893 labels.push(stringify!($label).to_string());
1894 sizes.push($size);
1895 )*
1896
1897 $crate::view::Extent::new(labels, sizes).unwrap()
1898 }
1899 };
1900}
1901
1902#[cfg(test)]
1903mod test {
1904 use super::labels::*;
1905 use super::*;
1906 use crate::Shape;
1907 use crate::shape;
1908 use crate::slice::CartesianIterator;
1909
1910 #[test]
1911 fn test_is_safe_ident() {
1912 assert!(is_safe_ident("x"));
1913 assert!(is_safe_ident("gpu_0"));
1914 assert!(!is_safe_ident("dim/0"));
1915 assert!(!is_safe_ident("x y"));
1916 assert!(!is_safe_ident("x=y"));
1917 }
1918 #[test]
1919 fn test_fmt_label() {
1920 assert_eq!(fmt_label("x"), "x");
1921 assert_eq!(fmt_label("dim/0"), "\"dim/0\"");
1922 }
1923
1924 #[test]
1925 fn test_points_basic() {
1926 let extent = extent!(x = 4, y = 5, z = 6);
1927 let _p1 = extent.point(vec![1, 2, 3]).unwrap();
1928 let _p2 = vec![1, 2, 3].in_(&extent).unwrap();
1929
1930 assert_eq!(extent.num_ranks(), 4 * 5 * 6);
1931
1932 let p3 = extent.point_of_rank(0).unwrap();
1933 assert_eq!(p3.coords(), &[0, 0, 0]);
1934 assert_eq!(p3.rank(), 0);
1935
1936 let p4 = extent.point_of_rank(1).unwrap();
1937 assert_eq!(p4.coords(), &[0, 0, 1]);
1938 assert_eq!(p4.rank(), 1);
1939
1940 let p5 = extent.point_of_rank(2).unwrap();
1941 assert_eq!(p5.coords(), &[0, 0, 2]);
1942 assert_eq!(p5.rank(), 2);
1943
1944 let p6 = extent.point_of_rank(6 * 5 + 1).unwrap();
1945 assert_eq!(p6.coords(), &[1, 0, 1]);
1946 assert_eq!(p6.rank(), 6 * 5 + 1);
1947 assert_eq!(p6.coord(0), 1);
1948 assert_eq!(p6.coord(1), 0);
1949 assert_eq!(p6.coord(2), 1);
1950
1951 assert_eq!(extent.points().collect::<Vec<_>>().len(), 4 * 5 * 6);
1952 for (rank, point) in extent.points().enumerate() {
1953 let c = point.coords();
1954 let (x, y, z) = (c[0], c[1], c[2]);
1955 assert_eq!(z + y * 6 + x * 6 * 5, rank);
1956 assert_eq!(point.rank(), rank);
1957 }
1958 }
1959
1960 #[test]
1961 fn points_iterates_ranks_in_row_major_order() {
1962 let ext = extent!(x = 2, y = 3, z = 4); let mut it = ext.points();
1964
1965 for expected_rank in 0..ext.num_ranks() {
1966 let p = it.next().expect("expected another Point");
1967 assert_eq!(
1968 p.rank, expected_rank,
1969 "ranks must be consecutive in row-major order"
1970 );
1971 }
1972 assert!(
1973 it.next().is_none(),
1974 "iterator must be exhausted after num_ranks items"
1975 );
1976 }
1977
1978 #[test]
1979 fn points_iterates_single_point_for_0d_extent() {
1980 let ext = extent!();
1982 let mut it = ext.points();
1983
1984 let p = it
1985 .next()
1986 .expect("0-D extent should yield exactly one point");
1987 assert_eq!(p.rank, 0);
1988 assert_eq!(p.extent, ext);
1989
1990 assert!(
1991 it.next().is_none(),
1992 "no more points after the single 0-D point"
1993 );
1994 }
1995
1996 macro_rules! assert_view {
1997 ($view:expr, $extent:expr, $( $($coord:expr),+ => $rank:expr );* $(;)?) => {
1998 let view = $view;
1999 assert_eq!(view.extent(), $extent);
2000 let expected: Vec<_> = vec![$(($extent.point(vec![$($coord),+]).unwrap(), $rank)),*];
2001 let actual: Vec<_> = ViewExt::iter(&view).collect();
2002 assert_eq!(actual, expected);
2003 };
2004 }
2005
2006 #[test]
2007 fn test_view_basic() {
2008 let extent = extent!(x = 4, y = 4);
2009 assert_view!(
2010 extent.range("x", 0..2).unwrap(),
2011 extent!(x = 2, y = 4),
2012 0, 0 => 0;
2013 0, 1 => 1;
2014 0, 2 => 2;
2015 0, 3 => 3;
2016 1, 0 => 4;
2017 1, 1 => 5;
2018 1, 2 => 6;
2019 1, 3 => 7;
2020 );
2021 assert_view!(
2022 extent.range("x", 1).unwrap().range("y", 2..).unwrap(),
2023 extent!(x = 1, y = 2),
2024 0, 0 => 6;
2025 0, 1 => 7;
2026 );
2027 assert_view!(
2028 extent.range("y", Range(0, None, 2)).unwrap(),
2029 extent!(x = 4, y = 2),
2030 0, 0 => 0;
2031 0, 1 => 2;
2032 1, 0 => 4;
2033 1, 1 => 6;
2034 2, 0 => 8;
2035 2, 1 => 10;
2036 3, 0 => 12;
2037 3, 1 => 14;
2038 );
2039 assert_view!(
2040 extent.range("y", Range(0, None, 2)).unwrap().range("x", 2..).unwrap(),
2041 extent!(x = 2, y = 2),
2042 0, 0 => 8;
2043 0, 1 => 10;
2044 1, 0 => 12;
2045 1, 1 => 14;
2046 );
2047
2048 let extent = extent!(x = 10, y = 2);
2049 assert_view!(
2050 extent.range("x", Range(0, None, 2)).unwrap(),
2051 extent!(x = 5, y = 2),
2052 0, 0 => 0;
2053 0, 1 => 1;
2054 1, 0 => 4;
2055 1, 1 => 5;
2056 2, 0 => 8;
2057 2, 1 => 9;
2058 3, 0 => 12;
2059 3, 1 => 13;
2060 4, 0 => 16;
2061 4, 1 => 17;
2062 );
2063 assert_view!(
2064 extent.range("x", Range(0, None, 2)).unwrap().range("x", 2..).unwrap().range("y", 1).unwrap(),
2065 extent!(x = 3, y = 1),
2066 0, 0 => 9;
2067 1, 0 => 13;
2068 2, 0 => 17;
2069 );
2070
2071 let extent = extent!(zone = 4, host = 2, gpu = 8);
2072 assert_view!(
2073 extent.range("zone", 0).unwrap().range("gpu", Range(0, None, 2)).unwrap(),
2074 extent!(zone = 1, host = 2, gpu = 4),
2075 0, 0, 0 => 0;
2076 0, 0, 1 => 2;
2077 0, 0, 2 => 4;
2078 0, 0, 3 => 6;
2079 0, 1, 0 => 8;
2080 0, 1, 1 => 10;
2081 0, 1, 2 => 12;
2082 0, 1, 3 => 14;
2083 );
2084
2085 let extent = extent!(x = 3);
2086 assert_view!(
2087 extent.range("x", Range(0, None, 2)).unwrap(),
2088 extent!(x = 2),
2089 0 => 0;
2090 1 => 2;
2091 );
2092 }
2093
2094 #[test]
2095 fn test_point_indexing() {
2096 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
2097 let point = extent.point(vec![1, 2, 3]).unwrap();
2098
2099 assert_eq!(point.coord(0), 1);
2100 assert_eq!(point.coord(1), 2);
2101 assert_eq!(point.coord(2), 3);
2102 }
2103
2104 #[test]
2105 #[should_panic]
2106 fn test_point_indexing_out_of_bounds() {
2107 let extent = Extent::new(vec!["x".into(), "y".into()], vec![4, 5]).unwrap();
2108 let point = extent.point(vec![1, 2]).unwrap();
2109
2110 let _ = point.coord(5); }
2112
2113 #[test]
2114 fn test_point_into_iter() {
2115 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
2116 let point = extent.point(vec![1, 2, 3]).unwrap();
2117
2118 let coords: Vec<usize> = (&point).into_iter().collect();
2119 assert_eq!(coords, vec![1, 2, 3]);
2120
2121 let mut sum = 0;
2122 for coord in &point {
2123 sum += coord;
2124 }
2125 assert_eq!(sum, 6);
2126 }
2127
2128 #[test]
2129 fn test_extent_basic() {
2130 let extent = extent!(x = 10, y = 5, z = 1);
2131 assert_eq!(
2132 extent.iter().collect::<Vec<_>>(),
2133 vec![
2134 ("x".to_string(), 10),
2135 ("y".to_string(), 5),
2136 ("z".to_string(), 1)
2137 ]
2138 );
2139 }
2140
2141 #[test]
2142 fn test_extent_display() {
2143 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
2144 assert_eq!(format!("{}", extent), "{x: 4, y: 5, z: 6}");
2145
2146 let extent = Extent::new(vec!["dim/0".into(), "dim/1".into()], vec![4, 5]).unwrap();
2147 assert_eq!(format!("{}", extent), "{\"dim/0\": 4, \"dim/1\": 5}");
2148
2149 let empty_extent = Extent::new(vec![], vec![]).unwrap();
2150 assert_eq!(format!("{}", empty_extent), "{}");
2151 }
2152
2153 #[test]
2154 fn extent_label_helpers() {
2155 let e = extent!(zone = 3, host = 2, gpu = 4);
2156 for (i, (lbl, sz)) in e.iter().enumerate() {
2157 assert_eq!(e.position(&lbl), Some(i));
2158 assert_eq!(e.size(&lbl), Some(sz));
2159 }
2160 assert_eq!(e.position("nope"), None);
2161 assert_eq!(e.size("nope"), None);
2162 }
2163
2164 #[test]
2165 fn test_extent_0d() {
2166 let e = Extent::new(vec![], vec![]).unwrap();
2167 assert_eq!(e.num_ranks(), 1);
2168
2169 let points: Vec<_> = e.points().collect();
2170 assert_eq!(points.len(), 1);
2171 assert_eq!(points[0].coords(), &[] as &[usize]);
2172 assert_eq!(points[0].rank(), 0);
2173
2174 let mut it = (&points[0]).into_iter();
2176 assert_eq!(it.len(), 0);
2177 assert!(it.next().is_none()); assert!(it.next().is_none()); }
2180
2181 #[test]
2182 fn test_extent_concat() {
2183 let extent1 = extent!(x = 2, y = 3);
2185 let extent2 = extent!(z = 4, w = 5);
2186
2187 let result = extent1.concat(&extent2).unwrap();
2188 assert_eq!(result.labels(), &["x", "y", "z", "w"]);
2189 assert_eq!(result.sizes(), &[2, 3, 4, 5]);
2190 assert_eq!(result.num_ranks(), 2 * 3 * 4 * 5);
2191
2192 let empty = extent!();
2194 let result = extent1.concat(&empty).unwrap();
2195 assert_eq!(result.labels(), &["x", "y"]);
2196 assert_eq!(result.sizes(), &[2, 3]);
2197
2198 let result = empty.concat(&extent1).unwrap();
2199 assert_eq!(result.labels(), &["x", "y"]);
2200 assert_eq!(result.sizes(), &[2, 3]);
2201
2202 let result = empty.concat(&empty).unwrap();
2204 assert_eq!(result.labels(), &[] as &[String]);
2205 assert_eq!(result.sizes(), &[] as &[usize]);
2206 assert_eq!(result.num_ranks(), 1); let result = extent1.concat(&extent1);
2210 assert!(
2211 result.is_err(),
2212 "Self-concatenation should error due to overlapping labels"
2213 );
2214 match result.unwrap_err() {
2215 ExtentError::OverlappingLabel { label } => {
2216 assert!(label == "x"); }
2218 other => panic!("Expected OverlappingLabel error, got {:?}", other),
2219 }
2220
2221 let result = extent1.concat(&extent2).unwrap();
2223 let point = result.point(vec![1, 2, 3, 4]).unwrap();
2224 assert_eq!(point.coords(), vec![1, 2, 3, 4]);
2225 assert_eq!(point.extent(), &result);
2226
2227 let extent_a = extent!(x = 2, y = 3);
2229 let extent_b = extent!(y = 3, z = 4); let result = extent_a.concat(&extent_b);
2231 assert!(
2232 result.is_err(),
2233 "Should error on overlapping labels even with same size"
2234 );
2235 match result.unwrap_err() {
2236 ExtentError::OverlappingLabel { label } => {
2237 assert_eq!(label, "y"); }
2239 other => panic!("Expected OverlappingLabel error, got {:?}", other),
2240 }
2241
2242 let extent_x = extent!(x = 2, y = 3);
2244 let extent_y = extent!(z = 4);
2245 assert_eq!(
2246 extent_x.concat(&extent_y).unwrap().labels(),
2247 &["x", "y", "z"]
2248 );
2249 assert_eq!(
2250 extent_y.concat(&extent_x).unwrap().labels(),
2251 &["z", "x", "y"]
2252 );
2253
2254 let extent_m = extent!(x = 2);
2256 let extent_n = extent!(y = 3);
2257 let extent_o = extent!(z = 4);
2258
2259 let left_assoc = extent_m
2260 .concat(&extent_n)
2261 .unwrap()
2262 .concat(&extent_o)
2263 .unwrap();
2264 let right_assoc = extent_m
2265 .concat(&extent_n.concat(&extent_o).unwrap())
2266 .unwrap();
2267
2268 assert_eq!(left_assoc, right_assoc);
2269 assert_eq!(left_assoc.labels(), &["x", "y", "z"]);
2270 assert_eq!(left_assoc.sizes(), &[2, 3, 4]);
2271 assert_eq!(left_assoc.num_ranks(), 2 * 3 * 4);
2272 }
2273
2274 #[test]
2275 fn extent_unity_equiv_to_0d() {
2276 let e = Extent::unity();
2277 assert!(e.is_empty());
2278 assert_eq!(e.num_ranks(), 1);
2279 let pts: Vec<_> = e.points().collect();
2280 assert_eq!(pts.len(), 1);
2281 assert_eq!(pts[0].rank(), 0);
2282 assert!(pts[0].coords().is_empty());
2283 }
2284
2285 #[test]
2286 fn test_point_display() {
2287 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
2288 let point = extent.point(vec![1, 2, 3]).unwrap();
2289 assert_eq!(format!("{}", point), "x=1/4,y=2/5,z=3/6");
2290
2291 assert!(extent.point(vec![]).is_err());
2292
2293 let empty_extent = Extent::new(vec![], vec![]).unwrap();
2294 let empty_point = empty_extent.point(vec![]).unwrap();
2295 assert_eq!(format!("{}", empty_point), "");
2296 }
2297
2298 #[test]
2299 fn test_point_display_with_quoted_labels() {
2300 let ext = Extent::new(vec!["dim/0".into(), "dim,1".into()], vec![3, 5]).unwrap();
2302
2303 assert_eq!(format!("{}", ext), "{\"dim/0\": 3, \"dim,1\": 5}");
2305
2306 let p = ext.point(vec![1, 2]).unwrap();
2308 assert_eq!(format!("{}", p), "\"dim/0\"=1/3,\"dim,1\"=2/5");
2309 }
2310
2311 #[test]
2312 fn test_relative_point() {
2313 pub fn relative_point(rank_on_root_mesh: usize, shape: &Shape) -> anyhow::Result<Point> {
2316 let coords = shape.slice().coordinates(rank_on_root_mesh)?;
2317 let extent = Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec())?;
2318 Ok(extent.point(coords)?)
2319 }
2320
2321 let root_shape = shape! { replicas = 4, hosts = 4, gpus = 4 };
2322 let sliced_shape = root_shape
2334 .select("replicas", crate::Range(0, Some(4), 3))
2335 .unwrap()
2336 .select("hosts", crate::Range(1, Some(4), 2))
2337 .unwrap()
2338 .select("gpus", crate::Range(0, Some(4), 2))
2339 .unwrap();
2340 let ranks_on_root_mesh = &[4, 6, 12, 14, 52, 54, 60, 62];
2341 assert_eq!(
2342 sliced_shape.slice().iter().collect::<Vec<_>>(),
2343 ranks_on_root_mesh,
2344 );
2345
2346 let ranks_on_sliced_mesh = ranks_on_root_mesh
2347 .iter()
2348 .map(|&r| relative_point(r, &sliced_shape).unwrap().rank());
2349 assert_eq!(
2350 ranks_on_sliced_mesh.collect::<Vec<_>>(),
2351 vec![0, 1, 2, 3, 4, 5, 6, 7]
2352 );
2353 }
2354
2355 #[test]
2356 fn test_iter_subviews() {
2357 let extent = extent!(zone = 4, host = 4, gpu = 8);
2358
2359 assert_eq!(extent.group_by("gpu").unwrap().count(), 16);
2360 assert_eq!(extent.group_by("zone").unwrap().count(), 1);
2361
2362 let mut parts = extent.group_by("gpu").unwrap();
2363 assert_view!(
2364 parts.next().unwrap(),
2365 extent!(gpu = 8),
2366 0 => 0;
2367 1 => 1;
2368 2 => 2;
2369 3 => 3;
2370 4 => 4;
2371 5 => 5;
2372 6 => 6;
2373 7 => 7;
2374 );
2375 assert_view!(
2376 parts.next().unwrap(),
2377 extent!(gpu = 8),
2378 0 => 8;
2379 1 => 9;
2380 2 => 10;
2381 3 => 11;
2382 4 => 12;
2383 5 => 13;
2384 6 => 14;
2385 7 => 15;
2386 );
2387 }
2388
2389 #[test]
2390 fn test_view_values() {
2391 let extent = extent!(x = 4, y = 4);
2392 assert_eq!(
2393 extent.values().collect::<Vec<_>>(),
2394 (0..16).collect::<Vec<_>>()
2395 );
2396 let region = extent.range("y", 1).unwrap();
2397 assert_eq!(region.values().collect::<Vec<_>>(), vec![1, 5, 9, 13]);
2398 }
2399
2400 #[test]
2401 fn region_is_subset_algebra() {
2402 let e = extent!(x = 5, y = 4);
2403 let a = e.range("x", 1..4).unwrap(); let b = a.range("y", 1..3).unwrap(); let c = e.range("x", 0..2).unwrap(); assert!(b.region().is_subset(&a.region()));
2408 assert!(b.region().is_subset(&e.region()));
2409 assert!(a.region().is_subset(&e.region()));
2410
2411 assert!(!c.region().is_subset(&a.region()));
2412 assert!(c.region().is_subset(&e.region()));
2413 }
2414
2415 #[test]
2416 fn test_remap() {
2417 let region: Region = extent!(x = 4, y = 4).into();
2418 assert_eq!(
2420 region.remap(®ion).unwrap().collect::<Vec<_>>(),
2421 (0..16).collect::<Vec<_>>()
2422 );
2423
2424 let subset = region.range("x", 2..).unwrap();
2425 assert_eq!(subset.num_ranks(), 8);
2426 assert_eq!(
2427 region.remap(&subset).unwrap().collect::<Vec<_>>(),
2428 vec![8, 9, 10, 11, 12, 13, 14, 15],
2429 );
2430
2431 let subset = subset.range("y", 1).unwrap();
2432 assert_eq!(subset.num_ranks(), 2);
2433 assert_eq!(
2434 region.remap(&subset).unwrap().collect::<Vec<_>>(),
2435 vec![9, 13],
2436 );
2437
2438 let ext = extent!(replica = 8, gpu = 4);
2441 let replica1 = ext.range("replica", 1).unwrap();
2442 assert_eq!(replica1.extent(), extent!(replica = 1, gpu = 4));
2443 let replica1_gpu12 = replica1.range("gpu", 1..3).unwrap();
2444 assert_eq!(replica1_gpu12.extent(), extent!(replica = 1, gpu = 2));
2445 assert_eq!(
2447 replica1.remap(&replica1_gpu12).unwrap().collect::<Vec<_>>(),
2448 vec![1, 2],
2449 );
2450 }
2451
2452 #[test]
2453 fn test_base_local_rank_conversion() {
2454 fn point(rank: usize, region: &Region) -> Point {
2455 region.extent().point_of_rank(rank).unwrap()
2456 }
2457
2458 let extent = extent!(replicas = 4, gpus = 2);
2459 let region = extent.range("replicas", 1..3).unwrap();
2460 assert!(
2468 region
2469 .base_rank_of_point(extent.point_of_rank(0).unwrap())
2470 .is_err()
2471 );
2472 assert_eq!(region.base_rank_of_point(point(0, ®ion)).unwrap(), 2);
2474 assert_eq!(region.base_rank_of_point(point(1, ®ion)).unwrap(), 3);
2475 assert_eq!(region.base_rank_of_point(point(2, ®ion)).unwrap(), 4);
2476 assert_eq!(region.base_rank_of_point(point(3, ®ion)).unwrap(), 5);
2477 assert_eq!(region.point_of_base_rank(2).unwrap(), point(0, ®ion));
2479 assert_eq!(region.point_of_base_rank(3).unwrap(), point(1, ®ion));
2480 assert_eq!(region.point_of_base_rank(4).unwrap(), point(2, ®ion));
2481 assert_eq!(region.point_of_base_rank(5).unwrap(), point(3, ®ion));
2482 assert!(region.point_of_base_rank(1).is_err());
2484 assert!(region.point_of_base_rank(6).is_err());
2485
2486 let subset = region
2488 .range("replicas", 1..2)
2489 .unwrap()
2490 .range("gpus", 1..2)
2491 .unwrap();
2492 assert_eq!(subset.base_rank_of_point(point(0, &subset)).unwrap(), 5);
2499 assert_eq!(subset.point_of_base_rank(5).unwrap(), point(0, &subset));
2500 assert!(subset.point_of_base_rank(4).is_err());
2502 assert!(subset.point_of_base_rank(6).is_err());
2503 }
2504
2505 use proptest::prelude::*;
2506
2507 use crate::strategy::gen_extent;
2508 use crate::strategy::gen_region;
2509 use crate::strategy::gen_region_strided;
2510
2511 proptest! {
2512 #[test]
2513 fn test_region_parser(region in gen_region(1..=5, 1024)) {
2514 assert_eq!(
2516 region,
2517 region.to_string().parse::<Region>().unwrap(),
2518 "failed to roundtrip region {}", region
2519 );
2520 }
2521 }
2522
2523 proptest! {
2535 #[test]
2536 fn region_parser_with_offset_roundtrips(region in gen_region(1..=4, 8)) {
2537 let (labels, slice) = region.clone().into_inner();
2538 let region_off = Region {
2539 labels,
2540 slice: Slice::new(8, slice.sizes().to_vec(), slice.strides().to_vec()).unwrap(),
2541 };
2542 let s = region_off.to_string();
2543 let parsed: Region = s.parse().unwrap();
2544 prop_assert_eq!(parsed, region_off);
2545 }
2546 }
2547
2548 proptest! {
2559 #[test]
2560 fn region_strided_display_parse_roundtrips(
2561 region in gen_region_strided(1..=4, 6, 3, 16)
2562 ) {
2563 let s = region.to_string();
2574 let parsed: Region = s.parse().unwrap();
2575 prop_assert_eq!(parsed, region);
2576 }
2577 }
2578
2579 proptest! {
2590 #[test]
2591 fn region_strided_display_matches_slice(
2592 region in gen_region_strided(1..=4, 6, 3, 16)
2593 ) {
2594 let s = region.to_string();
2595 let slice = region.slice();
2596
2597 if slice.offset() != 0 {
2599 let prefix: Vec<_> = s.split('+').collect();
2600 prop_assert!(prefix.len() > 1, "expected offset+ form in {}", s);
2601 let offset_str = prefix[0];
2602 let offset_val: usize = offset_str.parse().unwrap();
2603 prop_assert_eq!(offset_val, slice.offset(), "offset mismatch in {}", s);
2604 } else {
2605 prop_assert!(!s.contains('+'), "unexpected +offset in {}", s);
2606 }
2607
2608 let body = s.split('+').next_back().unwrap(); let parts: Vec<_> = body.split(',').collect();
2611 prop_assert_eq!(parts.len(), slice.sizes().len());
2612
2613 for (i, part) in parts.iter().enumerate() {
2614 let rhs = part.split('=').nth(1).unwrap();
2616 let mut nums = rhs.split('/');
2617 let size_val: usize = nums.next().unwrap().parse().unwrap();
2618 let stride_val: usize = nums.next().unwrap().parse().unwrap();
2619
2620 prop_assert_eq!(size_val, slice.sizes()[i], "size mismatch at dim {} in {}", i, s);
2621 prop_assert_eq!(stride_val, slice.strides()[i], "stride mismatch at dim {} in {}", i, s);
2622 }
2623 }
2624 }
2625
2626 #[test]
2627 fn test_point_from_str_round_trip() {
2628 let points = vec![
2629 extent!(x = 4, y = 5, z = 6).point(vec![1, 2, 3]).unwrap(),
2630 extent!(host = 2, gpu = 8).point(vec![0, 7]).unwrap(),
2631 extent!().point(vec![]).unwrap(),
2632 extent!(x = 10).point(vec![5]).unwrap(),
2633 ];
2634
2635 for point in points {
2636 assert_eq!(point, point.to_string().parse().unwrap());
2637 }
2638 }
2639
2640 #[test]
2641 fn test_point_from_str_basic() {
2642 let cases = vec![
2643 ("x=1/4,y=2/5", extent!(x = 4, y = 5), vec![1, 2]),
2644 ("host=0/2,gpu=7/8", extent!(host = 2, gpu = 8), vec![0, 7]),
2645 ("z=3/6", extent!(z = 6), vec![3]),
2646 ("", extent!(), vec![]), (" x = 1 / 4 , y = 2 / 5 ", extent!(x = 4, y = 5), vec![1, 2]),
2649 ];
2650
2651 for (input, expected_extent, expected_coords) in cases {
2652 let parsed: Point = input.parse().unwrap();
2653 let expected = expected_extent.point(expected_coords).unwrap();
2654 assert_eq!(parsed, expected, "failed to parse: {}", input);
2655 }
2656 }
2657
2658 #[test]
2659 fn test_point_from_str_quoted() {
2660 let extent = Extent::new(vec!["dim/0".into(), "dim,1".into()], vec![3, 5]).unwrap();
2662 let point = extent.point(vec![1, 2]).unwrap();
2663
2664 let display_str = point.to_string();
2665 assert_eq!(display_str, "\"dim/0\"=1/3,\"dim,1\"=2/5");
2666
2667 let parsed: Point = display_str.parse().unwrap();
2668 assert_eq!(parsed, point);
2669
2670 let parsed: Point = "\"dim/0\"=1/3,\"dim,1\"=2/5".parse().unwrap();
2671 assert_eq!(parsed, point);
2672 }
2673
2674 #[test]
2675 fn test_point_from_str_error_cases() {
2676 let error_cases = vec![
2678 "x=1,y=2/5", "x=1/4,y=2", "x=1/4,y=/5", "x=/4,y=2/5", "x=1/4,y=2/", "x=1/,y=2/5", "x=1/4=5,y=2/5", "x=1/4/6,y=2/5", "x=abc/4,y=2/5", "x=1/abc,y=2/5", "=1/4,y=2/5", "x=1/4,", "x=1/4,=2/5", "x=1/4,y", "x", "x=", "x=1/4,y=10/5", ];
2696
2697 for input in error_cases {
2698 let result: Result<Point, PointError> = input.parse();
2699 assert!(result.is_err(), "Expected error for input: '{}'", input);
2700 }
2701 }
2702
2703 #[test]
2704 fn test_point_from_str_coordinate_validation() {
2705 let input = "x=5/4,y=2/5"; let result: Result<Point, PointError> = input.parse();
2708 assert!(
2709 result.is_err(),
2710 "Expected error for out-of-bounds coordinate"
2711 );
2712
2713 match result.unwrap_err() {
2714 PointError::OutOfRangeIndex { size, index } => {
2715 assert_eq!(size, 4);
2716 assert_eq!(index, 5);
2717 }
2718 _ => panic!("Expected OutOfRangeIndex error"),
2719 }
2720 }
2721
2722 #[test]
2723 fn test_point_from_str_consistency_validation() {
2724 let input = "x=1/4,y=2/5,z=3/6";
2729 let parsed: Point = input.parse().unwrap();
2730
2731 assert_eq!(parsed.extent().labels(), &["x", "y", "z"]);
2732 assert_eq!(parsed.extent().sizes(), &[4, 5, 6]);
2733 assert_eq!(parsed.coords(), vec![1, 2, 3]);
2734 }
2735
2736 proptest! {
2737 #[test]
2740 fn point_coord_and_iter_agree(extent in gen_extent(0..=4, 8)) {
2741 for p in extent.points() {
2742 let via_coords = p.coords();
2743 let via_into_iter: Vec<_> = (&p).into_iter().collect();
2744 prop_assert_eq!(via_into_iter, via_coords.clone(), "coord_iter mismatch for {}", p);
2745
2746 for (i, &coord) in via_coords.iter().enumerate() {
2747 prop_assert_eq!(p.coord(i), coord, "coord(i) mismatch at axis {} for {}", i, p);
2748 }
2749 }
2750 }
2751
2752 #[test]
2754 fn points_count_matches_num_ranks(extent in gen_extent(0..=4, 8)) {
2755 let c = extent.points().count();
2756 prop_assert_eq!(c, extent.num_ranks(), "count {} != num_ranks {}", c, extent.num_ranks());
2757 }
2758
2759 #[test]
2763 fn coord_iter_exact_size_invariants(extent in gen_extent(0..=4, 8)) {
2764 for p in extent.points() {
2765 let mut it = (&p).into_iter();
2766
2767 let mut remaining = p.len();
2770 prop_assert_eq!(it.len(), remaining);
2771 prop_assert_eq!(it.size_hint(), (remaining, Some(remaining)));
2772
2773 let mut yielded = Vec::with_capacity(remaining);
2775
2776 while let Some(v) = it.next() {
2779 yielded.push(v);
2780 remaining -= 1;
2781 prop_assert_eq!(it.len(), remaining);
2782 prop_assert_eq!(it.size_hint(), (remaining, Some(remaining)));
2783 }
2784
2785 prop_assert_eq!(remaining, 0);
2788 prop_assert!(it.next().is_none());
2789 prop_assert!(it.next().is_none());
2790
2791 prop_assert_eq!(yielded, p.coords());
2793 }
2794 }
2795
2796 #[test]
2800 fn rank_of_coords_dim_mismatch(extent in gen_extent(0..=4, 8)) {
2801 let want = extent.len();
2802 let wrong = if want == 0 { 1 } else { want - 1 };
2804 let bad = vec![0usize; wrong];
2805
2806 match extent.rank_of_coords(&bad).unwrap_err() {
2807 PointError::DimMismatch { expected, actual } => {
2808 prop_assert_eq!(expected, want, "expected len mismatch");
2809 prop_assert_eq!(actual, wrong, "actual len mismatch");
2810 }
2811 other => prop_assert!(false, "expected DimMismatch, got {:?}", other),
2812 }
2813 }
2814
2815 #[test]
2820 fn rank_of_coords_out_of_range_index(extent in gen_extent(1..=4, 8)) {
2821 let sizes = extent.sizes().to_vec();
2823 let mut coords = vec![0usize; sizes.len()];
2825 let axis = 0usize;
2827 coords[axis] = sizes[axis];
2828
2829 match extent.rank_of_coords(&coords).unwrap_err() {
2830 PointError::OutOfRangeIndex { size, index } => {
2831 prop_assert_eq!(size, sizes[axis], "reported size mismatch");
2832 prop_assert_eq!(index, sizes[axis], "reported index mismatch");
2833 }
2834 other => prop_assert!(false, "expected OutOfRangeIndex, got {:?}", other),
2835 }
2836 }
2837
2838 #[test]
2841 fn point_of_rank_out_of_range(extent in gen_extent(0..=4, 8)) {
2842 let total = extent.num_ranks(); match extent.point_of_rank(total).unwrap_err() {
2844 PointError::OutOfRangeRank { total: t, rank: r } => {
2845 prop_assert_eq!(t, total, "reported total mismatch");
2846 prop_assert_eq!(r, total, "reported rank mismatch");
2847 }
2848 other => prop_assert!(false, "expected OutOfRangeRank, got {:?}", other),
2849 }
2850 }
2851
2852 #[test]
2854 fn point_display_parse_round_trip(extent in gen_extent(0..=4, 8)) {
2855 for point in extent.points() {
2856 let display = point.to_string();
2857 let parsed: Point = display.parse().unwrap();
2858 prop_assert_eq!(parsed, point, "round-trip failed for point: {}", display);
2859 }
2860 }
2861 }
2862
2863 proptest! {
2864 #[test]
2867 fn points_iterator_equivalent_to_legacy_cartesian(extent in gen_extent(0..=4, 8)) {
2868 let sizes = extent.sizes().to_vec();
2869
2870 let legacy_len = CartesianIterator::new(sizes.clone()).count();
2872 prop_assert_eq!(legacy_len, extent.num_ranks());
2873
2874 let legacy_coords = CartesianIterator::new(sizes);
2877 for (step, (p, coords)) in extent.points().zip(legacy_coords).enumerate() {
2878 let old_rank = extent.rank_of_coords(&coords).expect("valid legacy coords");
2879 prop_assert_eq!(p.rank(), old_rank, "rank mismatch at step {} for coords {:?}", step, coords);
2880 prop_assert_eq!(p.coords(), coords, "coords mismatch at step {}", step);
2881 }
2882 }
2883 }
2884}