1use std::str::FromStr;
33use std::sync::Arc;
34
35use serde::Deserialize;
36use serde::Serialize;
37use thiserror::Error;
38
39use crate::Range;
40use crate::Shape; use crate::Slice;
42use crate::SliceError;
43use crate::SliceIterator;
44use crate::parse::Parser;
45use crate::parse::ParserError;
46
47#[derive(Debug, thiserror::Error)]
49pub enum ExtentError {
50 #[error("label/sizes dimension mismatch: {num_labels} != {num_sizes}")]
56 DimMismatch {
57 num_labels: usize,
59 num_sizes: usize,
61 },
62
63 #[error("overlapping label found: {label}")]
68 OverlappingLabel {
69 label: String,
71 },
72}
73
74#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
82pub struct Extent {
83 inner: Arc<ExtentData>,
84}
85
86fn _assert_extent_traits()
87where
88 Extent: Send + Sync + 'static,
89{
90}
91
92#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
97struct ExtentData {
98 labels: Vec<String>,
99 sizes: Vec<usize>,
100}
101
102impl From<&Shape> for Extent {
103 fn from(s: &Shape) -> Self {
104 Extent::new(s.labels().to_vec(), s.slice().sizes().to_vec()).unwrap()
106 }
107}
108
109impl From<Shape> for Extent {
110 fn from(s: Shape) -> Self {
111 Extent::from(&s)
112 }
113}
114
115impl Extent {
116 pub fn new(labels: Vec<String>, sizes: Vec<usize>) -> Result<Self, ExtentError> {
118 if labels.len() != sizes.len() {
119 return Err(ExtentError::DimMismatch {
120 num_labels: labels.len(),
121 num_sizes: sizes.len(),
122 });
123 }
124
125 Ok(Self {
126 inner: Arc::new(ExtentData { labels, sizes }),
127 })
128 }
129
130 pub fn unity() -> Extent {
131 Extent::new(vec![], vec![]).unwrap()
132 }
133
134 pub fn labels(&self) -> &[String] {
136 &self.inner.labels
137 }
138
139 pub fn sizes(&self) -> &[usize] {
141 &self.inner.sizes
142 }
143
144 pub fn size(&self, label: &str) -> Option<usize> {
147 self.position(label).map(|pos| self.sizes()[pos])
148 }
149
150 pub fn position(&self, label: &str) -> Option<usize> {
153 self.labels().iter().position(|l| l == label)
154 }
155
156 pub fn rank_of_coords(&self, coords: &[usize]) -> Result<usize, PointError> {
166 let sizes = self.sizes();
167 if coords.len() != sizes.len() {
168 return Err(PointError::DimMismatch {
169 expected: sizes.len(),
170 actual: coords.len(),
171 });
172 }
173 let mut stride = 1;
174 let mut result = 0;
175 for (&c, &size) in coords.iter().rev().zip(sizes.iter().rev()) {
176 if c >= size {
177 return Err(PointError::OutOfRangeIndex { size, index: c });
178 }
179 result += c * stride;
180 stride *= size;
181 }
182 Ok(result)
183 }
184
185 pub fn point(&self, coords: Vec<usize>) -> Result<Point, PointError> {
231 Ok(Point {
232 rank: self.rank_of_coords(&coords)?,
233 extent: self.clone(),
234 })
235 }
236
237 pub fn point_of_rank(&self, rank: usize) -> Result<Point, PointError> {
261 let total = self.num_ranks();
262 if rank >= total {
263 return Err(PointError::OutOfRangeRank { total, rank });
264 }
265 Ok(Point {
266 rank,
267 extent: self.clone(),
268 })
269 }
270
271 pub fn len(&self) -> usize {
276 self.sizes().len()
277 }
278
279 pub fn is_empty(&self) -> bool {
284 self.sizes().is_empty()
285 }
286
287 pub fn num_ranks(&self) -> usize {
292 self.sizes().iter().product()
293 }
294
295 pub fn into_inner(self) -> (Vec<String>, Vec<usize>) {
297 match Arc::try_unwrap(self.inner) {
298 Ok(data) => (data.labels, data.sizes),
299 Err(shared) => (shared.labels.clone(), shared.sizes.clone()),
300 }
301 }
302
303 pub fn to_slice(&self) -> Slice {
305 Slice::new_row_major(self.sizes())
306 }
307
308 pub fn iter(&self) -> impl Iterator<Item = (String, usize)> + use<'_> {
310 self.labels()
311 .iter()
312 .zip(self.sizes().iter())
313 .map(|(l, s)| (l.clone(), *s))
314 }
315
316 pub fn points(&self) -> ExtentPointsIterator<'_> {
318 ExtentPointsIterator::new(self)
319 }
320
321 pub fn concat(&self, other: &Extent) -> Result<Self, ExtentError> {
329 use std::collections::HashSet;
330 let lhs: HashSet<&str> = self.labels().iter().map(|s| s.as_str()).collect();
332 if let Some(dup) = other.labels().iter().find(|l| lhs.contains(l.as_str())) {
333 return Err(ExtentError::OverlappingLabel { label: dup.clone() });
334 }
335 let mut labels = self.labels().to_vec();
337 let mut sizes = self.sizes().to_vec();
338 labels.reserve(other.labels().len());
339 sizes.reserve(other.sizes().len());
340 labels.extend(other.labels().iter().cloned());
341 sizes.extend(other.sizes().iter().copied());
342 Extent::new(labels, sizes)
343 }
344}
345
346mod labels {
358 pub(super) fn is_safe_ident(s: &str) -> bool {
362 s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
363 }
364
365 pub(super) fn fmt_label(s: &str) -> String {
370 if is_safe_ident(s) {
371 s.to_string()
372 } else {
373 format!("{:?}", s)
374 }
375 }
376}
377
378impl std::fmt::Display for Extent {
414 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415 write!(f, "{{")?;
416 for i in 0..self.sizes().len() {
417 write!(
418 f,
419 "{}: {}",
420 labels::fmt_label(&self.labels()[i]),
421 self.sizes()[i]
422 )?;
423 if i + 1 != self.sizes().len() {
424 write!(f, ", ")?;
425 }
426 }
427 write!(f, "}}")
428 }
429}
430
431pub struct ExtentPointsIterator<'a> {
433 extent: &'a Extent,
434 next_rank: usize,
435}
436
437impl<'a> ExtentPointsIterator<'a> {
438 pub fn new(extent: &'a Extent) -> Self {
439 Self {
440 extent,
441 next_rank: 0,
442 }
443 }
444}
445
446impl<'a> Iterator for ExtentPointsIterator<'a> {
447 type Item = Point;
448
449 fn next(&mut self) -> Option<Self::Item> {
452 if self.next_rank == self.extent.num_ranks() {
453 return None;
454 }
455
456 let p = Point {
457 rank: self.next_rank,
458 extent: self.extent.clone(),
459 };
460 self.next_rank += 1;
461 Some(p)
462 }
463}
464
465#[derive(Debug, Error)]
467pub enum PointError {
468 #[error("dimension mismatch: expected {expected}, got {actual}")]
475 DimMismatch {
476 expected: usize,
478 actual: usize,
480 },
481
482 #[error("out of range: total ranks {total}; does not contain rank {rank}")]
489 OutOfRangeRank {
490 total: usize,
492 rank: usize,
494 },
495
496 #[error("out of range: dim size {size}; does not contain index {index}")]
504 OutOfRangeIndex {
505 size: usize,
507 index: usize,
509 },
510
511 #[error("failed to parse point: {reason}")]
513 ParseError { reason: String },
514}
515
516#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
549pub struct Point {
550 rank: usize,
551 extent: Extent,
552}
553
554pub struct CoordIter<'a> {
574 sizes: &'a [usize],
575 rank: usize,
576 stride: usize,
577 axis: usize,
578}
579
580impl<'a> Iterator for CoordIter<'a> {
581 type Item = usize;
582
583 fn next(&mut self) -> Option<Self::Item> {
588 if self.axis >= self.sizes.len() {
589 return None;
590 }
591 self.stride /= self.sizes[self.axis];
592 let q = self.rank / self.stride;
593 self.rank %= self.stride;
594 self.axis += 1;
595 Some(q)
596 }
597
598 fn size_hint(&self) -> (usize, Option<usize>) {
604 let rem = self.sizes.len().saturating_sub(self.axis);
605 (rem, Some(rem))
606 }
607}
608
609impl ExactSizeIterator for CoordIter<'_> {}
610
611impl<'a> IntoIterator for &'a Point {
612 type Item = usize;
613 type IntoIter = CoordIter<'a>;
614
615 fn into_iter(self) -> Self::IntoIter {
634 self.coords_iter()
635 }
636}
637
638fn _assert_point_traits()
639where
640 Point: Send + Sync + 'static,
641{
642}
643
644pub trait InExtent {
660 fn in_(self, extent: &Extent) -> Result<Point, PointError>;
661}
662
663impl InExtent for Vec<usize> {
664 fn in_(self, extent: &Extent) -> Result<Point, PointError> {
669 extent.point(self)
670 }
671}
672
673impl Point {
674 pub fn coords_iter(&self) -> CoordIter<'_> {
675 CoordIter {
676 sizes: self.extent.sizes(),
677 rank: self.rank,
678 stride: self.extent.sizes().iter().product(),
679 axis: 0,
680 }
681 }
682
683 pub fn coord(&self, i: usize) -> usize {
701 self.coords_iter()
702 .nth(i)
703 .expect("coord(i): axis out of bounds")
704 }
705
706 pub fn coords(&self) -> Vec<usize> {
721 self.coords_iter().collect()
722 }
723
724 pub fn rank(&self) -> usize {
727 self.rank
728 }
729
730 pub fn extent(&self) -> &Extent {
733 &self.extent
734 }
735
736 pub fn len(&self) -> usize {
751 self.extent.len()
752 }
753
754 pub fn is_empty(&self) -> bool {
770 self.extent.len() == 0
771 }
772}
773
774impl std::fmt::Display for Point {
810 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
811 let labels = self.extent.labels();
812 let sizes = self.extent.sizes();
813 let coords = self.coords();
814
815 for i in 0..labels.len() {
816 write!(
817 f,
818 "{}={}/{}",
819 labels::fmt_label(&labels[i]),
820 coords[i],
821 sizes[i]
822 )?;
823 if i + 1 != labels.len() {
824 write!(f, ",")?;
825 }
826 }
827 Ok(())
828 }
829}
830
831impl FromStr for Point {
832 type Err = PointError;
833
834 fn from_str(s: &str) -> Result<Self, Self::Err> {
835 let s = s.trim();
836
837 if s.is_empty() {
838 let empty_extent = Extent::unity();
839 return empty_extent.point(vec![]);
840 }
841
842 let mut labels = Vec::new();
843 let mut coords = Vec::new();
844 let mut sizes = Vec::new();
845
846 let mut chars = s.chars().peekable();
847
848 while chars.peek().is_some() {
849 while chars.peek() == Some(&' ') {
850 chars.next();
851 }
852
853 if chars.peek().is_none() {
854 break;
855 }
856
857 let label = if chars.peek() == Some(&'"') {
858 chars.next(); let mut label = String::new();
860 let mut escaped = false;
861
862 for ch in chars.by_ref() {
864 if escaped {
865 match ch {
866 '"' => label.push('"'),
867 '\\' => label.push('\\'),
868 _ => {
869 label.push('\\');
870 label.push(ch);
871 }
872 }
873 escaped = false;
874 } else if ch == '\\' {
875 escaped = true;
876 } else if ch == '"' {
877 break;
878 } else {
879 label.push(ch);
880 }
881 }
882
883 if label.is_empty() {
884 return Err(PointError::ParseError {
885 reason: "empty quoted label".to_string(),
886 });
887 }
888
889 label
890 } else {
891 let mut label = String::new();
892 while let Some(&ch) = chars.peek() {
893 if ch == '=' || ch == ' ' {
894 break;
895 }
896 label.push(chars.next().unwrap());
897 }
898
899 if label.is_empty() {
900 return Err(PointError::ParseError {
901 reason: "missing label".to_string(),
902 });
903 }
904
905 label
906 };
907
908 while chars.peek() == Some(&' ') {
909 chars.next();
910 }
911
912 if chars.next() != Some('=') {
913 return Err(PointError::ParseError {
914 reason: format!("expected '=' after label '{}'", label),
915 });
916 }
917
918 while chars.peek() == Some(&' ') {
919 chars.next();
920 }
921
922 let mut coord = String::new();
923 while let Some(&ch) = chars.peek() {
924 if ch == '/' || ch == ' ' {
925 break;
926 }
927 coord.push(chars.next().unwrap());
928 }
929
930 if coord.is_empty() {
931 return Err(PointError::ParseError {
932 reason: format!("missing coordinate for dimension '{}'", label),
933 });
934 }
935
936 while chars.peek() == Some(&' ') {
937 chars.next();
938 }
939
940 if chars.next() != Some('/') {
941 return Err(PointError::ParseError {
942 reason: format!("expected '/' after coordinate for dimension '{}'", label),
943 });
944 }
945
946 while chars.peek() == Some(&' ') {
947 chars.next();
948 }
949
950 let mut size = String::new();
951 while let Some(&ch) = chars.peek() {
952 if ch == ',' || ch == ' ' {
953 break;
954 }
955 size.push(chars.next().unwrap());
956 }
957
958 if size.is_empty() {
959 return Err(PointError::ParseError {
960 reason: format!("missing size for dimension '{}'", label),
961 });
962 }
963
964 let coord = coord.parse::<usize>().map_err(|e| PointError::ParseError {
965 reason: format!(
966 "invalid coordinate '{}' for dimension '{}': {}",
967 coord, label, e
968 ),
969 })?;
970
971 let size = size.parse::<usize>().map_err(|e| PointError::ParseError {
972 reason: format!("invalid size '{}' for dimension '{}': {}", size, label, e),
973 })?;
974
975 labels.push(label);
976 coords.push(coord);
977 sizes.push(size);
978
979 while chars.peek() == Some(&' ') {
980 chars.next();
981 }
982
983 if chars.peek() == Some(&',') {
984 chars.next(); while chars.peek() == Some(&' ') {
986 chars.next();
987 }
988 if chars.peek().is_none() {
990 return Err(PointError::ParseError {
991 reason: "trailing comma".to_string(),
992 });
993 }
994 }
995 }
996
997 let extent = Extent::new(labels, sizes).map_err(|e| PointError::ParseError {
998 reason: format!("failed to create extent: {}", e),
999 })?;
1000
1001 extent.point(coords)
1002 }
1003}
1004
1005#[derive(Debug, Error)]
1007pub enum ViewError {
1008 #[error("no such dimension: {0}")]
1010 InvalidDim(String),
1011
1012 #[error("empty range: {range} for dimension {dim} of size {size}")]
1014 EmptyRange {
1015 range: Range,
1016 dim: String,
1017 size: usize,
1018 },
1019
1020 #[error(transparent)]
1021 ExtentError(#[from] ExtentError),
1022
1023 #[error("invalid range: selected ranks {selected} not a subset of base {base} ")]
1024 InvalidRange {
1025 base: Box<Region>,
1026 selected: Box<Region>,
1027 },
1028}
1029
1030#[derive(Debug, Error)]
1032pub enum RegionError {
1033 #[error("invalid point: this point does not belong to this region: {0}")]
1034 InvalidPoint(String),
1035
1036 #[error("out of range base rank: this base rank {0} does not belong to this region: {0}")]
1037 OutOfRangeBaseRank(usize, String),
1038}
1039
1040#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
1047pub struct Region {
1048 labels: Vec<String>,
1049 slice: Slice,
1050}
1051
1052impl Region {
1053 #[allow(dead_code)]
1054 fn empty() -> Region {
1055 Region {
1056 labels: Vec::new(),
1057 slice: Slice::new(0, Vec::new(), Vec::new()).unwrap(),
1058 }
1059 }
1060
1061 #[allow(dead_code)]
1064 pub fn new(labels: Vec<String>, slice: Slice) -> Self {
1065 Self { labels, slice }
1066 }
1067
1068 pub fn labels(&self) -> &[String] {
1070 &self.labels
1071 }
1072
1073 pub fn slice(&self) -> &Slice {
1076 &self.slice
1077 }
1078
1079 pub fn into_inner(self) -> (Vec<String>, Slice) {
1081 (self.labels, self.slice)
1082 }
1083
1084 pub fn extent(&self) -> Extent {
1086 Extent::new(self.labels.clone(), self.slice.sizes().to_vec()).unwrap()
1087 }
1088
1089 pub fn is_subset(&self, other: &Region) -> bool {
1092 let mut left = self.slice.iter().peekable();
1093 let mut right = other.slice.iter().peekable();
1094
1095 loop {
1096 match (left.peek(), right.peek()) {
1097 (Some(l), Some(r)) => {
1098 if l < r {
1099 return false;
1100 } else if l == r {
1101 left.next();
1102 right.next();
1103 } else {
1104 right.next();
1106 }
1107 }
1108 (Some(_), None) => return false,
1109 (None, _) => return true,
1110 }
1111 }
1112 }
1113
1114 pub fn remap(&self, target: &Region) -> Option<impl Iterator<Item = usize> + '_> {
1134 if !target.is_subset(self) {
1135 return None;
1136 }
1137
1138 let mut ours = self.slice.iter().enumerate();
1139 let mut theirs = target.slice.iter();
1140
1141 Some(std::iter::from_fn(move || {
1142 let needle = theirs.next()?;
1143 loop {
1144 let (index, value) = ours.next().unwrap();
1145 if value == needle {
1146 break Some(index);
1147 }
1148 }
1149 }))
1150 }
1151
1152 pub fn num_ranks(&self) -> usize {
1154 self.slice.len()
1155 }
1156
1157 pub fn base_rank_of_point(&self, p: Point) -> Result<usize, RegionError> {
1160 if p.extent() != &self.extent() {
1161 return Err(RegionError::InvalidPoint(
1162 "mismatched extent: p must be a point in this region’s extent".to_string(),
1163 ));
1164 }
1165
1166 Ok(self
1167 .slice()
1168 .location(&p.coords())
1169 .expect("should have valid location since extent is checked"))
1170 }
1171
1172 pub fn point_of_base_rank(&self, rank: usize) -> Result<Point, RegionError> {
1175 let coords = self
1176 .slice()
1177 .coordinates(rank)
1178 .map_err(|e| RegionError::OutOfRangeBaseRank(rank, e.to_string()))?;
1179 Ok(self
1180 .extent()
1181 .point(coords)
1182 .expect("should have valid point since coords is from this region"))
1183 }
1184}
1185
1186impl From<Extent> for Region {
1189 fn from(extent: Extent) -> Self {
1190 Region {
1191 labels: extent.labels().to_vec(),
1192 slice: extent.to_slice(),
1193 }
1194 }
1195}
1196
1197impl From<&Shape> for Region {
1198 fn from(s: &Shape) -> Self {
1199 Region {
1200 labels: s.labels().to_vec(),
1201 slice: s.slice().clone(),
1202 }
1203 }
1204}
1205
1206impl From<Shape> for Region {
1207 fn from(s: Shape) -> Self {
1208 Region::from(&s)
1209 }
1210}
1211
1212impl std::fmt::Display for Region {
1252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1253 if self.slice.offset() != 0 {
1254 write!(f, "{}+", self.slice.offset())?;
1255 }
1256 for i in 0..self.labels.len() {
1257 write!(
1258 f,
1259 "{}={}/{}",
1260 labels::fmt_label(&self.labels[i]),
1261 self.slice.sizes()[i],
1262 self.slice.strides()[i]
1263 )?;
1264 if i + 1 != self.labels.len() {
1265 write!(f, ",")?;
1266 }
1267 }
1268 Ok(())
1269 }
1270}
1271
1272#[derive(Debug, thiserror::Error)]
1273pub enum RegionParseError {
1274 #[error(transparent)]
1275 ParserError(#[from] ParserError),
1276
1277 #[error(transparent)]
1278 SliceError(#[from] SliceError),
1279}
1280
1281impl std::str::FromStr for Region {
1302 type Err = RegionParseError;
1303
1304 fn from_str(s: &str) -> Result<Self, Self::Err> {
1305 let mut parser = Parser::new(s, &["+", "=", ",", "/"]);
1306
1307 let offset: usize = if let Ok(offset) = parser.try_parse() {
1308 parser.expect("+")?;
1309 offset
1310 } else {
1311 0
1312 };
1313
1314 let mut labels = Vec::new();
1315 let mut sizes = Vec::new();
1316 let mut strides = Vec::new();
1317
1318 while !parser.is_empty() {
1319 if !labels.is_empty() {
1320 parser.expect(",")?;
1321 }
1322
1323 let label = if parser.peek_char() == Some('"') {
1325 parser.parse_string_literal()?
1326 } else {
1327 parser.next_or_err("label")?.to_string()
1328 };
1329 labels.push(label);
1330
1331 parser.expect("=")?;
1332 sizes.push(parser.try_parse()?);
1333 parser.expect("/")?;
1334 strides.push(parser.try_parse()?);
1335 }
1336
1337 Ok(Region {
1338 labels,
1339 slice: Slice::new(offset, sizes, strides)?,
1340 })
1341 }
1342}
1343
1344pub trait BuildFromRegion<T>: Sized {
1357 type Error;
1358
1359 fn build_dense(region: Region, values: Vec<T>) -> Result<Self, Self::Error>;
1361
1362 fn build_dense_unchecked(region: Region, values: Vec<T>) -> Self;
1364}
1365
1366pub trait BuildFromRegionIndexed<T>: Sized {
1380 type Error;
1381
1382 fn build_indexed(
1385 region: Region,
1386 pairs: impl IntoIterator<Item = (usize, T)>,
1387 ) -> Result<Self, Self::Error>;
1388}
1389
1390pub trait CollectMeshExt<T>: Iterator<Item = T> + Sized {
1404 fn collect_mesh<M>(self, region: Region) -> Result<M, M::Error>
1405 where
1406 M: BuildFromRegion<T>;
1407}
1408
1409impl<I, T> CollectMeshExt<T> for I
1412where
1413 I: Iterator<Item = T> + Sized,
1414{
1415 fn collect_mesh<M>(self, region: Region) -> Result<M, M::Error>
1416 where
1417 M: BuildFromRegion<T>,
1418 {
1419 M::build_dense(region, self.collect())
1420 }
1421}
1422
1423#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1425pub struct InvalidCardinality {
1426 pub expected: usize,
1427 pub actual: usize,
1428}
1429
1430pub trait CollectExactMeshExt<T>: ExactSizeIterator<Item = T> + Sized {
1437 fn collect_exact_mesh<M>(self, region: Region) -> Result<M, M::Error>
1438 where
1439 M: BuildFromRegion<T>,
1440 M::Error: From<InvalidCardinality>;
1441}
1442
1443impl<I, T> CollectExactMeshExt<T> for I
1446where
1447 I: ExactSizeIterator<Item = T> + Sized,
1448{
1449 fn collect_exact_mesh<M>(self, region: Region) -> Result<M, M::Error>
1450 where
1451 M: BuildFromRegion<T>,
1452 M::Error: From<InvalidCardinality>,
1453 {
1454 let expected = region.num_ranks();
1455 let actual = self.len();
1456 if actual != expected {
1457 return Err(M::Error::from(InvalidCardinality { expected, actual }));
1458 }
1459 Ok(M::build_dense_unchecked(region, self.collect()))
1460 }
1461}
1462
1463pub trait CollectIndexedMeshExt<T>: Iterator<Item = (usize, T)> + Sized {
1485 fn collect_indexed<M>(self, region: Region) -> Result<M, M::Error>
1486 where
1487 M: BuildFromRegionIndexed<T>;
1488}
1489
1490impl<I, T> CollectIndexedMeshExt<T> for I
1493where
1494 I: Iterator<Item = (usize, T)> + Sized,
1495{
1496 #[inline]
1497 fn collect_indexed<M>(self, region: Region) -> Result<M, M::Error>
1498 where
1499 M: BuildFromRegionIndexed<T>,
1500 {
1501 M::build_indexed(region, self)
1502 }
1503}
1504
1505pub trait MapIntoExt: Ranked {
1507 fn map_into<M, U>(&self, f: impl Fn(&Self::Item) -> U) -> M
1508 where
1509 Self: Sized,
1510 M: BuildFromRegion<U>,
1511 {
1512 let region = self.region().clone();
1513 let n = region.num_ranks();
1514 let values: Vec<U> = (0..n).map(|i| f(self.get(i).unwrap())).collect();
1515 M::build_dense_unchecked(region, values)
1516 }
1517
1518 fn try_map_into<M, U, E>(self, f: impl Fn(&Self::Item) -> Result<U, E>) -> Result<M, E>
1519 where
1520 Self: Sized,
1521 M: BuildFromRegion<U>,
1522 {
1523 let region = self.region().clone();
1524 let n = region.num_ranks();
1525 let mut out = Vec::with_capacity(n);
1526 for i in 0..n {
1527 out.push(f(self.get(i).unwrap())?);
1528 }
1529 Ok(M::build_dense_unchecked(region, out))
1530 }
1531}
1532
1533impl<T: Ranked> MapIntoExt for T {}
1536
1537pub trait View: Sized {
1539 type Item;
1541
1542 type View: View;
1544
1545 fn region(&self) -> Region;
1547
1548 fn get(&self, rank: usize) -> Option<Self::Item>;
1552
1553 #[allow(clippy::result_large_err)] fn subset(&self, region: Region) -> Result<Self::View, ViewError>;
1558}
1559
1560impl View for Region {
1562 type Item = usize;
1564
1565 type View = Region;
1567
1568 fn region(&self) -> Region {
1569 self.clone()
1570 }
1571
1572 fn subset(&self, region: Region) -> Result<Region, ViewError> {
1573 if region.is_subset(self) {
1574 Ok(region)
1575 } else {
1576 Err(ViewError::InvalidRange {
1577 base: Box::new(self.clone()),
1578 selected: Box::new(region),
1579 })
1580 }
1581 }
1582
1583 fn get(&self, rank: usize) -> Option<Self::Item> {
1584 self.slice.get(rank).ok()
1585 }
1586}
1587
1588impl View for Extent {
1590 type Item = usize;
1592
1593 type View = Region;
1596
1597 fn region(&self) -> Region {
1598 Region {
1599 labels: self.labels().to_vec(),
1600 slice: self.to_slice(),
1601 }
1602 }
1603
1604 fn subset(&self, region: Region) -> Result<Region, ViewError> {
1605 self.region().subset(region)
1606 }
1607
1608 fn get(&self, rank: usize) -> Option<Self::Item> {
1609 if rank < self.num_ranks() {
1610 Some(rank)
1611 } else {
1612 None
1613 }
1614 }
1615}
1616
1617pub trait Ranked: Sized {
1620 type Item: 'static;
1622
1623 fn region(&self) -> &Region;
1625
1626 fn get(&self, rank: usize) -> Option<&Self::Item>;
1628}
1629
1630pub trait RankedSliceable: Ranked {
1638 fn sliced(&self, region: Region) -> Self;
1642}
1643
1644impl<T: RankedSliceable> View for T
1645where
1646 T::Item: Clone + 'static,
1647{
1648 type Item = T::Item;
1649 type View = Self;
1650
1651 fn region(&self) -> Region {
1652 <Self as Ranked>::region(self).clone()
1653 }
1654
1655 fn get(&self, rank: usize) -> Option<Self::Item> {
1656 <Self as Ranked>::get(self, rank).cloned()
1657 }
1658
1659 fn subset(&self, region: Region) -> Result<Self, ViewError> {
1660 if !region.is_subset(self.region()) {
1661 return Err(ViewError::InvalidRange {
1662 base: Box::new(self.region().clone()),
1663 selected: Box::new(region.clone()),
1664 });
1665 }
1666
1667 Ok(self.sliced(region))
1668 }
1669}
1670
1671pub struct ViewIterator {
1673 extent: Extent, pos: SliceIterator, }
1676
1677impl Iterator for ViewIterator {
1678 type Item = (Point, usize);
1679 fn next(&mut self) -> Option<Self::Item> {
1680 let rank = self.pos.next()?;
1682 let coords = self.pos.slice.coordinates(rank).unwrap();
1684 let point = coords.in_(&self.extent).unwrap();
1685 Some((point, rank))
1686 }
1687}
1688
1689pub trait ViewExt: View {
1691 #[allow(clippy::result_large_err)] fn range<R: Into<Range>>(&self, dim: &str, range: R) -> Result<Self::View, ViewError>;
1718
1719 #[allow(clippy::result_large_err)] fn group_by(&self, dim: &str) -> Result<impl Iterator<Item = Self::View>, ViewError>;
1756
1757 fn extent(&self) -> Extent;
1759
1760 fn iter<'a>(&'a self) -> impl Iterator<Item = (Point, Self::Item)> + 'a;
1762
1763 fn values<'a>(&'a self) -> impl Iterator<Item = Self::Item> + 'a;
1765}
1766
1767impl<T: View> ViewExt for T {
1768 fn range<R: Into<Range>>(&self, dim: &str, range: R) -> Result<Self::View, ViewError> {
1769 let (labels, slice) = self.region().into_inner();
1770 let range = range.into();
1771 let dim = labels
1772 .iter()
1773 .position(|l| dim == l)
1774 .ok_or_else(|| ViewError::InvalidDim(dim.to_string()))?;
1775 let (mut offset, mut sizes, mut strides) = slice.into_inner();
1776 let (begin, end, step) = range.resolve(sizes[dim]);
1777 if end <= begin {
1778 return Err(ViewError::EmptyRange {
1779 range,
1780 dim: dim.to_string(),
1781 size: sizes[dim],
1782 });
1783 }
1784
1785 offset += strides[dim] * begin;
1786 sizes[dim] = (end - begin).div_ceil(step);
1787 strides[dim] *= step;
1788 let slice = Slice::new(offset, sizes, strides).unwrap();
1789
1790 self.subset(Region { labels, slice })
1791 }
1792
1793 fn group_by(&self, dim: &str) -> Result<impl Iterator<Item = Self::View>, ViewError> {
1794 let (labels, slice) = self.region().into_inner();
1795
1796 let dim = labels
1797 .iter()
1798 .position(|l| dim == l)
1799 .ok_or_else(|| ViewError::InvalidDim(dim.to_string()))?;
1800
1801 let (offset, sizes, strides) = slice.into_inner();
1802 let mut ranks_iter = Slice::new(offset, sizes[..dim].to_vec(), strides[..dim].to_vec())
1803 .unwrap()
1804 .iter();
1805
1806 let labels = labels[dim..].to_vec();
1807 let sizes = sizes[dim..].to_vec();
1808 let strides = strides[dim..].to_vec();
1809
1810 Ok(std::iter::from_fn(move || {
1811 let rank = ranks_iter.next()?;
1812 let slice = Slice::new(rank, sizes.clone(), strides.clone()).unwrap();
1813 Some(
1815 self.subset(Region {
1816 labels: labels.clone(),
1817 slice,
1818 })
1819 .unwrap(),
1820 )
1821 }))
1822 }
1823
1824 fn extent(&self) -> Extent {
1825 let (labels, slice) = self.region().into_inner();
1826 Extent::new(labels, slice.sizes().to_vec()).unwrap()
1827 }
1828
1829 fn iter(&self) -> impl Iterator<Item = (Point, Self::Item)> + '_ {
1830 let points = ViewIterator {
1831 extent: self.extent(),
1832 pos: self.region().slice().iter(),
1833 };
1834
1835 points.map(|(point, _)| (point.clone(), self.get(point.rank()).unwrap()))
1836 }
1837
1838 fn values(&self) -> impl Iterator<Item = Self::Item> + '_ {
1839 (0usize..self.extent().num_ranks()).map(|rank| self.get(rank).unwrap())
1840 }
1841}
1842
1843#[macro_export]
1851macro_rules! extent {
1852 ( $( $label:ident = $size:expr ),* $(,)? ) => {
1853 {
1854 #[allow(unused_mut)]
1855 let mut labels = Vec::new();
1856 #[allow(unused_mut)]
1857 let mut sizes = Vec::new();
1858
1859 $(
1860 labels.push(stringify!($label).to_string());
1861 sizes.push($size);
1862 )*
1863
1864 $crate::view::Extent::new(labels, sizes).unwrap()
1865 }
1866 };
1867}
1868
1869#[cfg(test)]
1870mod test {
1871 use super::labels::*;
1872 use super::*;
1873 use crate::Shape;
1874 use crate::shape;
1875 use crate::slice::CartesianIterator;
1876
1877 #[test]
1878 fn test_is_safe_ident() {
1879 assert!(is_safe_ident("x"));
1880 assert!(is_safe_ident("gpu_0"));
1881 assert!(!is_safe_ident("dim/0"));
1882 assert!(!is_safe_ident("x y"));
1883 assert!(!is_safe_ident("x=y"));
1884 }
1885 #[test]
1886 fn test_fmt_label() {
1887 assert_eq!(fmt_label("x"), "x");
1888 assert_eq!(fmt_label("dim/0"), "\"dim/0\"");
1889 }
1890
1891 #[test]
1892 fn test_points_basic() {
1893 let extent = extent!(x = 4, y = 5, z = 6);
1894 let _p1 = extent.point(vec![1, 2, 3]).unwrap();
1895 let _p2 = vec![1, 2, 3].in_(&extent).unwrap();
1896
1897 assert_eq!(extent.num_ranks(), 4 * 5 * 6);
1898
1899 let p3 = extent.point_of_rank(0).unwrap();
1900 assert_eq!(p3.coords(), &[0, 0, 0]);
1901 assert_eq!(p3.rank(), 0);
1902
1903 let p4 = extent.point_of_rank(1).unwrap();
1904 assert_eq!(p4.coords(), &[0, 0, 1]);
1905 assert_eq!(p4.rank(), 1);
1906
1907 let p5 = extent.point_of_rank(2).unwrap();
1908 assert_eq!(p5.coords(), &[0, 0, 2]);
1909 assert_eq!(p5.rank(), 2);
1910
1911 let p6 = extent.point_of_rank(6 * 5 + 1).unwrap();
1912 assert_eq!(p6.coords(), &[1, 0, 1]);
1913 assert_eq!(p6.rank(), 6 * 5 + 1);
1914 assert_eq!(p6.coord(0), 1);
1915 assert_eq!(p6.coord(1), 0);
1916 assert_eq!(p6.coord(2), 1);
1917
1918 assert_eq!(extent.points().collect::<Vec<_>>().len(), 4 * 5 * 6);
1919 for (rank, point) in extent.points().enumerate() {
1920 let c = point.coords();
1921 let (x, y, z) = (c[0], c[1], c[2]);
1922 assert_eq!(z + y * 6 + x * 6 * 5, rank);
1923 assert_eq!(point.rank(), rank);
1924 }
1925 }
1926
1927 #[test]
1928 fn points_iterates_ranks_in_row_major_order() {
1929 let ext = extent!(x = 2, y = 3, z = 4); let mut it = ext.points();
1931
1932 for expected_rank in 0..ext.num_ranks() {
1933 let p = it.next().expect("expected another Point");
1934 assert_eq!(
1935 p.rank, expected_rank,
1936 "ranks must be consecutive in row-major order"
1937 );
1938 }
1939 assert!(
1940 it.next().is_none(),
1941 "iterator must be exhausted after num_ranks items"
1942 );
1943 }
1944
1945 #[test]
1946 fn points_iterates_single_point_for_0d_extent() {
1947 let ext = extent!();
1949 let mut it = ext.points();
1950
1951 let p = it
1952 .next()
1953 .expect("0-D extent should yield exactly one point");
1954 assert_eq!(p.rank, 0);
1955 assert_eq!(p.extent, ext);
1956
1957 assert!(
1958 it.next().is_none(),
1959 "no more points after the single 0-D point"
1960 );
1961 }
1962
1963 macro_rules! assert_view {
1964 ($view:expr, $extent:expr, $( $($coord:expr),+ => $rank:expr );* $(;)?) => {
1965 let view = $view;
1966 assert_eq!(view.extent(), $extent);
1967 let expected: Vec<_> = vec![$(($extent.point(vec![$($coord),+]).unwrap(), $rank)),*];
1968 let actual: Vec<_> = ViewExt::iter(&view).collect();
1969 assert_eq!(actual, expected);
1970 };
1971 }
1972
1973 #[test]
1974 fn test_view_basic() {
1975 let extent = extent!(x = 4, y = 4);
1976 assert_view!(
1977 extent.range("x", 0..2).unwrap(),
1978 extent!(x = 2, y = 4),
1979 0, 0 => 0;
1980 0, 1 => 1;
1981 0, 2 => 2;
1982 0, 3 => 3;
1983 1, 0 => 4;
1984 1, 1 => 5;
1985 1, 2 => 6;
1986 1, 3 => 7;
1987 );
1988 assert_view!(
1989 extent.range("x", 1).unwrap().range("y", 2..).unwrap(),
1990 extent!(x = 1, y = 2),
1991 0, 0 => 6;
1992 0, 1 => 7;
1993 );
1994 assert_view!(
1995 extent.range("y", Range(0, None, 2)).unwrap(),
1996 extent!(x = 4, y = 2),
1997 0, 0 => 0;
1998 0, 1 => 2;
1999 1, 0 => 4;
2000 1, 1 => 6;
2001 2, 0 => 8;
2002 2, 1 => 10;
2003 3, 0 => 12;
2004 3, 1 => 14;
2005 );
2006 assert_view!(
2007 extent.range("y", Range(0, None, 2)).unwrap().range("x", 2..).unwrap(),
2008 extent!(x = 2, y = 2),
2009 0, 0 => 8;
2010 0, 1 => 10;
2011 1, 0 => 12;
2012 1, 1 => 14;
2013 );
2014
2015 let extent = extent!(x = 10, y = 2);
2016 assert_view!(
2017 extent.range("x", Range(0, None, 2)).unwrap(),
2018 extent!(x = 5, y = 2),
2019 0, 0 => 0;
2020 0, 1 => 1;
2021 1, 0 => 4;
2022 1, 1 => 5;
2023 2, 0 => 8;
2024 2, 1 => 9;
2025 3, 0 => 12;
2026 3, 1 => 13;
2027 4, 0 => 16;
2028 4, 1 => 17;
2029 );
2030 assert_view!(
2031 extent.range("x", Range(0, None, 2)).unwrap().range("x", 2..).unwrap().range("y", 1).unwrap(),
2032 extent!(x = 3, y = 1),
2033 0, 0 => 9;
2034 1, 0 => 13;
2035 2, 0 => 17;
2036 );
2037
2038 let extent = extent!(zone = 4, host = 2, gpu = 8);
2039 assert_view!(
2040 extent.range("zone", 0).unwrap().range("gpu", Range(0, None, 2)).unwrap(),
2041 extent!(zone = 1, host = 2, gpu = 4),
2042 0, 0, 0 => 0;
2043 0, 0, 1 => 2;
2044 0, 0, 2 => 4;
2045 0, 0, 3 => 6;
2046 0, 1, 0 => 8;
2047 0, 1, 1 => 10;
2048 0, 1, 2 => 12;
2049 0, 1, 3 => 14;
2050 );
2051
2052 let extent = extent!(x = 3);
2053 assert_view!(
2054 extent.range("x", Range(0, None, 2)).unwrap(),
2055 extent!(x = 2),
2056 0 => 0;
2057 1 => 2;
2058 );
2059 }
2060
2061 #[test]
2062 fn test_point_indexing() {
2063 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
2064 let point = extent.point(vec![1, 2, 3]).unwrap();
2065
2066 assert_eq!(point.coord(0), 1);
2067 assert_eq!(point.coord(1), 2);
2068 assert_eq!(point.coord(2), 3);
2069 }
2070
2071 #[test]
2072 #[should_panic]
2073 fn test_point_indexing_out_of_bounds() {
2074 let extent = Extent::new(vec!["x".into(), "y".into()], vec![4, 5]).unwrap();
2075 let point = extent.point(vec![1, 2]).unwrap();
2076
2077 let _ = point.coord(5); }
2079
2080 #[test]
2081 fn test_point_into_iter() {
2082 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
2083 let point = extent.point(vec![1, 2, 3]).unwrap();
2084
2085 let coords: Vec<usize> = (&point).into_iter().collect();
2086 assert_eq!(coords, vec![1, 2, 3]);
2087
2088 let mut sum = 0;
2089 for coord in &point {
2090 sum += coord;
2091 }
2092 assert_eq!(sum, 6);
2093 }
2094
2095 #[test]
2096 fn test_extent_basic() {
2097 let extent = extent!(x = 10, y = 5, z = 1);
2098 assert_eq!(
2099 extent.iter().collect::<Vec<_>>(),
2100 vec![
2101 ("x".to_string(), 10),
2102 ("y".to_string(), 5),
2103 ("z".to_string(), 1)
2104 ]
2105 );
2106 }
2107
2108 #[test]
2109 fn test_extent_display() {
2110 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
2111 assert_eq!(format!("{}", extent), "{x: 4, y: 5, z: 6}");
2112
2113 let extent = Extent::new(vec!["dim/0".into(), "dim/1".into()], vec![4, 5]).unwrap();
2114 assert_eq!(format!("{}", extent), "{\"dim/0\": 4, \"dim/1\": 5}");
2115
2116 let empty_extent = Extent::new(vec![], vec![]).unwrap();
2117 assert_eq!(format!("{}", empty_extent), "{}");
2118 }
2119
2120 #[test]
2121 fn extent_label_helpers() {
2122 let e = extent!(zone = 3, host = 2, gpu = 4);
2123 for (i, (lbl, sz)) in e.iter().enumerate() {
2124 assert_eq!(e.position(&lbl), Some(i));
2125 assert_eq!(e.size(&lbl), Some(sz));
2126 }
2127 assert_eq!(e.position("nope"), None);
2128 assert_eq!(e.size("nope"), None);
2129 }
2130
2131 #[test]
2132 fn test_extent_0d() {
2133 let e = Extent::new(vec![], vec![]).unwrap();
2134 assert_eq!(e.num_ranks(), 1);
2135
2136 let points: Vec<_> = e.points().collect();
2137 assert_eq!(points.len(), 1);
2138 assert_eq!(points[0].coords(), &[]);
2139 assert_eq!(points[0].rank(), 0);
2140
2141 let mut it = (&points[0]).into_iter();
2143 assert_eq!(it.len(), 0);
2144 assert!(it.next().is_none()); assert!(it.next().is_none()); }
2147
2148 #[test]
2149 fn test_extent_concat() {
2150 let extent1 = extent!(x = 2, y = 3);
2152 let extent2 = extent!(z = 4, w = 5);
2153
2154 let result = extent1.concat(&extent2).unwrap();
2155 assert_eq!(result.labels(), &["x", "y", "z", "w"]);
2156 assert_eq!(result.sizes(), &[2, 3, 4, 5]);
2157 assert_eq!(result.num_ranks(), 2 * 3 * 4 * 5);
2158
2159 let empty = extent!();
2161 let result = extent1.concat(&empty).unwrap();
2162 assert_eq!(result.labels(), &["x", "y"]);
2163 assert_eq!(result.sizes(), &[2, 3]);
2164
2165 let result = empty.concat(&extent1).unwrap();
2166 assert_eq!(result.labels(), &["x", "y"]);
2167 assert_eq!(result.sizes(), &[2, 3]);
2168
2169 let result = empty.concat(&empty).unwrap();
2171 assert_eq!(result.labels(), &[] as &[String]);
2172 assert_eq!(result.sizes(), &[] as &[usize]);
2173 assert_eq!(result.num_ranks(), 1); let result = extent1.concat(&extent1);
2177 assert!(
2178 result.is_err(),
2179 "Self-concatenation should error due to overlapping labels"
2180 );
2181 match result.unwrap_err() {
2182 ExtentError::OverlappingLabel { label } => {
2183 assert!(label == "x"); }
2185 other => panic!("Expected OverlappingLabel error, got {:?}", other),
2186 }
2187
2188 let result = extent1.concat(&extent2).unwrap();
2190 let point = result.point(vec![1, 2, 3, 4]).unwrap();
2191 assert_eq!(point.coords(), vec![1, 2, 3, 4]);
2192 assert_eq!(point.extent(), &result);
2193
2194 let extent_a = extent!(x = 2, y = 3);
2196 let extent_b = extent!(y = 3, z = 4); let result = extent_a.concat(&extent_b);
2198 assert!(
2199 result.is_err(),
2200 "Should error on overlapping labels even with same size"
2201 );
2202 match result.unwrap_err() {
2203 ExtentError::OverlappingLabel { label } => {
2204 assert_eq!(label, "y"); }
2206 other => panic!("Expected OverlappingLabel error, got {:?}", other),
2207 }
2208
2209 let extent_x = extent!(x = 2, y = 3);
2211 let extent_y = extent!(z = 4);
2212 assert_eq!(
2213 extent_x.concat(&extent_y).unwrap().labels(),
2214 &["x", "y", "z"]
2215 );
2216 assert_eq!(
2217 extent_y.concat(&extent_x).unwrap().labels(),
2218 &["z", "x", "y"]
2219 );
2220
2221 let extent_m = extent!(x = 2);
2223 let extent_n = extent!(y = 3);
2224 let extent_o = extent!(z = 4);
2225
2226 let left_assoc = extent_m
2227 .concat(&extent_n)
2228 .unwrap()
2229 .concat(&extent_o)
2230 .unwrap();
2231 let right_assoc = extent_m
2232 .concat(&extent_n.concat(&extent_o).unwrap())
2233 .unwrap();
2234
2235 assert_eq!(left_assoc, right_assoc);
2236 assert_eq!(left_assoc.labels(), &["x", "y", "z"]);
2237 assert_eq!(left_assoc.sizes(), &[2, 3, 4]);
2238 assert_eq!(left_assoc.num_ranks(), 2 * 3 * 4);
2239 }
2240
2241 #[test]
2242 fn extent_unity_equiv_to_0d() {
2243 let e = Extent::unity();
2244 assert!(e.is_empty());
2245 assert_eq!(e.num_ranks(), 1);
2246 let pts: Vec<_> = e.points().collect();
2247 assert_eq!(pts.len(), 1);
2248 assert_eq!(pts[0].rank(), 0);
2249 assert!(pts[0].coords().is_empty());
2250 }
2251
2252 #[test]
2253 fn test_point_display() {
2254 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
2255 let point = extent.point(vec![1, 2, 3]).unwrap();
2256 assert_eq!(format!("{}", point), "x=1/4,y=2/5,z=3/6");
2257
2258 assert!(extent.point(vec![]).is_err());
2259
2260 let empty_extent = Extent::new(vec![], vec![]).unwrap();
2261 let empty_point = empty_extent.point(vec![]).unwrap();
2262 assert_eq!(format!("{}", empty_point), "");
2263 }
2264
2265 #[test]
2266 fn test_point_display_with_quoted_labels() {
2267 let ext = Extent::new(vec!["dim/0".into(), "dim,1".into()], vec![3, 5]).unwrap();
2269
2270 assert_eq!(format!("{}", ext), "{\"dim/0\": 3, \"dim,1\": 5}");
2272
2273 let p = ext.point(vec![1, 2]).unwrap();
2275 assert_eq!(format!("{}", p), "\"dim/0\"=1/3,\"dim,1\"=2/5");
2276 }
2277
2278 #[test]
2279 fn test_relative_point() {
2280 pub fn relative_point(rank_on_root_mesh: usize, shape: &Shape) -> anyhow::Result<Point> {
2283 let coords = shape.slice().coordinates(rank_on_root_mesh)?;
2284 let extent = Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec())?;
2285 Ok(extent.point(coords)?)
2286 }
2287
2288 let root_shape = shape! { replicas = 4, hosts = 4, gpus = 4 };
2289 let sliced_shape = root_shape
2301 .select("replicas", crate::Range(0, Some(4), 3))
2302 .unwrap()
2303 .select("hosts", crate::Range(1, Some(4), 2))
2304 .unwrap()
2305 .select("gpus", crate::Range(0, Some(4), 2))
2306 .unwrap();
2307 let ranks_on_root_mesh = &[4, 6, 12, 14, 52, 54, 60, 62];
2308 assert_eq!(
2309 sliced_shape.slice().iter().collect::<Vec<_>>(),
2310 ranks_on_root_mesh,
2311 );
2312
2313 let ranks_on_sliced_mesh = ranks_on_root_mesh
2314 .iter()
2315 .map(|&r| relative_point(r, &sliced_shape).unwrap().rank());
2316 assert_eq!(
2317 ranks_on_sliced_mesh.collect::<Vec<_>>(),
2318 vec![0, 1, 2, 3, 4, 5, 6, 7]
2319 );
2320 }
2321
2322 #[test]
2323 fn test_iter_subviews() {
2324 let extent = extent!(zone = 4, host = 4, gpu = 8);
2325
2326 assert_eq!(extent.group_by("gpu").unwrap().count(), 16);
2327 assert_eq!(extent.group_by("zone").unwrap().count(), 1);
2328
2329 let mut parts = extent.group_by("gpu").unwrap();
2330 assert_view!(
2331 parts.next().unwrap(),
2332 extent!(gpu = 8),
2333 0 => 0;
2334 1 => 1;
2335 2 => 2;
2336 3 => 3;
2337 4 => 4;
2338 5 => 5;
2339 6 => 6;
2340 7 => 7;
2341 );
2342 assert_view!(
2343 parts.next().unwrap(),
2344 extent!(gpu = 8),
2345 0 => 8;
2346 1 => 9;
2347 2 => 10;
2348 3 => 11;
2349 4 => 12;
2350 5 => 13;
2351 6 => 14;
2352 7 => 15;
2353 );
2354 }
2355
2356 #[test]
2357 fn test_view_values() {
2358 let extent = extent!(x = 4, y = 4);
2359 assert_eq!(
2360 extent.values().collect::<Vec<_>>(),
2361 (0..16).collect::<Vec<_>>()
2362 );
2363 let region = extent.range("y", 1).unwrap();
2364 assert_eq!(region.values().collect::<Vec<_>>(), vec![1, 5, 9, 13]);
2365 }
2366
2367 #[test]
2368 fn region_is_subset_algebra() {
2369 let e = extent!(x = 5, y = 4);
2370 let a = e.range("x", 1..4).unwrap(); let b = a.range("y", 1..3).unwrap(); let c = e.range("x", 0..2).unwrap(); assert!(b.region().is_subset(&a.region()));
2375 assert!(b.region().is_subset(&e.region()));
2376 assert!(a.region().is_subset(&e.region()));
2377
2378 assert!(!c.region().is_subset(&a.region()));
2379 assert!(c.region().is_subset(&e.region()));
2380 }
2381
2382 #[test]
2383 fn test_remap() {
2384 let region: Region = extent!(x = 4, y = 4).into();
2385 assert_eq!(
2387 region.remap(®ion).unwrap().collect::<Vec<_>>(),
2388 (0..16).collect::<Vec<_>>()
2389 );
2390
2391 let subset = region.range("x", 2..).unwrap();
2392 assert_eq!(subset.num_ranks(), 8);
2393 assert_eq!(
2394 region.remap(&subset).unwrap().collect::<Vec<_>>(),
2395 vec![8, 9, 10, 11, 12, 13, 14, 15],
2396 );
2397
2398 let subset = subset.range("y", 1).unwrap();
2399 assert_eq!(subset.num_ranks(), 2);
2400 assert_eq!(
2401 region.remap(&subset).unwrap().collect::<Vec<_>>(),
2402 vec![9, 13],
2403 );
2404
2405 let ext = extent!(replica = 8, gpu = 4);
2408 let replica1 = ext.range("replica", 1).unwrap();
2409 assert_eq!(replica1.extent(), extent!(replica = 1, gpu = 4));
2410 let replica1_gpu12 = replica1.range("gpu", 1..3).unwrap();
2411 assert_eq!(replica1_gpu12.extent(), extent!(replica = 1, gpu = 2));
2412 assert_eq!(
2414 replica1.remap(&replica1_gpu12).unwrap().collect::<Vec<_>>(),
2415 vec![1, 2],
2416 );
2417 }
2418
2419 #[test]
2420 fn test_base_local_rank_conversion() {
2421 fn point(rank: usize, region: &Region) -> Point {
2422 region.extent().point_of_rank(rank).unwrap()
2423 }
2424
2425 let extent = extent!(replicas = 4, gpus = 2);
2426 let region = extent.range("replicas", 1..3).unwrap();
2427 assert!(
2435 region
2436 .base_rank_of_point(extent.point_of_rank(0).unwrap())
2437 .is_err()
2438 );
2439 assert_eq!(region.base_rank_of_point(point(0, ®ion)).unwrap(), 2);
2441 assert_eq!(region.base_rank_of_point(point(1, ®ion)).unwrap(), 3);
2442 assert_eq!(region.base_rank_of_point(point(2, ®ion)).unwrap(), 4);
2443 assert_eq!(region.base_rank_of_point(point(3, ®ion)).unwrap(), 5);
2444 assert_eq!(region.point_of_base_rank(2).unwrap(), point(0, ®ion));
2446 assert_eq!(region.point_of_base_rank(3).unwrap(), point(1, ®ion));
2447 assert_eq!(region.point_of_base_rank(4).unwrap(), point(2, ®ion));
2448 assert_eq!(region.point_of_base_rank(5).unwrap(), point(3, ®ion));
2449 assert!(region.point_of_base_rank(1).is_err());
2451 assert!(region.point_of_base_rank(6).is_err());
2452
2453 let subset = region
2455 .range("replicas", 1..2)
2456 .unwrap()
2457 .range("gpus", 1..2)
2458 .unwrap();
2459 assert_eq!(subset.base_rank_of_point(point(0, &subset)).unwrap(), 5);
2466 assert_eq!(subset.point_of_base_rank(5).unwrap(), point(0, &subset));
2467 assert!(subset.point_of_base_rank(4).is_err());
2469 assert!(subset.point_of_base_rank(6).is_err());
2470 }
2471
2472 use proptest::prelude::*;
2473
2474 use crate::strategy::gen_extent;
2475 use crate::strategy::gen_region;
2476 use crate::strategy::gen_region_strided;
2477
2478 proptest! {
2479 #[test]
2480 fn test_region_parser(region in gen_region(1..=5, 1024)) {
2481 assert_eq!(
2483 region,
2484 region.to_string().parse::<Region>().unwrap(),
2485 "failed to roundtrip region {}", region
2486 );
2487 }
2488 }
2489
2490 proptest! {
2502 #[test]
2503 fn region_parser_with_offset_roundtrips(region in gen_region(1..=4, 8)) {
2504 let (labels, slice) = region.clone().into_inner();
2505 let region_off = Region {
2506 labels,
2507 slice: Slice::new(8, slice.sizes().to_vec(), slice.strides().to_vec()).unwrap(),
2508 };
2509 let s = region_off.to_string();
2510 let parsed: Region = s.parse().unwrap();
2511 prop_assert_eq!(parsed, region_off);
2512 }
2513 }
2514
2515 proptest! {
2526 #[test]
2527 fn region_strided_display_parse_roundtrips(
2528 region in gen_region_strided(1..=4, 6, 3, 16)
2529 ) {
2530 let s = region.to_string();
2541 let parsed: Region = s.parse().unwrap();
2542 prop_assert_eq!(parsed, region);
2543 }
2544 }
2545
2546 proptest! {
2557 #[test]
2558 fn region_strided_display_matches_slice(
2559 region in gen_region_strided(1..=4, 6, 3, 16)
2560 ) {
2561 let s = region.to_string();
2562 let slice = region.slice();
2563
2564 if slice.offset() != 0 {
2566 let prefix: Vec<_> = s.split('+').collect();
2567 prop_assert!(prefix.len() > 1, "expected offset+ form in {}", s);
2568 let offset_str = prefix[0];
2569 let offset_val: usize = offset_str.parse().unwrap();
2570 prop_assert_eq!(offset_val, slice.offset(), "offset mismatch in {}", s);
2571 } else {
2572 prop_assert!(!s.contains('+'), "unexpected +offset in {}", s);
2573 }
2574
2575 let body = s.split('+').next_back().unwrap(); let parts: Vec<_> = body.split(',').collect();
2578 prop_assert_eq!(parts.len(), slice.sizes().len());
2579
2580 for (i, part) in parts.iter().enumerate() {
2581 let rhs = part.split('=').nth(1).unwrap();
2583 let mut nums = rhs.split('/');
2584 let size_val: usize = nums.next().unwrap().parse().unwrap();
2585 let stride_val: usize = nums.next().unwrap().parse().unwrap();
2586
2587 prop_assert_eq!(size_val, slice.sizes()[i], "size mismatch at dim {} in {}", i, s);
2588 prop_assert_eq!(stride_val, slice.strides()[i], "stride mismatch at dim {} in {}", i, s);
2589 }
2590 }
2591 }
2592
2593 #[test]
2594 fn test_point_from_str_round_trip() {
2595 let points = vec![
2596 extent!(x = 4, y = 5, z = 6).point(vec![1, 2, 3]).unwrap(),
2597 extent!(host = 2, gpu = 8).point(vec![0, 7]).unwrap(),
2598 extent!().point(vec![]).unwrap(),
2599 extent!(x = 10).point(vec![5]).unwrap(),
2600 ];
2601
2602 for point in points {
2603 assert_eq!(point, point.to_string().parse().unwrap());
2604 }
2605 }
2606
2607 #[test]
2608 fn test_point_from_str_basic() {
2609 let cases = vec![
2610 ("x=1/4,y=2/5", extent!(x = 4, y = 5), vec![1, 2]),
2611 ("host=0/2,gpu=7/8", extent!(host = 2, gpu = 8), vec![0, 7]),
2612 ("z=3/6", extent!(z = 6), vec![3]),
2613 ("", extent!(), vec![]), (" x = 1 / 4 , y = 2 / 5 ", extent!(x = 4, y = 5), vec![1, 2]),
2616 ];
2617
2618 for (input, expected_extent, expected_coords) in cases {
2619 let parsed: Point = input.parse().unwrap();
2620 let expected = expected_extent.point(expected_coords).unwrap();
2621 assert_eq!(parsed, expected, "failed to parse: {}", input);
2622 }
2623 }
2624
2625 #[test]
2626 fn test_point_from_str_quoted() {
2627 let extent = Extent::new(vec!["dim/0".into(), "dim,1".into()], vec![3, 5]).unwrap();
2629 let point = extent.point(vec![1, 2]).unwrap();
2630
2631 let display_str = point.to_string();
2632 assert_eq!(display_str, "\"dim/0\"=1/3,\"dim,1\"=2/5");
2633
2634 let parsed: Point = display_str.parse().unwrap();
2635 assert_eq!(parsed, point);
2636
2637 let parsed: Point = "\"dim/0\"=1/3,\"dim,1\"=2/5".parse().unwrap();
2638 assert_eq!(parsed, point);
2639 }
2640
2641 #[test]
2642 fn test_point_from_str_error_cases() {
2643 let error_cases = vec![
2645 "x=1,y=2/5", "x=1/4,y=2", "x=1/4,y=/5", "x=/4,y=2/5", "x=1/4,y=2/", "x=1/,y=2/5", "x=1/4=5,y=2/5", "x=1/4/6,y=2/5", "x=abc/4,y=2/5", "x=1/abc,y=2/5", "=1/4,y=2/5", "x=1/4,", "x=1/4,=2/5", "x=1/4,y", "x", "x=", "x=1/4,y=10/5", ];
2663
2664 for input in error_cases {
2665 let result: Result<Point, PointError> = input.parse();
2666 assert!(result.is_err(), "Expected error for input: '{}'", input);
2667 }
2668 }
2669
2670 #[test]
2671 fn test_point_from_str_coordinate_validation() {
2672 let input = "x=5/4,y=2/5"; let result: Result<Point, PointError> = input.parse();
2675 assert!(
2676 result.is_err(),
2677 "Expected error for out-of-bounds coordinate"
2678 );
2679
2680 match result.unwrap_err() {
2681 PointError::OutOfRangeIndex { size, index } => {
2682 assert_eq!(size, 4);
2683 assert_eq!(index, 5);
2684 }
2685 _ => panic!("Expected OutOfRangeIndex error"),
2686 }
2687 }
2688
2689 #[test]
2690 fn test_point_from_str_consistency_validation() {
2691 let input = "x=1/4,y=2/5,z=3/6";
2696 let parsed: Point = input.parse().unwrap();
2697
2698 assert_eq!(parsed.extent().labels(), &["x", "y", "z"]);
2699 assert_eq!(parsed.extent().sizes(), &[4, 5, 6]);
2700 assert_eq!(parsed.coords(), vec![1, 2, 3]);
2701 }
2702
2703 proptest! {
2704 #[test]
2707 fn point_coord_and_iter_agree(extent in gen_extent(0..=4, 8)) {
2708 for p in extent.points() {
2709 let via_coords = p.coords();
2710 let via_into_iter: Vec<_> = (&p).into_iter().collect();
2711 prop_assert_eq!(via_into_iter, via_coords.clone(), "coord_iter mismatch for {}", p);
2712
2713 for (i, &coord) in via_coords.iter().enumerate() {
2714 prop_assert_eq!(p.coord(i), coord, "coord(i) mismatch at axis {} for {}", i, p);
2715 }
2716 }
2717 }
2718
2719 #[test]
2721 fn points_count_matches_num_ranks(extent in gen_extent(0..=4, 8)) {
2722 let c = extent.points().count();
2723 prop_assert_eq!(c, extent.num_ranks(), "count {} != num_ranks {}", c, extent.num_ranks());
2724 }
2725
2726 #[test]
2730 fn coord_iter_exact_size_invariants(extent in gen_extent(0..=4, 8)) {
2731 for p in extent.points() {
2732 let mut it = (&p).into_iter();
2733
2734 let mut remaining = p.len();
2737 prop_assert_eq!(it.len(), remaining);
2738 prop_assert_eq!(it.size_hint(), (remaining, Some(remaining)));
2739
2740 let mut yielded = Vec::with_capacity(remaining);
2742
2743 while let Some(v) = it.next() {
2746 yielded.push(v);
2747 remaining -= 1;
2748 prop_assert_eq!(it.len(), remaining);
2749 prop_assert_eq!(it.size_hint(), (remaining, Some(remaining)));
2750 }
2751
2752 prop_assert_eq!(remaining, 0);
2755 prop_assert!(it.next().is_none());
2756 prop_assert!(it.next().is_none());
2757
2758 prop_assert_eq!(yielded, p.coords());
2760 }
2761 }
2762
2763 #[test]
2767 fn rank_of_coords_dim_mismatch(extent in gen_extent(0..=4, 8)) {
2768 let want = extent.len();
2769 let wrong = if want == 0 { 1 } else { want - 1 };
2771 let bad = vec![0usize; wrong];
2772
2773 match extent.rank_of_coords(&bad).unwrap_err() {
2774 PointError::DimMismatch { expected, actual } => {
2775 prop_assert_eq!(expected, want, "expected len mismatch");
2776 prop_assert_eq!(actual, wrong, "actual len mismatch");
2777 }
2778 other => prop_assert!(false, "expected DimMismatch, got {:?}", other),
2779 }
2780 }
2781
2782 #[test]
2787 fn rank_of_coords_out_of_range_index(extent in gen_extent(1..=4, 8)) {
2788 let sizes = extent.sizes().to_vec();
2790 let mut coords = vec![0usize; sizes.len()];
2792 let axis = 0usize;
2794 coords[axis] = sizes[axis];
2795
2796 match extent.rank_of_coords(&coords).unwrap_err() {
2797 PointError::OutOfRangeIndex { size, index } => {
2798 prop_assert_eq!(size, sizes[axis], "reported size mismatch");
2799 prop_assert_eq!(index, sizes[axis], "reported index mismatch");
2800 }
2801 other => prop_assert!(false, "expected OutOfRangeIndex, got {:?}", other),
2802 }
2803 }
2804
2805 #[test]
2808 fn point_of_rank_out_of_range(extent in gen_extent(0..=4, 8)) {
2809 let total = extent.num_ranks(); match extent.point_of_rank(total).unwrap_err() {
2811 PointError::OutOfRangeRank { total: t, rank: r } => {
2812 prop_assert_eq!(t, total, "reported total mismatch");
2813 prop_assert_eq!(r, total, "reported rank mismatch");
2814 }
2815 other => prop_assert!(false, "expected OutOfRangeRank, got {:?}", other),
2816 }
2817 }
2818
2819 #[test]
2821 fn point_display_parse_round_trip(extent in gen_extent(0..=4, 8)) {
2822 for point in extent.points() {
2823 let display = point.to_string();
2824 let parsed: Point = display.parse().unwrap();
2825 prop_assert_eq!(parsed, point, "round-trip failed for point: {}", display);
2826 }
2827 }
2828 }
2829
2830 proptest! {
2831 #[test]
2834 fn points_iterator_equivalent_to_legacy_cartesian(extent in gen_extent(0..=4, 8)) {
2835 let sizes = extent.sizes().to_vec();
2836
2837 let legacy_len = CartesianIterator::new(sizes.clone()).count();
2839 prop_assert_eq!(legacy_len, extent.num_ranks());
2840
2841 let legacy_coords = CartesianIterator::new(sizes);
2844 for (step, (p, coords)) in extent.points().zip(legacy_coords).enumerate() {
2845 let old_rank = extent.rank_of_coords(&coords).expect("valid legacy coords");
2846 prop_assert_eq!(p.rank(), old_rank, "rank mismatch at step {} for coords {:?}", step, coords);
2847 prop_assert_eq!(p.coords(), coords, "coords mismatch at step {}", step);
2848 }
2849 }
2850 }
2851}