1use std::str::FromStr;
33use std::sync::Arc;
34
35use serde::Deserialize;
36use serde::Serialize;
37use thiserror::Error;
38
39use crate::Range;
40use crate::Shape; use crate::Slice;
42use crate::SliceError;
43use crate::SliceIterator;
44use crate::parse::Parser;
45use crate::parse::ParserError;
46
47#[derive(Debug, thiserror::Error)]
49pub enum ExtentError {
50 #[error("label/sizes dimension mismatch: {num_labels} != {num_sizes}")]
56 DimMismatch {
57 num_labels: usize,
59 num_sizes: usize,
61 },
62
63 #[error("overlapping label found: {label}")]
68 OverlappingLabel {
69 label: String,
71 },
72}
73
74#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
82pub struct Extent {
83 inner: Arc<ExtentData>,
84}
85
86fn _assert_extent_traits()
87where
88 Extent: Send + Sync + 'static,
89{
90}
91
92#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
97struct ExtentData {
98 labels: Vec<String>,
99 sizes: Vec<usize>,
100}
101
102impl From<&Shape> for Extent {
103 fn from(s: &Shape) -> Self {
104 Extent::new(s.labels().to_vec(), s.slice().sizes().to_vec()).unwrap()
106 }
107}
108
109impl From<Shape> for Extent {
110 fn from(s: Shape) -> Self {
111 Extent::from(&s)
112 }
113}
114
115impl Extent {
116 pub fn new(labels: Vec<String>, sizes: Vec<usize>) -> Result<Self, ExtentError> {
118 if labels.len() != sizes.len() {
119 return Err(ExtentError::DimMismatch {
120 num_labels: labels.len(),
121 num_sizes: sizes.len(),
122 });
123 }
124
125 Ok(Self {
126 inner: Arc::new(ExtentData { labels, sizes }),
127 })
128 }
129
130 pub fn unity() -> Extent {
131 Extent::new(vec![], vec![]).unwrap()
132 }
133
134 pub fn labels(&self) -> &[String] {
136 &self.inner.labels
137 }
138
139 pub fn sizes(&self) -> &[usize] {
141 &self.inner.sizes
142 }
143
144 pub fn size(&self, label: &str) -> Option<usize> {
147 self.position(label).map(|pos| self.sizes()[pos])
148 }
149
150 pub fn position(&self, label: &str) -> Option<usize> {
153 self.labels().iter().position(|l| l == label)
154 }
155
156 pub fn rank_of_coords(&self, coords: &[usize]) -> Result<usize, PointError> {
166 let sizes = self.sizes();
167 if coords.len() != sizes.len() {
168 return Err(PointError::DimMismatch {
169 expected: sizes.len(),
170 actual: coords.len(),
171 });
172 }
173 let mut stride = 1;
174 let mut result = 0;
175 for (&c, &size) in coords.iter().rev().zip(sizes.iter().rev()) {
176 if c >= size {
177 return Err(PointError::OutOfRangeIndex { size, index: c });
178 }
179 result += c * stride;
180 stride *= size;
181 }
182 Ok(result)
183 }
184
185 pub fn point(&self, coords: Vec<usize>) -> Result<Point, PointError> {
231 Ok(Point {
232 rank: self.rank_of_coords(&coords)?,
233 extent: self.clone(),
234 })
235 }
236
237 pub fn point_of_rank(&self, rank: usize) -> Result<Point, PointError> {
261 let total = self.num_ranks();
262 if rank >= total {
263 return Err(PointError::OutOfRangeRank { total, rank });
264 }
265 Ok(Point {
266 rank,
267 extent: self.clone(),
268 })
269 }
270
271 pub fn len(&self) -> usize {
276 self.sizes().len()
277 }
278
279 pub fn is_empty(&self) -> bool {
284 self.sizes().is_empty()
285 }
286
287 pub fn num_ranks(&self) -> usize {
292 self.sizes().iter().product()
293 }
294
295 pub fn into_inner(self) -> (Vec<String>, Vec<usize>) {
297 match Arc::try_unwrap(self.inner) {
298 Ok(data) => (data.labels, data.sizes),
299 Err(shared) => (shared.labels.clone(), shared.sizes.clone()),
300 }
301 }
302
303 pub fn to_slice(&self) -> Slice {
305 Slice::new_row_major(self.sizes())
306 }
307
308 pub fn iter(&self) -> impl Iterator<Item = (String, usize)> + use<'_> {
310 self.labels()
311 .iter()
312 .zip(self.sizes().iter())
313 .map(|(l, s)| (l.clone(), *s))
314 }
315
316 pub fn points(&self) -> ExtentPointsIterator<'_> {
318 ExtentPointsIterator::new(self)
319 }
320
321 pub fn concat(&self, other: &Extent) -> Result<Self, ExtentError> {
329 use std::collections::HashSet;
330 let lhs: HashSet<&str> = self.labels().iter().map(|s| s.as_str()).collect();
332 if let Some(dup) = other.labels().iter().find(|l| lhs.contains(l.as_str())) {
333 return Err(ExtentError::OverlappingLabel { label: dup.clone() });
334 }
335 let mut labels = self.labels().to_vec();
337 let mut sizes = self.sizes().to_vec();
338 labels.reserve(other.labels().len());
339 sizes.reserve(other.sizes().len());
340 labels.extend(other.labels().iter().cloned());
341 sizes.extend(other.sizes().iter().copied());
342 Extent::new(labels, sizes)
343 }
344}
345
346mod labels {
358 pub(super) fn is_safe_ident(s: &str) -> bool {
362 s.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
363 }
364
365 pub(super) fn fmt_label(s: &str) -> String {
370 if is_safe_ident(s) {
371 s.to_string()
372 } else {
373 format!("{:?}", s)
374 }
375 }
376}
377
378impl std::fmt::Display for Extent {
414 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
415 write!(f, "{{")?;
416 for i in 0..self.sizes().len() {
417 write!(
418 f,
419 "{}: {}",
420 labels::fmt_label(&self.labels()[i]),
421 self.sizes()[i]
422 )?;
423 if i + 1 != self.sizes().len() {
424 write!(f, ", ")?;
425 }
426 }
427 write!(f, "}}")
428 }
429}
430
431pub struct ExtentPointsIterator<'a> {
433 extent: &'a Extent,
434 next_rank: usize,
435}
436
437impl<'a> ExtentPointsIterator<'a> {
438 pub fn new(extent: &'a Extent) -> Self {
439 Self {
440 extent,
441 next_rank: 0,
442 }
443 }
444}
445
446impl<'a> Iterator for ExtentPointsIterator<'a> {
447 type Item = Point;
448
449 fn next(&mut self) -> Option<Self::Item> {
452 if self.next_rank == self.extent.num_ranks() {
453 return None;
454 }
455
456 let p = Point {
457 rank: self.next_rank,
458 extent: self.extent.clone(),
459 };
460 self.next_rank += 1;
461 Some(p)
462 }
463}
464
465#[derive(Debug, Error)]
467pub enum PointError {
468 #[error("dimension mismatch: expected {expected}, got {actual}")]
475 DimMismatch {
476 expected: usize,
478 actual: usize,
480 },
481
482 #[error("out of range: total ranks {total}; does not contain rank {rank}")]
489 OutOfRangeRank {
490 total: usize,
492 rank: usize,
494 },
495
496 #[error("out of range: dim size {size}; does not contain index {index}")]
504 OutOfRangeIndex {
505 size: usize,
507 index: usize,
509 },
510
511 #[error("failed to parse point: {reason}")]
513 ParseError { reason: String },
514}
515
516#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
549pub struct Point {
550 rank: usize,
551 extent: Extent,
552}
553
554pub struct CoordIter<'a> {
574 sizes: &'a [usize],
575 rank: usize,
576 stride: usize,
577 axis: usize,
578}
579
580impl<'a> Iterator for CoordIter<'a> {
581 type Item = usize;
582
583 fn next(&mut self) -> Option<Self::Item> {
588 if self.axis >= self.sizes.len() {
589 return None;
590 }
591 self.stride /= self.sizes[self.axis];
592 let q = self.rank / self.stride;
593 self.rank %= self.stride;
594 self.axis += 1;
595 Some(q)
596 }
597
598 fn size_hint(&self) -> (usize, Option<usize>) {
604 let rem = self.sizes.len().saturating_sub(self.axis);
605 (rem, Some(rem))
606 }
607}
608
609impl ExactSizeIterator for CoordIter<'_> {}
610
611impl<'a> IntoIterator for &'a Point {
612 type Item = usize;
613 type IntoIter = CoordIter<'a>;
614
615 fn into_iter(self) -> Self::IntoIter {
634 self.coords_iter()
635 }
636}
637
638fn _assert_point_traits()
639where
640 Point: Send + Sync + 'static,
641{
642}
643
644pub trait InExtent {
660 fn in_(self, extent: &Extent) -> Result<Point, PointError>;
661}
662
663impl InExtent for Vec<usize> {
664 fn in_(self, extent: &Extent) -> Result<Point, PointError> {
669 extent.point(self)
670 }
671}
672
673impl Point {
674 pub fn coords_iter(&self) -> CoordIter<'_> {
675 CoordIter {
676 sizes: self.extent.sizes(),
677 rank: self.rank,
678 stride: self.extent.sizes().iter().product(),
679 axis: 0,
680 }
681 }
682
683 pub fn coord(&self, i: usize) -> usize {
701 self.coords_iter()
702 .nth(i)
703 .expect("coord(i): axis out of bounds")
704 }
705
706 pub fn coords(&self) -> Vec<usize> {
721 self.coords_iter().collect()
722 }
723
724 pub fn rank(&self) -> usize {
727 self.rank
728 }
729
730 pub fn extent(&self) -> &Extent {
733 &self.extent
734 }
735
736 pub fn len(&self) -> usize {
751 self.extent.len()
752 }
753
754 pub fn is_empty(&self) -> bool {
770 self.extent.len() == 0
771 }
772
773 pub fn format_as_dict(&self) -> String {
785 format!(
786 "{{{}}}",
787 self.extent()
788 .labels()
789 .iter()
790 .zip(self.coords_iter())
791 .zip(self.extent().sizes())
792 .map(|((label, coord), size)| format!("'{}': {}/{}", label, coord, size))
793 .collect::<Vec<_>>()
794 .join(", ")
795 )
796 }
797}
798
799impl std::fmt::Display for Point {
835 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
836 let labels = self.extent.labels();
837 let sizes = self.extent.sizes();
838 let coords = self.coords();
839
840 for i in 0..labels.len() {
841 write!(
842 f,
843 "{}={}/{}",
844 labels::fmt_label(&labels[i]),
845 coords[i],
846 sizes[i]
847 )?;
848 if i + 1 != labels.len() {
849 write!(f, ",")?;
850 }
851 }
852 Ok(())
853 }
854}
855
856impl FromStr for Point {
857 type Err = PointError;
858
859 fn from_str(s: &str) -> Result<Self, Self::Err> {
860 let s = s.trim();
861
862 if s.is_empty() {
863 let empty_extent = Extent::unity();
864 return empty_extent.point(vec![]);
865 }
866
867 let mut labels = Vec::new();
868 let mut coords = Vec::new();
869 let mut sizes = Vec::new();
870
871 let mut chars = s.chars().peekable();
872
873 while chars.peek().is_some() {
874 while chars.peek() == Some(&' ') {
875 chars.next();
876 }
877
878 if chars.peek().is_none() {
879 break;
880 }
881
882 let label = if chars.peek() == Some(&'"') {
883 chars.next(); let mut label = String::new();
885 let mut escaped = false;
886
887 for ch in chars.by_ref() {
889 if escaped {
890 match ch {
891 '"' => label.push('"'),
892 '\\' => label.push('\\'),
893 _ => {
894 label.push('\\');
895 label.push(ch);
896 }
897 }
898 escaped = false;
899 } else if ch == '\\' {
900 escaped = true;
901 } else if ch == '"' {
902 break;
903 } else {
904 label.push(ch);
905 }
906 }
907
908 if label.is_empty() {
909 return Err(PointError::ParseError {
910 reason: "empty quoted label".to_string(),
911 });
912 }
913
914 label
915 } else {
916 let mut label = String::new();
917 while let Some(&ch) = chars.peek() {
918 if ch == '=' || ch == ' ' {
919 break;
920 }
921 label.push(chars.next().unwrap());
922 }
923
924 if label.is_empty() {
925 return Err(PointError::ParseError {
926 reason: "missing label".to_string(),
927 });
928 }
929
930 label
931 };
932
933 while chars.peek() == Some(&' ') {
934 chars.next();
935 }
936
937 if chars.next() != Some('=') {
938 return Err(PointError::ParseError {
939 reason: format!("expected '=' after label '{}'", label),
940 });
941 }
942
943 while chars.peek() == Some(&' ') {
944 chars.next();
945 }
946
947 let mut coord = String::new();
948 while let Some(&ch) = chars.peek() {
949 if ch == '/' || ch == ' ' {
950 break;
951 }
952 coord.push(chars.next().unwrap());
953 }
954
955 if coord.is_empty() {
956 return Err(PointError::ParseError {
957 reason: format!("missing coordinate for dimension '{}'", label),
958 });
959 }
960
961 while chars.peek() == Some(&' ') {
962 chars.next();
963 }
964
965 if chars.next() != Some('/') {
966 return Err(PointError::ParseError {
967 reason: format!("expected '/' after coordinate for dimension '{}'", label),
968 });
969 }
970
971 while chars.peek() == Some(&' ') {
972 chars.next();
973 }
974
975 let mut size = String::new();
976 while let Some(&ch) = chars.peek() {
977 if ch == ',' || ch == ' ' {
978 break;
979 }
980 size.push(chars.next().unwrap());
981 }
982
983 if size.is_empty() {
984 return Err(PointError::ParseError {
985 reason: format!("missing size for dimension '{}'", label),
986 });
987 }
988
989 let coord = coord.parse::<usize>().map_err(|e| PointError::ParseError {
990 reason: format!(
991 "invalid coordinate '{}' for dimension '{}': {}",
992 coord, label, e
993 ),
994 })?;
995
996 let size = size.parse::<usize>().map_err(|e| PointError::ParseError {
997 reason: format!("invalid size '{}' for dimension '{}': {}", size, label, e),
998 })?;
999
1000 labels.push(label);
1001 coords.push(coord);
1002 sizes.push(size);
1003
1004 while chars.peek() == Some(&' ') {
1005 chars.next();
1006 }
1007
1008 if chars.peek() == Some(&',') {
1009 chars.next(); while chars.peek() == Some(&' ') {
1011 chars.next();
1012 }
1013 if chars.peek().is_none() {
1015 return Err(PointError::ParseError {
1016 reason: "trailing comma".to_string(),
1017 });
1018 }
1019 }
1020 }
1021
1022 let extent = Extent::new(labels, sizes).map_err(|e| PointError::ParseError {
1023 reason: format!("failed to create extent: {}", e),
1024 })?;
1025
1026 extent.point(coords)
1027 }
1028}
1029
1030#[derive(Debug, Error)]
1032pub enum ViewError {
1033 #[error("no such dimension: {0}")]
1035 InvalidDim(String),
1036
1037 #[error("empty range: {range} for dimension {dim} of size {size}")]
1039 EmptyRange {
1040 range: Range,
1041 dim: String,
1042 size: usize,
1043 },
1044
1045 #[error(transparent)]
1046 ExtentError(#[from] ExtentError),
1047
1048 #[error("invalid range: selected ranks {selected} not a subset of base {base} ")]
1049 InvalidRange {
1050 base: Box<Region>,
1051 selected: Box<Region>,
1052 },
1053}
1054
1055#[derive(Debug, Error)]
1057pub enum RegionError {
1058 #[error("invalid point: this point does not belong to this region: {0}")]
1059 InvalidPoint(String),
1060
1061 #[error("out of range base rank: this base rank {0} does not belong to this region: {0}")]
1062 OutOfRangeBaseRank(usize, String),
1063}
1064
1065#[derive(Debug, Clone, Serialize, Deserialize, Hash, PartialEq, Eq)]
1072pub struct Region {
1073 labels: Vec<String>,
1074 slice: Slice,
1075}
1076
1077impl Region {
1078 #[allow(dead_code)]
1079 fn empty() -> Region {
1080 Region {
1081 labels: Vec::new(),
1082 slice: Slice::new(0, Vec::new(), Vec::new()).unwrap(),
1083 }
1084 }
1085
1086 #[allow(dead_code)]
1089 pub fn new(labels: Vec<String>, slice: Slice) -> Self {
1090 Self { labels, slice }
1091 }
1092
1093 pub fn labels(&self) -> &[String] {
1095 &self.labels
1096 }
1097
1098 pub fn slice(&self) -> &Slice {
1101 &self.slice
1102 }
1103
1104 pub fn into_inner(self) -> (Vec<String>, Slice) {
1106 (self.labels, self.slice)
1107 }
1108
1109 pub fn extent(&self) -> Extent {
1111 Extent::new(self.labels.clone(), self.slice.sizes().to_vec()).unwrap()
1112 }
1113
1114 pub fn is_subset(&self, other: &Region) -> bool {
1117 let mut left = self.slice.iter().peekable();
1118 let mut right = other.slice.iter().peekable();
1119
1120 loop {
1121 match (left.peek(), right.peek()) {
1122 (Some(l), Some(r)) => {
1123 if l < r {
1124 return false;
1125 } else if l == r {
1126 left.next();
1127 right.next();
1128 } else {
1129 right.next();
1131 }
1132 }
1133 (Some(_), None) => return false,
1134 (None, _) => return true,
1135 }
1136 }
1137 }
1138
1139 pub fn remap(&self, target: &Region) -> Option<impl Iterator<Item = usize> + '_> {
1159 if !target.is_subset(self) {
1160 return None;
1161 }
1162
1163 let mut ours = self.slice.iter().enumerate();
1164 let mut theirs = target.slice.iter();
1165
1166 Some(std::iter::from_fn(move || {
1167 let needle = theirs.next()?;
1168 loop {
1169 let (index, value) = ours.next().unwrap();
1170 if value == needle {
1171 break Some(index);
1172 }
1173 }
1174 }))
1175 }
1176
1177 pub fn num_ranks(&self) -> usize {
1179 self.slice.len()
1180 }
1181
1182 pub fn base_rank_of_point(&self, p: Point) -> Result<usize, RegionError> {
1185 if p.extent() != &self.extent() {
1186 return Err(RegionError::InvalidPoint(
1187 "mismatched extent: p must be a point in this region’s extent".to_string(),
1188 ));
1189 }
1190
1191 Ok(self
1192 .slice()
1193 .location(&p.coords())
1194 .expect("should have valid location since extent is checked"))
1195 }
1196
1197 pub fn point_of_base_rank(&self, rank: usize) -> Result<Point, RegionError> {
1200 let coords = self
1201 .slice()
1202 .coordinates(rank)
1203 .map_err(|e| RegionError::OutOfRangeBaseRank(rank, e.to_string()))?;
1204 Ok(self
1205 .extent()
1206 .point(coords)
1207 .expect("should have valid point since coords is from this region"))
1208 }
1209}
1210
1211impl From<Extent> for Region {
1214 fn from(extent: Extent) -> Self {
1215 Region {
1216 labels: extent.labels().to_vec(),
1217 slice: extent.to_slice(),
1218 }
1219 }
1220}
1221
1222impl From<&Shape> for Region {
1223 fn from(s: &Shape) -> Self {
1224 Region {
1225 labels: s.labels().to_vec(),
1226 slice: s.slice().clone(),
1227 }
1228 }
1229}
1230
1231impl From<Shape> for Region {
1232 fn from(s: Shape) -> Self {
1233 Region::from(&s)
1234 }
1235}
1236
1237impl std::fmt::Display for Region {
1277 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1278 if self.slice.offset() != 0 {
1279 write!(f, "{}+", self.slice.offset())?;
1280 }
1281 for i in 0..self.labels.len() {
1282 write!(
1283 f,
1284 "{}={}/{}",
1285 labels::fmt_label(&self.labels[i]),
1286 self.slice.sizes()[i],
1287 self.slice.strides()[i]
1288 )?;
1289 if i + 1 != self.labels.len() {
1290 write!(f, ",")?;
1291 }
1292 }
1293 Ok(())
1294 }
1295}
1296
1297#[derive(Debug, thiserror::Error)]
1298pub enum RegionParseError {
1299 #[error(transparent)]
1300 ParserError(#[from] ParserError),
1301
1302 #[error(transparent)]
1303 SliceError(#[from] SliceError),
1304}
1305
1306impl std::str::FromStr for Region {
1327 type Err = RegionParseError;
1328
1329 fn from_str(s: &str) -> Result<Self, Self::Err> {
1330 let mut parser = Parser::new(s, &["+", "=", ",", "/"]);
1331
1332 let offset: usize = if let Ok(offset) = parser.try_parse() {
1333 parser.expect("+")?;
1334 offset
1335 } else {
1336 0
1337 };
1338
1339 let mut labels = Vec::new();
1340 let mut sizes = Vec::new();
1341 let mut strides = Vec::new();
1342
1343 while !parser.is_empty() {
1344 if !labels.is_empty() {
1345 parser.expect(",")?;
1346 }
1347
1348 let label = if parser.peek_char() == Some('"') {
1350 parser.parse_string_literal()?
1351 } else {
1352 parser.next_or_err("label")?.to_string()
1353 };
1354 labels.push(label);
1355
1356 parser.expect("=")?;
1357 sizes.push(parser.try_parse()?);
1358 parser.expect("/")?;
1359 strides.push(parser.try_parse()?);
1360 }
1361
1362 Ok(Region {
1363 labels,
1364 slice: Slice::new(offset, sizes, strides)?,
1365 })
1366 }
1367}
1368
1369pub trait BuildFromRegion<T>: Sized {
1382 type Error;
1383
1384 fn build_dense(region: Region, values: Vec<T>) -> Result<Self, Self::Error>;
1386
1387 fn build_dense_unchecked(region: Region, values: Vec<T>) -> Self;
1389}
1390
1391pub trait BuildFromRegionIndexed<T>: Sized {
1405 type Error;
1406
1407 fn build_indexed(
1410 region: Region,
1411 pairs: impl IntoIterator<Item = (usize, T)>,
1412 ) -> Result<Self, Self::Error>;
1413}
1414
1415pub trait CollectMeshExt<T>: Iterator<Item = T> + Sized {
1429 fn collect_mesh<M>(self, region: Region) -> Result<M, M::Error>
1430 where
1431 M: BuildFromRegion<T>;
1432}
1433
1434impl<I, T> CollectMeshExt<T> for I
1437where
1438 I: Iterator<Item = T> + Sized,
1439{
1440 fn collect_mesh<M>(self, region: Region) -> Result<M, M::Error>
1441 where
1442 M: BuildFromRegion<T>,
1443 {
1444 M::build_dense(region, self.collect())
1445 }
1446}
1447
1448#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1450pub struct InvalidCardinality {
1451 pub expected: usize,
1452 pub actual: usize,
1453}
1454
1455pub trait CollectExactMeshExt<T>: ExactSizeIterator<Item = T> + Sized {
1462 fn collect_exact_mesh<M>(self, region: Region) -> Result<M, M::Error>
1463 where
1464 M: BuildFromRegion<T>,
1465 M::Error: From<InvalidCardinality>;
1466}
1467
1468impl<I, T> CollectExactMeshExt<T> for I
1471where
1472 I: ExactSizeIterator<Item = T> + Sized,
1473{
1474 fn collect_exact_mesh<M>(self, region: Region) -> Result<M, M::Error>
1475 where
1476 M: BuildFromRegion<T>,
1477 M::Error: From<InvalidCardinality>,
1478 {
1479 let expected = region.num_ranks();
1480 let actual = self.len();
1481 if actual != expected {
1482 return Err(M::Error::from(InvalidCardinality { expected, actual }));
1483 }
1484 Ok(M::build_dense_unchecked(region, self.collect()))
1485 }
1486}
1487
1488pub trait CollectIndexedMeshExt<T>: Iterator<Item = (usize, T)> + Sized {
1510 fn collect_indexed<M>(self, region: Region) -> Result<M, M::Error>
1511 where
1512 M: BuildFromRegionIndexed<T>;
1513}
1514
1515impl<I, T> CollectIndexedMeshExt<T> for I
1518where
1519 I: Iterator<Item = (usize, T)> + Sized,
1520{
1521 #[inline]
1522 fn collect_indexed<M>(self, region: Region) -> Result<M, M::Error>
1523 where
1524 M: BuildFromRegionIndexed<T>,
1525 {
1526 M::build_indexed(region, self)
1527 }
1528}
1529
1530pub trait MapIntoExt: Ranked {
1532 fn map_into<M, U>(&self, f: impl Fn(&Self::Item) -> U) -> M
1533 where
1534 Self: Sized,
1535 M: BuildFromRegion<U>,
1536 {
1537 let region = self.region().clone();
1538 let n = region.num_ranks();
1539 let values: Vec<U> = (0..n).map(|i| f(self.get(i).unwrap())).collect();
1540 M::build_dense_unchecked(region, values)
1541 }
1542
1543 fn try_map_into<M, U, E>(self, f: impl Fn(&Self::Item) -> Result<U, E>) -> Result<M, E>
1544 where
1545 Self: Sized,
1546 M: BuildFromRegion<U>,
1547 {
1548 let region = self.region().clone();
1549 let n = region.num_ranks();
1550 let mut out = Vec::with_capacity(n);
1551 for i in 0..n {
1552 out.push(f(self.get(i).unwrap())?);
1553 }
1554 Ok(M::build_dense_unchecked(region, out))
1555 }
1556}
1557
1558impl<T: Ranked> MapIntoExt for T {}
1561
1562pub trait View: Sized {
1564 type Item;
1566
1567 type View: View;
1569
1570 fn region(&self) -> Region;
1572
1573 fn get(&self, rank: usize) -> Option<Self::Item>;
1577
1578 #[allow(clippy::result_large_err)] fn subset(&self, region: Region) -> Result<Self::View, ViewError>;
1583}
1584
1585impl View for Region {
1587 type Item = usize;
1589
1590 type View = Region;
1592
1593 fn region(&self) -> Region {
1594 self.clone()
1595 }
1596
1597 fn subset(&self, region: Region) -> Result<Region, ViewError> {
1598 if region.is_subset(self) {
1599 Ok(region)
1600 } else {
1601 Err(ViewError::InvalidRange {
1602 base: Box::new(self.clone()),
1603 selected: Box::new(region),
1604 })
1605 }
1606 }
1607
1608 fn get(&self, rank: usize) -> Option<Self::Item> {
1609 self.slice.get(rank).ok()
1610 }
1611}
1612
1613impl View for Extent {
1615 type Item = usize;
1617
1618 type View = Region;
1621
1622 fn region(&self) -> Region {
1623 Region {
1624 labels: self.labels().to_vec(),
1625 slice: self.to_slice(),
1626 }
1627 }
1628
1629 fn subset(&self, region: Region) -> Result<Region, ViewError> {
1630 self.region().subset(region)
1631 }
1632
1633 fn get(&self, rank: usize) -> Option<Self::Item> {
1634 if rank < self.num_ranks() {
1635 Some(rank)
1636 } else {
1637 None
1638 }
1639 }
1640}
1641
1642pub trait Ranked: Sized {
1645 type Item: 'static;
1647
1648 fn region(&self) -> &Region;
1650
1651 fn get(&self, rank: usize) -> Option<&Self::Item>;
1653}
1654
1655pub trait RankedSliceable: Ranked {
1663 fn sliced(&self, region: Region) -> Self;
1667}
1668
1669impl<T: RankedSliceable> View for T
1670where
1671 T::Item: Clone + 'static,
1672{
1673 type Item = T::Item;
1674 type View = Self;
1675
1676 fn region(&self) -> Region {
1677 <Self as Ranked>::region(self).clone()
1678 }
1679
1680 fn get(&self, rank: usize) -> Option<Self::Item> {
1681 <Self as Ranked>::get(self, rank).cloned()
1682 }
1683
1684 fn subset(&self, region: Region) -> Result<Self, ViewError> {
1685 if !region.is_subset(self.region()) {
1686 return Err(ViewError::InvalidRange {
1687 base: Box::new(self.region().clone()),
1688 selected: Box::new(region.clone()),
1689 });
1690 }
1691
1692 Ok(self.sliced(region))
1693 }
1694}
1695
1696pub struct ViewIterator {
1698 extent: Extent, pos: SliceIterator, }
1701
1702impl Iterator for ViewIterator {
1703 type Item = (Point, usize);
1704 fn next(&mut self) -> Option<Self::Item> {
1705 let rank = self.pos.next()?;
1707 let coords = self.pos.slice.coordinates(rank).unwrap();
1709 let point = coords.in_(&self.extent).unwrap();
1710 Some((point, rank))
1711 }
1712}
1713
1714pub trait ViewExt: View {
1716 #[allow(clippy::result_large_err)] fn range<R: Into<Range>>(&self, dim: &str, range: R) -> Result<Self::View, ViewError>;
1743
1744 #[allow(clippy::result_large_err)] fn group_by(&self, dim: &str) -> Result<impl Iterator<Item = Self::View>, ViewError>;
1781
1782 fn extent(&self) -> Extent;
1784
1785 fn iter<'a>(&'a self) -> impl Iterator<Item = (Point, Self::Item)> + 'a;
1787
1788 fn values<'a>(&'a self) -> impl Iterator<Item = Self::Item> + 'a;
1790}
1791
1792impl<T: View> ViewExt for T {
1793 fn range<R: Into<Range>>(&self, dim: &str, range: R) -> Result<Self::View, ViewError> {
1794 let (labels, slice) = self.region().into_inner();
1795 let range = range.into();
1796 let dim = labels
1797 .iter()
1798 .position(|l| dim == l)
1799 .ok_or_else(|| ViewError::InvalidDim(dim.to_string()))?;
1800 let (mut offset, mut sizes, mut strides) = slice.into_inner();
1801 let (begin, end, step) = range.resolve(sizes[dim]);
1802 if end <= begin {
1803 return Err(ViewError::EmptyRange {
1804 range,
1805 dim: dim.to_string(),
1806 size: sizes[dim],
1807 });
1808 }
1809
1810 offset += strides[dim] * begin;
1811 sizes[dim] = (end - begin).div_ceil(step);
1812 strides[dim] *= step;
1813 let slice = Slice::new(offset, sizes, strides).unwrap();
1814
1815 self.subset(Region { labels, slice })
1816 }
1817
1818 fn group_by(&self, dim: &str) -> Result<impl Iterator<Item = Self::View>, ViewError> {
1819 let (labels, slice) = self.region().into_inner();
1820
1821 let dim = labels
1822 .iter()
1823 .position(|l| dim == l)
1824 .ok_or_else(|| ViewError::InvalidDim(dim.to_string()))?;
1825
1826 let (offset, sizes, strides) = slice.into_inner();
1827 let mut ranks_iter = Slice::new(offset, sizes[..dim].to_vec(), strides[..dim].to_vec())
1828 .unwrap()
1829 .iter();
1830
1831 let labels = labels[dim..].to_vec();
1832 let sizes = sizes[dim..].to_vec();
1833 let strides = strides[dim..].to_vec();
1834
1835 Ok(std::iter::from_fn(move || {
1836 let rank = ranks_iter.next()?;
1837 let slice = Slice::new(rank, sizes.clone(), strides.clone()).unwrap();
1838 Some(
1840 self.subset(Region {
1841 labels: labels.clone(),
1842 slice,
1843 })
1844 .unwrap(),
1845 )
1846 }))
1847 }
1848
1849 fn extent(&self) -> Extent {
1850 let (labels, slice) = self.region().into_inner();
1851 Extent::new(labels, slice.sizes().to_vec()).unwrap()
1852 }
1853
1854 fn iter(&self) -> impl Iterator<Item = (Point, Self::Item)> + '_ {
1855 let points = ViewIterator {
1856 extent: self.extent(),
1857 pos: self.region().slice().iter(),
1858 };
1859
1860 points.map(|(point, _)| (point.clone(), self.get(point.rank()).unwrap()))
1861 }
1862
1863 fn values(&self) -> impl Iterator<Item = Self::Item> + '_ {
1864 (0usize..self.extent().num_ranks()).map(|rank| self.get(rank).unwrap())
1865 }
1866}
1867
1868#[macro_export]
1876macro_rules! extent {
1877 ( $( $label:ident = $size:expr ),* $(,)? ) => {
1878 {
1879 #[allow(unused_mut)]
1880 let mut labels = Vec::new();
1881 #[allow(unused_mut)]
1882 let mut sizes = Vec::new();
1883
1884 $(
1885 labels.push(stringify!($label).to_string());
1886 sizes.push($size);
1887 )*
1888
1889 $crate::view::Extent::new(labels, sizes).unwrap()
1890 }
1891 };
1892}
1893
1894#[cfg(test)]
1895mod test {
1896 use super::labels::*;
1897 use super::*;
1898 use crate::Shape;
1899 use crate::shape;
1900 use crate::slice::CartesianIterator;
1901
1902 #[test]
1903 fn test_is_safe_ident() {
1904 assert!(is_safe_ident("x"));
1905 assert!(is_safe_ident("gpu_0"));
1906 assert!(!is_safe_ident("dim/0"));
1907 assert!(!is_safe_ident("x y"));
1908 assert!(!is_safe_ident("x=y"));
1909 }
1910 #[test]
1911 fn test_fmt_label() {
1912 assert_eq!(fmt_label("x"), "x");
1913 assert_eq!(fmt_label("dim/0"), "\"dim/0\"");
1914 }
1915
1916 #[test]
1917 fn test_points_basic() {
1918 let extent = extent!(x = 4, y = 5, z = 6);
1919 let _p1 = extent.point(vec![1, 2, 3]).unwrap();
1920 let _p2 = vec![1, 2, 3].in_(&extent).unwrap();
1921
1922 assert_eq!(extent.num_ranks(), 4 * 5 * 6);
1923
1924 let p3 = extent.point_of_rank(0).unwrap();
1925 assert_eq!(p3.coords(), &[0, 0, 0]);
1926 assert_eq!(p3.rank(), 0);
1927
1928 let p4 = extent.point_of_rank(1).unwrap();
1929 assert_eq!(p4.coords(), &[0, 0, 1]);
1930 assert_eq!(p4.rank(), 1);
1931
1932 let p5 = extent.point_of_rank(2).unwrap();
1933 assert_eq!(p5.coords(), &[0, 0, 2]);
1934 assert_eq!(p5.rank(), 2);
1935
1936 let p6 = extent.point_of_rank(6 * 5 + 1).unwrap();
1937 assert_eq!(p6.coords(), &[1, 0, 1]);
1938 assert_eq!(p6.rank(), 6 * 5 + 1);
1939 assert_eq!(p6.coord(0), 1);
1940 assert_eq!(p6.coord(1), 0);
1941 assert_eq!(p6.coord(2), 1);
1942
1943 assert_eq!(extent.points().collect::<Vec<_>>().len(), 4 * 5 * 6);
1944 for (rank, point) in extent.points().enumerate() {
1945 let c = point.coords();
1946 let (x, y, z) = (c[0], c[1], c[2]);
1947 assert_eq!(z + y * 6 + x * 6 * 5, rank);
1948 assert_eq!(point.rank(), rank);
1949 }
1950 }
1951
1952 #[test]
1953 fn points_iterates_ranks_in_row_major_order() {
1954 let ext = extent!(x = 2, y = 3, z = 4); let mut it = ext.points();
1956
1957 for expected_rank in 0..ext.num_ranks() {
1958 let p = it.next().expect("expected another Point");
1959 assert_eq!(
1960 p.rank, expected_rank,
1961 "ranks must be consecutive in row-major order"
1962 );
1963 }
1964 assert!(
1965 it.next().is_none(),
1966 "iterator must be exhausted after num_ranks items"
1967 );
1968 }
1969
1970 #[test]
1971 fn points_iterates_single_point_for_0d_extent() {
1972 let ext = extent!();
1974 let mut it = ext.points();
1975
1976 let p = it
1977 .next()
1978 .expect("0-D extent should yield exactly one point");
1979 assert_eq!(p.rank, 0);
1980 assert_eq!(p.extent, ext);
1981
1982 assert!(
1983 it.next().is_none(),
1984 "no more points after the single 0-D point"
1985 );
1986 }
1987
1988 macro_rules! assert_view {
1989 ($view:expr, $extent:expr, $( $($coord:expr),+ => $rank:expr );* $(;)?) => {
1990 let view = $view;
1991 assert_eq!(view.extent(), $extent);
1992 let expected: Vec<_> = vec![$(($extent.point(vec![$($coord),+]).unwrap(), $rank)),*];
1993 let actual: Vec<_> = ViewExt::iter(&view).collect();
1994 assert_eq!(actual, expected);
1995 };
1996 }
1997
1998 #[test]
1999 fn test_view_basic() {
2000 let extent = extent!(x = 4, y = 4);
2001 assert_view!(
2002 extent.range("x", 0..2).unwrap(),
2003 extent!(x = 2, y = 4),
2004 0, 0 => 0;
2005 0, 1 => 1;
2006 0, 2 => 2;
2007 0, 3 => 3;
2008 1, 0 => 4;
2009 1, 1 => 5;
2010 1, 2 => 6;
2011 1, 3 => 7;
2012 );
2013 assert_view!(
2014 extent.range("x", 1).unwrap().range("y", 2..).unwrap(),
2015 extent!(x = 1, y = 2),
2016 0, 0 => 6;
2017 0, 1 => 7;
2018 );
2019 assert_view!(
2020 extent.range("y", Range(0, None, 2)).unwrap(),
2021 extent!(x = 4, y = 2),
2022 0, 0 => 0;
2023 0, 1 => 2;
2024 1, 0 => 4;
2025 1, 1 => 6;
2026 2, 0 => 8;
2027 2, 1 => 10;
2028 3, 0 => 12;
2029 3, 1 => 14;
2030 );
2031 assert_view!(
2032 extent.range("y", Range(0, None, 2)).unwrap().range("x", 2..).unwrap(),
2033 extent!(x = 2, y = 2),
2034 0, 0 => 8;
2035 0, 1 => 10;
2036 1, 0 => 12;
2037 1, 1 => 14;
2038 );
2039
2040 let extent = extent!(x = 10, y = 2);
2041 assert_view!(
2042 extent.range("x", Range(0, None, 2)).unwrap(),
2043 extent!(x = 5, y = 2),
2044 0, 0 => 0;
2045 0, 1 => 1;
2046 1, 0 => 4;
2047 1, 1 => 5;
2048 2, 0 => 8;
2049 2, 1 => 9;
2050 3, 0 => 12;
2051 3, 1 => 13;
2052 4, 0 => 16;
2053 4, 1 => 17;
2054 );
2055 assert_view!(
2056 extent.range("x", Range(0, None, 2)).unwrap().range("x", 2..).unwrap().range("y", 1).unwrap(),
2057 extent!(x = 3, y = 1),
2058 0, 0 => 9;
2059 1, 0 => 13;
2060 2, 0 => 17;
2061 );
2062
2063 let extent = extent!(zone = 4, host = 2, gpu = 8);
2064 assert_view!(
2065 extent.range("zone", 0).unwrap().range("gpu", Range(0, None, 2)).unwrap(),
2066 extent!(zone = 1, host = 2, gpu = 4),
2067 0, 0, 0 => 0;
2068 0, 0, 1 => 2;
2069 0, 0, 2 => 4;
2070 0, 0, 3 => 6;
2071 0, 1, 0 => 8;
2072 0, 1, 1 => 10;
2073 0, 1, 2 => 12;
2074 0, 1, 3 => 14;
2075 );
2076
2077 let extent = extent!(x = 3);
2078 assert_view!(
2079 extent.range("x", Range(0, None, 2)).unwrap(),
2080 extent!(x = 2),
2081 0 => 0;
2082 1 => 2;
2083 );
2084 }
2085
2086 #[test]
2087 fn test_point_indexing() {
2088 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
2089 let point = extent.point(vec![1, 2, 3]).unwrap();
2090
2091 assert_eq!(point.coord(0), 1);
2092 assert_eq!(point.coord(1), 2);
2093 assert_eq!(point.coord(2), 3);
2094 }
2095
2096 #[test]
2097 #[should_panic]
2098 fn test_point_indexing_out_of_bounds() {
2099 let extent = Extent::new(vec!["x".into(), "y".into()], vec![4, 5]).unwrap();
2100 let point = extent.point(vec![1, 2]).unwrap();
2101
2102 let _ = point.coord(5); }
2104
2105 #[test]
2106 fn test_point_into_iter() {
2107 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
2108 let point = extent.point(vec![1, 2, 3]).unwrap();
2109
2110 let coords: Vec<usize> = (&point).into_iter().collect();
2111 assert_eq!(coords, vec![1, 2, 3]);
2112
2113 let mut sum = 0;
2114 for coord in &point {
2115 sum += coord;
2116 }
2117 assert_eq!(sum, 6);
2118 }
2119
2120 #[test]
2121 fn test_extent_basic() {
2122 let extent = extent!(x = 10, y = 5, z = 1);
2123 assert_eq!(
2124 extent.iter().collect::<Vec<_>>(),
2125 vec![
2126 ("x".to_string(), 10),
2127 ("y".to_string(), 5),
2128 ("z".to_string(), 1)
2129 ]
2130 );
2131 }
2132
2133 #[test]
2134 fn test_extent_display() {
2135 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
2136 assert_eq!(format!("{}", extent), "{x: 4, y: 5, z: 6}");
2137
2138 let extent = Extent::new(vec!["dim/0".into(), "dim/1".into()], vec![4, 5]).unwrap();
2139 assert_eq!(format!("{}", extent), "{\"dim/0\": 4, \"dim/1\": 5}");
2140
2141 let empty_extent = Extent::new(vec![], vec![]).unwrap();
2142 assert_eq!(format!("{}", empty_extent), "{}");
2143 }
2144
2145 #[test]
2146 fn extent_label_helpers() {
2147 let e = extent!(zone = 3, host = 2, gpu = 4);
2148 for (i, (lbl, sz)) in e.iter().enumerate() {
2149 assert_eq!(e.position(&lbl), Some(i));
2150 assert_eq!(e.size(&lbl), Some(sz));
2151 }
2152 assert_eq!(e.position("nope"), None);
2153 assert_eq!(e.size("nope"), None);
2154 }
2155
2156 #[test]
2157 fn test_extent_0d() {
2158 let e = Extent::new(vec![], vec![]).unwrap();
2159 assert_eq!(e.num_ranks(), 1);
2160
2161 let points: Vec<_> = e.points().collect();
2162 assert_eq!(points.len(), 1);
2163 assert_eq!(points[0].coords(), &[]);
2164 assert_eq!(points[0].rank(), 0);
2165
2166 let mut it = (&points[0]).into_iter();
2168 assert_eq!(it.len(), 0);
2169 assert!(it.next().is_none()); assert!(it.next().is_none()); }
2172
2173 #[test]
2174 fn test_extent_concat() {
2175 let extent1 = extent!(x = 2, y = 3);
2177 let extent2 = extent!(z = 4, w = 5);
2178
2179 let result = extent1.concat(&extent2).unwrap();
2180 assert_eq!(result.labels(), &["x", "y", "z", "w"]);
2181 assert_eq!(result.sizes(), &[2, 3, 4, 5]);
2182 assert_eq!(result.num_ranks(), 2 * 3 * 4 * 5);
2183
2184 let empty = extent!();
2186 let result = extent1.concat(&empty).unwrap();
2187 assert_eq!(result.labels(), &["x", "y"]);
2188 assert_eq!(result.sizes(), &[2, 3]);
2189
2190 let result = empty.concat(&extent1).unwrap();
2191 assert_eq!(result.labels(), &["x", "y"]);
2192 assert_eq!(result.sizes(), &[2, 3]);
2193
2194 let result = empty.concat(&empty).unwrap();
2196 assert_eq!(result.labels(), &[] as &[String]);
2197 assert_eq!(result.sizes(), &[] as &[usize]);
2198 assert_eq!(result.num_ranks(), 1); let result = extent1.concat(&extent1);
2202 assert!(
2203 result.is_err(),
2204 "Self-concatenation should error due to overlapping labels"
2205 );
2206 match result.unwrap_err() {
2207 ExtentError::OverlappingLabel { label } => {
2208 assert!(label == "x"); }
2210 other => panic!("Expected OverlappingLabel error, got {:?}", other),
2211 }
2212
2213 let result = extent1.concat(&extent2).unwrap();
2215 let point = result.point(vec![1, 2, 3, 4]).unwrap();
2216 assert_eq!(point.coords(), vec![1, 2, 3, 4]);
2217 assert_eq!(point.extent(), &result);
2218
2219 let extent_a = extent!(x = 2, y = 3);
2221 let extent_b = extent!(y = 3, z = 4); let result = extent_a.concat(&extent_b);
2223 assert!(
2224 result.is_err(),
2225 "Should error on overlapping labels even with same size"
2226 );
2227 match result.unwrap_err() {
2228 ExtentError::OverlappingLabel { label } => {
2229 assert_eq!(label, "y"); }
2231 other => panic!("Expected OverlappingLabel error, got {:?}", other),
2232 }
2233
2234 let extent_x = extent!(x = 2, y = 3);
2236 let extent_y = extent!(z = 4);
2237 assert_eq!(
2238 extent_x.concat(&extent_y).unwrap().labels(),
2239 &["x", "y", "z"]
2240 );
2241 assert_eq!(
2242 extent_y.concat(&extent_x).unwrap().labels(),
2243 &["z", "x", "y"]
2244 );
2245
2246 let extent_m = extent!(x = 2);
2248 let extent_n = extent!(y = 3);
2249 let extent_o = extent!(z = 4);
2250
2251 let left_assoc = extent_m
2252 .concat(&extent_n)
2253 .unwrap()
2254 .concat(&extent_o)
2255 .unwrap();
2256 let right_assoc = extent_m
2257 .concat(&extent_n.concat(&extent_o).unwrap())
2258 .unwrap();
2259
2260 assert_eq!(left_assoc, right_assoc);
2261 assert_eq!(left_assoc.labels(), &["x", "y", "z"]);
2262 assert_eq!(left_assoc.sizes(), &[2, 3, 4]);
2263 assert_eq!(left_assoc.num_ranks(), 2 * 3 * 4);
2264 }
2265
2266 #[test]
2267 fn extent_unity_equiv_to_0d() {
2268 let e = Extent::unity();
2269 assert!(e.is_empty());
2270 assert_eq!(e.num_ranks(), 1);
2271 let pts: Vec<_> = e.points().collect();
2272 assert_eq!(pts.len(), 1);
2273 assert_eq!(pts[0].rank(), 0);
2274 assert!(pts[0].coords().is_empty());
2275 }
2276
2277 #[test]
2278 fn test_point_display() {
2279 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
2280 let point = extent.point(vec![1, 2, 3]).unwrap();
2281 assert_eq!(format!("{}", point), "x=1/4,y=2/5,z=3/6");
2282
2283 assert!(extent.point(vec![]).is_err());
2284
2285 let empty_extent = Extent::new(vec![], vec![]).unwrap();
2286 let empty_point = empty_extent.point(vec![]).unwrap();
2287 assert_eq!(format!("{}", empty_point), "");
2288 }
2289
2290 #[test]
2291 fn test_point_display_with_quoted_labels() {
2292 let ext = Extent::new(vec!["dim/0".into(), "dim,1".into()], vec![3, 5]).unwrap();
2294
2295 assert_eq!(format!("{}", ext), "{\"dim/0\": 3, \"dim,1\": 5}");
2297
2298 let p = ext.point(vec![1, 2]).unwrap();
2300 assert_eq!(format!("{}", p), "\"dim/0\"=1/3,\"dim,1\"=2/5");
2301 }
2302
2303 #[test]
2304 fn test_relative_point() {
2305 pub fn relative_point(rank_on_root_mesh: usize, shape: &Shape) -> anyhow::Result<Point> {
2308 let coords = shape.slice().coordinates(rank_on_root_mesh)?;
2309 let extent = Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec())?;
2310 Ok(extent.point(coords)?)
2311 }
2312
2313 let root_shape = shape! { replicas = 4, hosts = 4, gpus = 4 };
2314 let sliced_shape = root_shape
2326 .select("replicas", crate::Range(0, Some(4), 3))
2327 .unwrap()
2328 .select("hosts", crate::Range(1, Some(4), 2))
2329 .unwrap()
2330 .select("gpus", crate::Range(0, Some(4), 2))
2331 .unwrap();
2332 let ranks_on_root_mesh = &[4, 6, 12, 14, 52, 54, 60, 62];
2333 assert_eq!(
2334 sliced_shape.slice().iter().collect::<Vec<_>>(),
2335 ranks_on_root_mesh,
2336 );
2337
2338 let ranks_on_sliced_mesh = ranks_on_root_mesh
2339 .iter()
2340 .map(|&r| relative_point(r, &sliced_shape).unwrap().rank());
2341 assert_eq!(
2342 ranks_on_sliced_mesh.collect::<Vec<_>>(),
2343 vec![0, 1, 2, 3, 4, 5, 6, 7]
2344 );
2345 }
2346
2347 #[test]
2348 fn test_iter_subviews() {
2349 let extent = extent!(zone = 4, host = 4, gpu = 8);
2350
2351 assert_eq!(extent.group_by("gpu").unwrap().count(), 16);
2352 assert_eq!(extent.group_by("zone").unwrap().count(), 1);
2353
2354 let mut parts = extent.group_by("gpu").unwrap();
2355 assert_view!(
2356 parts.next().unwrap(),
2357 extent!(gpu = 8),
2358 0 => 0;
2359 1 => 1;
2360 2 => 2;
2361 3 => 3;
2362 4 => 4;
2363 5 => 5;
2364 6 => 6;
2365 7 => 7;
2366 );
2367 assert_view!(
2368 parts.next().unwrap(),
2369 extent!(gpu = 8),
2370 0 => 8;
2371 1 => 9;
2372 2 => 10;
2373 3 => 11;
2374 4 => 12;
2375 5 => 13;
2376 6 => 14;
2377 7 => 15;
2378 );
2379 }
2380
2381 #[test]
2382 fn test_view_values() {
2383 let extent = extent!(x = 4, y = 4);
2384 assert_eq!(
2385 extent.values().collect::<Vec<_>>(),
2386 (0..16).collect::<Vec<_>>()
2387 );
2388 let region = extent.range("y", 1).unwrap();
2389 assert_eq!(region.values().collect::<Vec<_>>(), vec![1, 5, 9, 13]);
2390 }
2391
2392 #[test]
2393 fn region_is_subset_algebra() {
2394 let e = extent!(x = 5, y = 4);
2395 let a = e.range("x", 1..4).unwrap(); let b = a.range("y", 1..3).unwrap(); let c = e.range("x", 0..2).unwrap(); assert!(b.region().is_subset(&a.region()));
2400 assert!(b.region().is_subset(&e.region()));
2401 assert!(a.region().is_subset(&e.region()));
2402
2403 assert!(!c.region().is_subset(&a.region()));
2404 assert!(c.region().is_subset(&e.region()));
2405 }
2406
2407 #[test]
2408 fn test_remap() {
2409 let region: Region = extent!(x = 4, y = 4).into();
2410 assert_eq!(
2412 region.remap(®ion).unwrap().collect::<Vec<_>>(),
2413 (0..16).collect::<Vec<_>>()
2414 );
2415
2416 let subset = region.range("x", 2..).unwrap();
2417 assert_eq!(subset.num_ranks(), 8);
2418 assert_eq!(
2419 region.remap(&subset).unwrap().collect::<Vec<_>>(),
2420 vec![8, 9, 10, 11, 12, 13, 14, 15],
2421 );
2422
2423 let subset = subset.range("y", 1).unwrap();
2424 assert_eq!(subset.num_ranks(), 2);
2425 assert_eq!(
2426 region.remap(&subset).unwrap().collect::<Vec<_>>(),
2427 vec![9, 13],
2428 );
2429
2430 let ext = extent!(replica = 8, gpu = 4);
2433 let replica1 = ext.range("replica", 1).unwrap();
2434 assert_eq!(replica1.extent(), extent!(replica = 1, gpu = 4));
2435 let replica1_gpu12 = replica1.range("gpu", 1..3).unwrap();
2436 assert_eq!(replica1_gpu12.extent(), extent!(replica = 1, gpu = 2));
2437 assert_eq!(
2439 replica1.remap(&replica1_gpu12).unwrap().collect::<Vec<_>>(),
2440 vec![1, 2],
2441 );
2442 }
2443
2444 #[test]
2445 fn test_base_local_rank_conversion() {
2446 fn point(rank: usize, region: &Region) -> Point {
2447 region.extent().point_of_rank(rank).unwrap()
2448 }
2449
2450 let extent = extent!(replicas = 4, gpus = 2);
2451 let region = extent.range("replicas", 1..3).unwrap();
2452 assert!(
2460 region
2461 .base_rank_of_point(extent.point_of_rank(0).unwrap())
2462 .is_err()
2463 );
2464 assert_eq!(region.base_rank_of_point(point(0, ®ion)).unwrap(), 2);
2466 assert_eq!(region.base_rank_of_point(point(1, ®ion)).unwrap(), 3);
2467 assert_eq!(region.base_rank_of_point(point(2, ®ion)).unwrap(), 4);
2468 assert_eq!(region.base_rank_of_point(point(3, ®ion)).unwrap(), 5);
2469 assert_eq!(region.point_of_base_rank(2).unwrap(), point(0, ®ion));
2471 assert_eq!(region.point_of_base_rank(3).unwrap(), point(1, ®ion));
2472 assert_eq!(region.point_of_base_rank(4).unwrap(), point(2, ®ion));
2473 assert_eq!(region.point_of_base_rank(5).unwrap(), point(3, ®ion));
2474 assert!(region.point_of_base_rank(1).is_err());
2476 assert!(region.point_of_base_rank(6).is_err());
2477
2478 let subset = region
2480 .range("replicas", 1..2)
2481 .unwrap()
2482 .range("gpus", 1..2)
2483 .unwrap();
2484 assert_eq!(subset.base_rank_of_point(point(0, &subset)).unwrap(), 5);
2491 assert_eq!(subset.point_of_base_rank(5).unwrap(), point(0, &subset));
2492 assert!(subset.point_of_base_rank(4).is_err());
2494 assert!(subset.point_of_base_rank(6).is_err());
2495 }
2496
2497 use proptest::prelude::*;
2498
2499 use crate::strategy::gen_extent;
2500 use crate::strategy::gen_region;
2501 use crate::strategy::gen_region_strided;
2502
2503 proptest! {
2504 #[test]
2505 fn test_region_parser(region in gen_region(1..=5, 1024)) {
2506 assert_eq!(
2508 region,
2509 region.to_string().parse::<Region>().unwrap(),
2510 "failed to roundtrip region {}", region
2511 );
2512 }
2513 }
2514
2515 proptest! {
2527 #[test]
2528 fn region_parser_with_offset_roundtrips(region in gen_region(1..=4, 8)) {
2529 let (labels, slice) = region.clone().into_inner();
2530 let region_off = Region {
2531 labels,
2532 slice: Slice::new(8, slice.sizes().to_vec(), slice.strides().to_vec()).unwrap(),
2533 };
2534 let s = region_off.to_string();
2535 let parsed: Region = s.parse().unwrap();
2536 prop_assert_eq!(parsed, region_off);
2537 }
2538 }
2539
2540 proptest! {
2551 #[test]
2552 fn region_strided_display_parse_roundtrips(
2553 region in gen_region_strided(1..=4, 6, 3, 16)
2554 ) {
2555 let s = region.to_string();
2566 let parsed: Region = s.parse().unwrap();
2567 prop_assert_eq!(parsed, region);
2568 }
2569 }
2570
2571 proptest! {
2582 #[test]
2583 fn region_strided_display_matches_slice(
2584 region in gen_region_strided(1..=4, 6, 3, 16)
2585 ) {
2586 let s = region.to_string();
2587 let slice = region.slice();
2588
2589 if slice.offset() != 0 {
2591 let prefix: Vec<_> = s.split('+').collect();
2592 prop_assert!(prefix.len() > 1, "expected offset+ form in {}", s);
2593 let offset_str = prefix[0];
2594 let offset_val: usize = offset_str.parse().unwrap();
2595 prop_assert_eq!(offset_val, slice.offset(), "offset mismatch in {}", s);
2596 } else {
2597 prop_assert!(!s.contains('+'), "unexpected +offset in {}", s);
2598 }
2599
2600 let body = s.split('+').next_back().unwrap(); let parts: Vec<_> = body.split(',').collect();
2603 prop_assert_eq!(parts.len(), slice.sizes().len());
2604
2605 for (i, part) in parts.iter().enumerate() {
2606 let rhs = part.split('=').nth(1).unwrap();
2608 let mut nums = rhs.split('/');
2609 let size_val: usize = nums.next().unwrap().parse().unwrap();
2610 let stride_val: usize = nums.next().unwrap().parse().unwrap();
2611
2612 prop_assert_eq!(size_val, slice.sizes()[i], "size mismatch at dim {} in {}", i, s);
2613 prop_assert_eq!(stride_val, slice.strides()[i], "stride mismatch at dim {} in {}", i, s);
2614 }
2615 }
2616 }
2617
2618 #[test]
2619 fn test_point_from_str_round_trip() {
2620 let points = vec![
2621 extent!(x = 4, y = 5, z = 6).point(vec![1, 2, 3]).unwrap(),
2622 extent!(host = 2, gpu = 8).point(vec![0, 7]).unwrap(),
2623 extent!().point(vec![]).unwrap(),
2624 extent!(x = 10).point(vec![5]).unwrap(),
2625 ];
2626
2627 for point in points {
2628 assert_eq!(point, point.to_string().parse().unwrap());
2629 }
2630 }
2631
2632 #[test]
2633 fn test_point_from_str_basic() {
2634 let cases = vec![
2635 ("x=1/4,y=2/5", extent!(x = 4, y = 5), vec![1, 2]),
2636 ("host=0/2,gpu=7/8", extent!(host = 2, gpu = 8), vec![0, 7]),
2637 ("z=3/6", extent!(z = 6), vec![3]),
2638 ("", extent!(), vec![]), (" x = 1 / 4 , y = 2 / 5 ", extent!(x = 4, y = 5), vec![1, 2]),
2641 ];
2642
2643 for (input, expected_extent, expected_coords) in cases {
2644 let parsed: Point = input.parse().unwrap();
2645 let expected = expected_extent.point(expected_coords).unwrap();
2646 assert_eq!(parsed, expected, "failed to parse: {}", input);
2647 }
2648 }
2649
2650 #[test]
2651 fn test_point_from_str_quoted() {
2652 let extent = Extent::new(vec!["dim/0".into(), "dim,1".into()], vec![3, 5]).unwrap();
2654 let point = extent.point(vec![1, 2]).unwrap();
2655
2656 let display_str = point.to_string();
2657 assert_eq!(display_str, "\"dim/0\"=1/3,\"dim,1\"=2/5");
2658
2659 let parsed: Point = display_str.parse().unwrap();
2660 assert_eq!(parsed, point);
2661
2662 let parsed: Point = "\"dim/0\"=1/3,\"dim,1\"=2/5".parse().unwrap();
2663 assert_eq!(parsed, point);
2664 }
2665
2666 #[test]
2667 fn test_point_from_str_error_cases() {
2668 let error_cases = vec![
2670 "x=1,y=2/5", "x=1/4,y=2", "x=1/4,y=/5", "x=/4,y=2/5", "x=1/4,y=2/", "x=1/,y=2/5", "x=1/4=5,y=2/5", "x=1/4/6,y=2/5", "x=abc/4,y=2/5", "x=1/abc,y=2/5", "=1/4,y=2/5", "x=1/4,", "x=1/4,=2/5", "x=1/4,y", "x", "x=", "x=1/4,y=10/5", ];
2688
2689 for input in error_cases {
2690 let result: Result<Point, PointError> = input.parse();
2691 assert!(result.is_err(), "Expected error for input: '{}'", input);
2692 }
2693 }
2694
2695 #[test]
2696 fn test_point_from_str_coordinate_validation() {
2697 let input = "x=5/4,y=2/5"; let result: Result<Point, PointError> = input.parse();
2700 assert!(
2701 result.is_err(),
2702 "Expected error for out-of-bounds coordinate"
2703 );
2704
2705 match result.unwrap_err() {
2706 PointError::OutOfRangeIndex { size, index } => {
2707 assert_eq!(size, 4);
2708 assert_eq!(index, 5);
2709 }
2710 _ => panic!("Expected OutOfRangeIndex error"),
2711 }
2712 }
2713
2714 #[test]
2715 fn test_point_from_str_consistency_validation() {
2716 let input = "x=1/4,y=2/5,z=3/6";
2721 let parsed: Point = input.parse().unwrap();
2722
2723 assert_eq!(parsed.extent().labels(), &["x", "y", "z"]);
2724 assert_eq!(parsed.extent().sizes(), &[4, 5, 6]);
2725 assert_eq!(parsed.coords(), vec![1, 2, 3]);
2726 }
2727
2728 proptest! {
2729 #[test]
2732 fn point_coord_and_iter_agree(extent in gen_extent(0..=4, 8)) {
2733 for p in extent.points() {
2734 let via_coords = p.coords();
2735 let via_into_iter: Vec<_> = (&p).into_iter().collect();
2736 prop_assert_eq!(via_into_iter, via_coords.clone(), "coord_iter mismatch for {}", p);
2737
2738 for (i, &coord) in via_coords.iter().enumerate() {
2739 prop_assert_eq!(p.coord(i), coord, "coord(i) mismatch at axis {} for {}", i, p);
2740 }
2741 }
2742 }
2743
2744 #[test]
2746 fn points_count_matches_num_ranks(extent in gen_extent(0..=4, 8)) {
2747 let c = extent.points().count();
2748 prop_assert_eq!(c, extent.num_ranks(), "count {} != num_ranks {}", c, extent.num_ranks());
2749 }
2750
2751 #[test]
2755 fn coord_iter_exact_size_invariants(extent in gen_extent(0..=4, 8)) {
2756 for p in extent.points() {
2757 let mut it = (&p).into_iter();
2758
2759 let mut remaining = p.len();
2762 prop_assert_eq!(it.len(), remaining);
2763 prop_assert_eq!(it.size_hint(), (remaining, Some(remaining)));
2764
2765 let mut yielded = Vec::with_capacity(remaining);
2767
2768 while let Some(v) = it.next() {
2771 yielded.push(v);
2772 remaining -= 1;
2773 prop_assert_eq!(it.len(), remaining);
2774 prop_assert_eq!(it.size_hint(), (remaining, Some(remaining)));
2775 }
2776
2777 prop_assert_eq!(remaining, 0);
2780 prop_assert!(it.next().is_none());
2781 prop_assert!(it.next().is_none());
2782
2783 prop_assert_eq!(yielded, p.coords());
2785 }
2786 }
2787
2788 #[test]
2792 fn rank_of_coords_dim_mismatch(extent in gen_extent(0..=4, 8)) {
2793 let want = extent.len();
2794 let wrong = if want == 0 { 1 } else { want - 1 };
2796 let bad = vec![0usize; wrong];
2797
2798 match extent.rank_of_coords(&bad).unwrap_err() {
2799 PointError::DimMismatch { expected, actual } => {
2800 prop_assert_eq!(expected, want, "expected len mismatch");
2801 prop_assert_eq!(actual, wrong, "actual len mismatch");
2802 }
2803 other => prop_assert!(false, "expected DimMismatch, got {:?}", other),
2804 }
2805 }
2806
2807 #[test]
2812 fn rank_of_coords_out_of_range_index(extent in gen_extent(1..=4, 8)) {
2813 let sizes = extent.sizes().to_vec();
2815 let mut coords = vec![0usize; sizes.len()];
2817 let axis = 0usize;
2819 coords[axis] = sizes[axis];
2820
2821 match extent.rank_of_coords(&coords).unwrap_err() {
2822 PointError::OutOfRangeIndex { size, index } => {
2823 prop_assert_eq!(size, sizes[axis], "reported size mismatch");
2824 prop_assert_eq!(index, sizes[axis], "reported index mismatch");
2825 }
2826 other => prop_assert!(false, "expected OutOfRangeIndex, got {:?}", other),
2827 }
2828 }
2829
2830 #[test]
2833 fn point_of_rank_out_of_range(extent in gen_extent(0..=4, 8)) {
2834 let total = extent.num_ranks(); match extent.point_of_rank(total).unwrap_err() {
2836 PointError::OutOfRangeRank { total: t, rank: r } => {
2837 prop_assert_eq!(t, total, "reported total mismatch");
2838 prop_assert_eq!(r, total, "reported rank mismatch");
2839 }
2840 other => prop_assert!(false, "expected OutOfRangeRank, got {:?}", other),
2841 }
2842 }
2843
2844 #[test]
2846 fn point_display_parse_round_trip(extent in gen_extent(0..=4, 8)) {
2847 for point in extent.points() {
2848 let display = point.to_string();
2849 let parsed: Point = display.parse().unwrap();
2850 prop_assert_eq!(parsed, point, "round-trip failed for point: {}", display);
2851 }
2852 }
2853 }
2854
2855 proptest! {
2856 #[test]
2859 fn points_iterator_equivalent_to_legacy_cartesian(extent in gen_extent(0..=4, 8)) {
2860 let sizes = extent.sizes().to_vec();
2861
2862 let legacy_len = CartesianIterator::new(sizes.clone()).count();
2864 prop_assert_eq!(legacy_len, extent.num_ranks());
2865
2866 let legacy_coords = CartesianIterator::new(sizes);
2869 for (step, (p, coords)) in extent.points().zip(legacy_coords).enumerate() {
2870 let old_rank = extent.rank_of_coords(&coords).expect("valid legacy coords");
2871 prop_assert_eq!(p.rank(), old_rank, "rank mismatch at step {} for coords {:?}", step, coords);
2872 prop_assert_eq!(p.coords(), coords, "coords mismatch at step {}", step);
2873 }
2874 }
2875 }
2876}