1use std::fmt;
10use std::str::FromStr;
11
12use serde::Deserialize;
13use serde::Serialize;
14
15use crate::DimSliceIterator;
16use crate::Region;
17use crate::Slice;
18use crate::SliceError;
19use crate::selection::Selection;
20use crate::view::Extent;
21
22#[derive(Debug, thiserror::Error)]
25pub enum ShapeError {
26 #[error("label slice dimension mismatch: {labels_dim} != {slice_dim}")]
27 DimSliceMismatch { labels_dim: usize, slice_dim: usize },
28
29 #[error("invalid labels `{labels:?}`")]
30 InvalidLabels { labels: Vec<String> },
31
32 #[error("empty range {range}")]
33 EmptyRange { range: Range },
34
35 #[error("out of range {range} for dimension {dim} of size {size}")]
36 OutOfRange {
37 range: Range,
38 dim: String,
39 size: usize,
40 },
41
42 #[error("selection `{expr}` exceeds dimensionality {num_dim}")]
43 SelectionTooDeep { expr: Selection, num_dim: usize },
44
45 #[error("dynamic selection `{expr}`")]
46 SelectionDynamic { expr: Selection },
47
48 #[error("{index} out of range for dimension {dim} of size {size}")]
49 IndexOutOfRange {
50 index: usize,
51 dim: String,
52 size: usize,
53 },
54
55 #[error("failed to parse shape: {reason}")]
56 ParseError { reason: String },
57
58 #[error(transparent)]
59 SliceError(#[from] SliceError),
60}
61
62#[derive(Clone, Deserialize, Serialize, PartialEq, Hash, Debug)]
64pub struct Shape {
65 labels: Vec<String>,
67 slice: Slice,
69}
70
71impl Shape {
72 pub fn new(labels: Vec<String>, slice: Slice) -> Result<Self, ShapeError> {
79 if labels.len() != slice.num_dim() {
80 return Err(ShapeError::DimSliceMismatch {
81 labels_dim: labels.len(),
82 slice_dim: slice.num_dim(),
83 });
84 }
85 Ok(Self { labels, slice })
86 }
87
88 pub fn at(&self, label: &str, index: usize) -> Result<Self, ShapeError> {
93 let dim = self.dim(label)?;
94 let slice = self.slice.at(dim, index).map_err(|err| match err {
95 SliceError::IndexOutOfRange { index, total } => ShapeError::OutOfRange {
96 range: Range(index, Some(index + 1), 1),
97 dim: label.to_string(),
98 size: total,
99 },
100 other => other.into(),
101 })?;
102 let mut labels = self.labels.clone();
103 labels.remove(dim);
104 Ok(Self { labels, slice })
105 }
106
107 pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
113 let dim = self.dim(label)?;
114 let range = range.into();
115 let (begin, end, step) = range.resolve(self.slice().sizes()[dim]);
116 let slice = self
117 .slice
118 .select(dim, begin, end, step)
119 .map_err(|err| match err {
120 SliceError::EmptyRange { .. } => ShapeError::EmptyRange { range },
121 SliceError::IndexOutOfRange { total, .. } => ShapeError::OutOfRange {
122 range,
123 dim: label.to_string(),
124 size: total,
125 },
126 other => other.into(),
127 })?;
128 let labels = self.labels.clone();
129 Ok(Self { labels, slice })
130 }
131
132 pub fn select_iter(&self, dims: usize) -> Result<SelectIterator<'_>, ShapeError> {
146 let num_dims = self.slice().num_dim();
147 if dims == 0 || dims >= num_dims {
148 return Err(ShapeError::SliceError(SliceError::IndexOutOfRange {
149 index: dims,
150 total: num_dims,
151 }));
152 }
153
154 Ok(SelectIterator {
155 shape: self,
156 iter: self.slice().dim_iter(dims),
157 })
158 }
159
160 pub fn index(&self, indices: Vec<(String, usize)>) -> Result<Shape, ShapeError> {
165 let mut shape = self.clone();
166 for (label, index) in indices {
167 shape = shape.at(&label, index)?;
168 }
169 Ok(shape)
170 }
171
172 pub fn labels(&self) -> &[String] {
174 &self.labels
175 }
176
177 pub fn slice(&self) -> &Slice {
179 &self.slice
180 }
181
182 pub fn coordinates(&self, rank: usize) -> Result<Vec<(String, usize)>, ShapeError> {
184 let coords = self.slice.coordinates(rank)?;
185 Ok(coords
186 .iter()
187 .zip(self.labels.iter())
188 .map(|(i, l)| (l.to_string(), *i))
189 .collect())
190 }
191
192 pub fn dim(&self, label: &str) -> Result<usize, ShapeError> {
193 self.labels
194 .iter()
195 .position(|l| l == label)
196 .ok_or_else(|| ShapeError::InvalidLabels {
197 labels: vec![label.to_string()],
198 })
199 }
200
201 pub fn unity() -> Shape {
203 Shape::new(vec![], Slice::new(0, vec![], vec![]).expect("unity")).expect("unity")
204 }
205
206 pub fn extent(&self) -> Extent {
208 Extent::new(self.labels.clone(), self.slice.sizes().to_vec()).unwrap()
209 }
210
211 pub fn region(&self) -> Region {
213 self.into()
214 }
215}
216
217impl From<Region> for Shape {
218 fn from(region: Region) -> Self {
219 let (labels, slice) = region.into_inner();
220 Shape::new(labels, slice)
221 .expect("Shape::new should not fail because a Region by definition is a valid Shape")
222 }
223}
224
225impl From<&Region> for Shape {
226 fn from(region: &Region) -> Self {
227 Shape::new(region.labels().to_vec(), region.slice().clone())
228 .expect("Shape::new should not fail because a Region by definition is a valid Shape")
229 }
230}
231
232pub struct SelectIterator<'a> {
259 shape: &'a Shape,
260 iter: DimSliceIterator,
261}
262
263impl<'a> Iterator for SelectIterator<'a> {
264 type Item = Shape;
265
266 fn next(&mut self) -> Option<Self::Item> {
267 let pos = self.iter.next()?;
268 let mut shape = self.shape.clone();
269 for (dim, index) in pos.iter().enumerate() {
270 shape = shape.select(&self.shape.labels()[dim], *index).unwrap();
271 }
272 Some(shape)
273 }
274}
275
276impl fmt::Display for Shape {
277 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278 write!(f, "{{")?;
282 for dim in 0..self.labels.len() {
283 write!(f, "{}={}", self.labels[dim], self.slice.sizes()[dim])?;
284 if dim < self.labels.len() - 1 {
285 write!(f, ",")?;
286 }
287 }
288 write!(f, "}}")
289 }
290}
291
292impl FromStr for Shape {
293 type Err = ShapeError;
294
295 fn from_str(s: &str) -> Result<Self, Self::Err> {
296 let s = s.trim();
297
298 if !s.starts_with('{') || !s.ends_with('}') {
299 return Err(ShapeError::ParseError {
300 reason: "shape string must be enclosed in braces".to_string(),
301 });
302 }
303
304 let inner = &s[1..s.len() - 1].trim();
305
306 if inner.is_empty() {
307 return Ok(Shape::unity());
308 }
309
310 let mut labels = Vec::new();
311 let mut sizes = Vec::new();
312
313 for part in inner.split(',') {
314 let part = part.trim();
315 let mut split = part.split('=');
316
317 let label = split
318 .next()
319 .ok_or_else(|| ShapeError::ParseError {
320 reason: format!("invalid dimension format: '{}'", part),
321 })?
322 .trim();
323
324 let size_str = split
325 .next()
326 .ok_or_else(|| ShapeError::ParseError {
327 reason: format!("missing size for dimension '{}'", label),
328 })?
329 .trim();
330
331 if split.next().is_some() {
332 return Err(ShapeError::ParseError {
333 reason: format!("invalid dimension format: '{}'", part),
334 });
335 }
336
337 if label.is_empty() {
338 return Err(ShapeError::ParseError {
339 reason: format!("missing label in dimension: '{}'", part),
340 });
341 }
342
343 let size = size_str
344 .parse::<usize>()
345 .map_err(|_| ShapeError::ParseError {
346 reason: format!("invalid size '{}' for dimension '{}'", size_str, label),
347 })?;
348
349 labels.push(label.to_string());
350 sizes.push(size);
351 }
352
353 let slice = Slice::new_row_major(sizes);
354 Shape::new(labels, slice)
355 }
356}
357
358#[macro_export]
368macro_rules! shape {
369 ( $( $label:ident = $size:expr ),* $(,)? ) => {
370 {
371 let mut labels = Vec::new();
372 let mut sizes = Vec::new();
373
374 $(
375 labels.push(stringify!($label).to_string());
376 sizes.push($size);
377 )*
378
379 $crate::shape::Shape::new(labels, $crate::Slice::new_row_major(sizes)).unwrap()
380 }
381 };
382}
383
384#[macro_export]
396macro_rules! select {
397 ($shape:ident, $label:ident = $range:expr) => {
398 $shape.select(stringify!($label), $range)
399 };
400
401 ($shape:ident, $label:ident = $range:expr, $($labels:ident = $ranges:expr),+) => {
402 $shape.select(stringify!($label), $range).and_then(|shape| $crate::select!(shape, $($labels = $ranges),+))
403 };
404}
405
406#[derive(
413 Debug,
414 Clone,
415 Eq,
416 Hash,
417 PartialEq,
418 Serialize,
419 Deserialize,
420 PartialOrd,
421 Ord
422)]
423pub struct Range(pub usize, pub Option<usize>, pub usize);
424
425impl Range {
426 pub(crate) fn resolve(&self, size: usize) -> (usize, usize, usize) {
427 match self {
428 Range(begin, Some(end), stride) => (*begin, std::cmp::min(size, *end), *stride),
429 Range(begin, None, stride) => (*begin, size, *stride),
430 }
431 }
432
433 pub(crate) fn is_empty(&self) -> bool {
434 matches!(self, Range(begin, Some(end), _) if end <= begin)
435 }
436}
437
438impl fmt::Display for Range {
439 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
440 match self {
441 Range(begin, None, stride) => write!(f, "{}::{}", begin, stride),
442 Range(begin, Some(end), stride) => write!(f, "{}:{}:{}", begin, end, stride),
443 }
444 }
445}
446
447impl From<std::ops::Range<usize>> for Range {
448 fn from(r: std::ops::Range<usize>) -> Self {
449 Self(r.start, Some(r.end), 1)
450 }
451}
452
453impl From<std::ops::RangeInclusive<usize>> for Range {
454 fn from(r: std::ops::RangeInclusive<usize>) -> Self {
455 Self(*r.start(), Some(*r.end() + 1), 1)
456 }
457}
458
459impl From<std::ops::RangeFrom<usize>> for Range {
460 fn from(r: std::ops::RangeFrom<usize>) -> Self {
461 Self(r.start, None, 1)
462 }
463}
464
465impl From<usize> for Range {
466 fn from(idx: usize) -> Self {
467 Self(idx, Some(idx + 1), 1)
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use std::assert_matches::assert_matches;
474
475 use super::*;
476
477 #[test]
478 fn test_basic() {
479 let s = shape!(host = 2, gpu = 8);
480 assert_eq!(&s.labels, &["host".to_string(), "gpu".to_string()]);
481 assert_eq!(s.slice.offset(), 0);
482 assert_eq!(s.slice.sizes(), &[2, 8]);
483 assert_eq!(s.slice.strides(), &[8, 1]);
484
485 assert_eq!(s.to_string(), "{host=2,gpu=8}");
486 }
487
488 #[test]
489 fn test_select() {
490 let s = shape!(host = 2, gpu = 8);
491
492 assert_eq!(
493 s.slice().iter().collect::<Vec<_>>(),
494 &[
495 0,
496 1,
497 2,
498 3,
499 4,
500 5,
501 6,
502 7,
503 8,
504 8 + 1,
505 8 + 2,
506 8 + 3,
507 8 + 4,
508 8 + 5,
509 8 + 6,
510 8 + 7
511 ]
512 );
513
514 assert_eq!(
515 select!(s, host = 1)
516 .unwrap()
517 .slice()
518 .iter()
519 .collect::<Vec<_>>(),
520 &[8, 8 + 1, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
521 );
522
523 assert_eq!(
524 select!(s, gpu = 2..)
525 .unwrap()
526 .slice()
527 .iter()
528 .collect::<Vec<_>>(),
529 &[2, 3, 4, 5, 6, 7, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
530 );
531
532 assert_eq!(
533 select!(s, gpu = 3..5)
534 .unwrap()
535 .slice()
536 .iter()
537 .collect::<Vec<_>>(),
538 &[3, 4, 8 + 3, 8 + 4]
539 );
540
541 assert_eq!(
542 select!(s, gpu = 3..5, host = 1)
543 .unwrap()
544 .slice()
545 .iter()
546 .collect::<Vec<_>>(),
547 &[8 + 3, 8 + 4]
548 );
549 }
550
551 #[test]
552 fn test_select_iter() {
553 let s = shape!(replica = 2, host = 2, gpu = 8);
554 let selections: Vec<_> = s.select_iter(2).unwrap().collect();
555 assert_eq!(selections[0].slice().sizes(), &[1, 1, 8]);
556 assert_eq!(selections[1].slice().sizes(), &[1, 1, 8]);
557 assert_eq!(selections[2].slice().sizes(), &[1, 1, 8]);
558 assert_eq!(selections[3].slice().sizes(), &[1, 1, 8]);
559 assert_eq!(
560 selections,
561 &[
562 select!(s, replica = 0, host = 0).unwrap(),
563 select!(s, replica = 0, host = 1).unwrap(),
564 select!(s, replica = 1, host = 0).unwrap(),
565 select!(s, replica = 1, host = 1).unwrap()
566 ]
567 );
568 }
569
570 #[test]
571 fn test_coordinates() {
572 let s = shape!(host = 2, gpu = 8);
573 assert_eq!(
574 s.coordinates(0).unwrap(),
575 vec![("host".to_string(), 0), ("gpu".to_string(), 0)]
576 );
577 assert_eq!(
578 s.coordinates(1).unwrap(),
579 vec![("host".to_string(), 0), ("gpu".to_string(), 1)]
580 );
581 assert_eq!(
582 s.coordinates(8).unwrap(),
583 vec![("host".to_string(), 1), ("gpu".to_string(), 0)]
584 );
585 assert_eq!(
586 s.coordinates(9).unwrap(),
587 vec![("host".to_string(), 1), ("gpu".to_string(), 1)]
588 );
589
590 assert_matches!(
591 s.coordinates(16).unwrap_err(),
592 ShapeError::SliceError(SliceError::ValueNotInSlice { value: 16 })
593 );
594 }
595
596 #[test]
597 fn test_select_bad() {
598 let s = shape!(host = 2, gpu = 8);
599
600 assert_matches!(
601 select!(s, gpu = 1..1).unwrap_err(),
602 ShapeError::EmptyRange {
603 range: Range(1, Some(1), 1)
604 },
605 );
606
607 assert_matches!(
608 select!(s, gpu = 8).unwrap_err(),
609 ShapeError::OutOfRange {
610 range: Range(8, Some(9), 1),
611 dim,
612 size: 8,
613 } if dim == "gpu",
614 );
615 }
616
617 #[test]
618 fn test_shape_index() {
619 let n_hosts = 5;
620 let n_gpus = 7;
621
622 let s = shape!(host = n_hosts, gpu = n_gpus);
624 assert_eq!(
625 s.index(vec![("host".to_string(), 0)]).unwrap(),
626 Shape::new(
627 vec!["gpu".to_string()],
628 Slice::new(0, vec![n_gpus], vec![1]).unwrap()
629 )
630 .unwrap()
631 );
632
633 let offset = 1;
635 assert_eq!(
636 s.index(vec![("gpu".to_string(), offset)]).unwrap(),
637 Shape::new(
638 vec!["host".to_string()],
639 Slice::new(offset, vec![n_hosts], vec![n_gpus]).unwrap()
640 )
641 .unwrap()
642 );
643
644 let n_zone = 2;
646 let s = shape!(zone = n_zone, host = n_hosts, gpu = n_gpus);
647 let offset = 3;
648 assert_eq!(
649 s.index(vec![("host".to_string(), offset)]).unwrap(),
650 Shape::new(
651 vec!["zone".to_string(), "gpu".to_string()],
652 Slice::new(
653 offset * n_gpus,
654 vec![n_zone, n_gpus],
655 vec![n_hosts * n_gpus, 1]
656 )
657 .unwrap()
658 )
659 .unwrap()
660 );
661
662 assert!(
664 shape!(gpu = n_gpus)
665 .index(vec![("gpu".to_string(), n_gpus)])
666 .is_err()
667 );
668 assert!(
670 shape!(gpu = n_gpus)
671 .index(vec![("non-exist-dim".to_string(), 0)])
672 .is_err()
673 );
674 }
675
676 #[test]
677 fn test_shape_select_stride_rounding() {
678 let shape = shape!(x = 10);
679 let sub = shape.select("x", Range(0, Some(10), 3)).unwrap();
681 let slice = sub.slice();
682 assert_eq!(
684 slice,
685 &Slice::new(0, vec![4], vec![3]).unwrap(),
686 "Expected offset 0, size 4, stride 3"
687 );
688 }
689
690 #[test]
691 fn test_shape_at_removes_dimension() {
692 let labels = vec![
693 "batch".to_string(),
694 "height".to_string(),
695 "width".to_string(),
696 ];
697 let slice = Slice::new_row_major(vec![2, 3, 4]);
698 let shape = Shape::new(labels, slice).unwrap();
699
700 let result = shape.at("batch", 1).unwrap();
702
703 assert_eq!(result.labels(), &["height", "width"]);
705 assert_eq!(result.slice().sizes(), &[3, 4]);
706 assert_eq!(result.slice().offset(), 12); }
708
709 #[test]
710 fn test_shape_at_middle_dimension() {
711 let labels = vec![
712 "batch".to_string(),
713 "height".to_string(),
714 "width".to_string(),
715 ];
716 let slice = Slice::new_row_major(vec![2, 3, 4]);
717 let shape = Shape::new(labels, slice).unwrap();
718
719 let result = shape.at("height", 1).unwrap();
721
722 assert_eq!(result.labels(), &["batch", "width"]);
724 assert_eq!(result.slice().sizes(), &[2, 4]);
725 assert_eq!(result.slice().offset(), 4); }
727
728 #[test]
729 fn test_shape_at_invalid_label() {
730 let labels = vec!["batch".to_string(), "height".to_string()];
731 let slice = Slice::new_row_major(vec![2, 3]);
732 let shape = Shape::new(labels, slice).unwrap();
733
734 let result = shape.at("nonexistent", 0);
735 assert!(matches!(result, Err(ShapeError::InvalidLabels { .. })));
736 }
737
738 #[test]
739 fn test_shape_at_index_out_of_range() {
740 let labels = vec!["batch".to_string(), "height".to_string()];
741 let slice = Slice::new_row_major(vec![2, 3]);
742 let shape = Shape::new(labels, slice).unwrap();
743
744 let result = shape.at("batch", 5); assert!(matches!(result, Err(ShapeError::OutOfRange { .. })));
746 }
747
748 #[test]
749 fn test_shape_from_str_round_trip() {
750 let test_cases = vec![
751 shape!(host = 2, gpu = 8),
752 shape!(x = 1),
753 shape!(batch = 10, height = 224, width = 224, channels = 3),
754 Shape::unity(), ];
756
757 for original in test_cases {
758 let display_str = original.to_string();
759 let parsed: Shape = display_str.parse().unwrap();
760 assert_eq!(
761 parsed, original,
762 "Round-trip failed for shape: {}",
763 display_str
764 );
765 }
766 }
767
768 #[test]
769 fn test_shape_from_str_valid_cases() {
770 let test_cases = vec![
771 ("{host=2,gpu=8}", shape!(host = 2, gpu = 8)),
772 ("{x=1}", shape!(x = 1)),
773 ("{ host = 2 , gpu = 8 }", shape!(host = 2, gpu = 8)), ("{}", Shape::unity()), ];
776
777 for (input, expected) in test_cases {
778 let parsed: Shape = input.parse().unwrap();
779 assert_eq!(parsed, expected, "Failed to parse: {}", input);
780 }
781 }
782
783 #[test]
784 fn test_shape_from_str_error_cases() {
785 let error_cases = vec![
786 "host=2,gpu=8",
787 "{host=2,gpu=8",
788 "host=2,gpu=8}",
789 "{host=2,gpu=}",
790 "{host=,gpu=8}",
791 "{host=2=3,gpu=8}",
792 "{host=abc,gpu=8}",
793 "{host=2,}",
794 "{=8}",
795 ];
796
797 for input in error_cases {
798 let result: Result<Shape, ShapeError> = input.parse();
799 assert!(result.is_err(), "expected error for input: {}", input);
800 }
801 }
802}