1use std::fmt;
10use std::str::FromStr;
11
12use serde::Deserialize;
13use serde::Serialize;
14
15use crate::DimSliceIterator;
16use crate::Region;
17use crate::Slice;
18use crate::SliceError;
19use crate::selection::Selection;
20use crate::view::Extent;
21
22#[derive(Debug, thiserror::Error)]
25pub enum ShapeError {
26 #[error("label slice dimension mismatch: {labels_dim} != {slice_dim}")]
27 DimSliceMismatch { labels_dim: usize, slice_dim: usize },
28
29 #[error("invalid labels `{labels:?}`")]
30 InvalidLabels { labels: Vec<String> },
31
32 #[error("empty range {range}")]
33 EmptyRange { range: Range },
34
35 #[error("out of range {range} for dimension {dim} of size {size}")]
36 OutOfRange {
37 range: Range,
38 dim: String,
39 size: usize,
40 },
41
42 #[error("selection `{expr}` exceeds dimensionality {num_dim}")]
43 SelectionTooDeep { expr: Selection, num_dim: usize },
44
45 #[error("dynamic selection `{expr}`")]
46 SelectionDynamic { expr: Selection },
47
48 #[error("{index} out of range for dimension {dim} of size {size}")]
49 IndexOutOfRange {
50 index: usize,
51 dim: String,
52 size: usize,
53 },
54
55 #[error("failed to parse shape: {reason}")]
56 ParseError { reason: String },
57
58 #[error(transparent)]
59 SliceError(#[from] SliceError),
60}
61
62#[derive(Clone, Deserialize, Serialize, PartialEq, Hash, Debug)]
64pub struct Shape {
65 labels: Vec<String>,
67 slice: Slice,
69}
70
71impl Shape {
72 pub fn new(labels: Vec<String>, slice: Slice) -> Result<Self, ShapeError> {
79 if labels.len() != slice.num_dim() {
80 return Err(ShapeError::DimSliceMismatch {
81 labels_dim: labels.len(),
82 slice_dim: slice.num_dim(),
83 });
84 }
85 Ok(Self { labels, slice })
86 }
87
88 pub fn at(&self, label: &str, index: usize) -> Result<Self, ShapeError> {
93 let dim = self.dim(label)?;
94 let slice = self.slice.at(dim, index).map_err(|err| match err {
95 SliceError::IndexOutOfRange { index, total } => ShapeError::OutOfRange {
96 range: Range(index, Some(index + 1), 1),
97 dim: label.to_string(),
98 size: total,
99 },
100 other => other.into(),
101 })?;
102 let mut labels = self.labels.clone();
103 labels.remove(dim);
104 Ok(Self { labels, slice })
105 }
106
107 pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
113 let dim = self.dim(label)?;
114 let range = range.into();
115 let (begin, end, step) = range.resolve(self.slice().sizes()[dim]);
116 let slice = self
117 .slice
118 .select(dim, begin, end, step)
119 .map_err(|err| match err {
120 SliceError::EmptyRange { .. } => ShapeError::EmptyRange { range },
121 SliceError::IndexOutOfRange { total, .. } => ShapeError::OutOfRange {
122 range,
123 dim: label.to_string(),
124 size: total,
125 },
126 other => other.into(),
127 })?;
128 let labels = self.labels.clone();
129 Ok(Self { labels, slice })
130 }
131
132 pub fn select_iter(&self, dims: usize) -> Result<SelectIterator<'_>, ShapeError> {
146 let num_dims = self.slice().num_dim();
147 if dims == 0 || dims >= num_dims {
148 return Err(ShapeError::SliceError(SliceError::IndexOutOfRange {
149 index: dims,
150 total: num_dims,
151 }));
152 }
153
154 Ok(SelectIterator {
155 shape: self,
156 iter: self.slice().dim_iter(dims),
157 })
158 }
159
160 pub fn index(&self, indices: Vec<(String, usize)>) -> Result<Shape, ShapeError> {
165 let mut shape = self.clone();
166 for (label, index) in indices {
167 shape = shape.at(&label, index)?;
168 }
169 Ok(shape)
170 }
171
172 pub fn labels(&self) -> &[String] {
174 &self.labels
175 }
176
177 pub fn slice(&self) -> &Slice {
179 &self.slice
180 }
181
182 pub fn coordinates(&self, rank: usize) -> Result<Vec<(String, usize)>, ShapeError> {
184 let coords = self.slice.coordinates(rank)?;
185 Ok(coords
186 .iter()
187 .zip(self.labels.iter())
188 .map(|(i, l)| (l.to_string(), *i))
189 .collect())
190 }
191
192 pub fn dim(&self, label: &str) -> Result<usize, ShapeError> {
193 self.labels
194 .iter()
195 .position(|l| l == label)
196 .ok_or_else(|| ShapeError::InvalidLabels {
197 labels: vec![label.to_string()],
198 })
199 }
200
201 pub fn unity() -> Shape {
203 Shape::new(vec![], Slice::new(0, vec![], vec![]).expect("unity")).expect("unity")
204 }
205
206 pub fn extent(&self) -> Extent {
208 Extent::new(self.labels.clone(), self.slice.sizes().to_vec()).unwrap()
209 }
210
211 pub fn region(&self) -> Region {
213 self.into()
214 }
215}
216
217impl From<&Region> for Shape {
218 fn from(region: &Region) -> Self {
219 Shape::new(region.labels().to_vec(), region.slice().clone())
220 .expect("Shape::new should not fail because a Region by definition is a valid Shape")
221 }
222}
223
224pub struct SelectIterator<'a> {
251 shape: &'a Shape,
252 iter: DimSliceIterator,
253}
254
255impl<'a> Iterator for SelectIterator<'a> {
256 type Item = Shape;
257
258 fn next(&mut self) -> Option<Self::Item> {
259 let pos = self.iter.next()?;
260 let mut shape = self.shape.clone();
261 for (dim, index) in pos.iter().enumerate() {
262 shape = shape.select(&self.shape.labels()[dim], *index).unwrap();
263 }
264 Some(shape)
265 }
266}
267
268impl fmt::Display for Shape {
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 write!(f, "{{")?;
274 for dim in 0..self.labels.len() {
275 write!(f, "{}={}", self.labels[dim], self.slice.sizes()[dim])?;
276 if dim < self.labels.len() - 1 {
277 write!(f, ",")?;
278 }
279 }
280 write!(f, "}}")
281 }
282}
283
284impl FromStr for Shape {
285 type Err = ShapeError;
286
287 fn from_str(s: &str) -> Result<Self, Self::Err> {
288 let s = s.trim();
289
290 if !s.starts_with('{') || !s.ends_with('}') {
291 return Err(ShapeError::ParseError {
292 reason: "shape string must be enclosed in braces".to_string(),
293 });
294 }
295
296 let inner = &s[1..s.len() - 1].trim();
297
298 if inner.is_empty() {
299 return Ok(Shape::unity());
300 }
301
302 let mut labels = Vec::new();
303 let mut sizes = Vec::new();
304
305 for part in inner.split(',') {
306 let part = part.trim();
307 let mut split = part.split('=');
308
309 let label = split
310 .next()
311 .ok_or_else(|| ShapeError::ParseError {
312 reason: format!("invalid dimension format: '{}'", part),
313 })?
314 .trim();
315
316 let size_str = split
317 .next()
318 .ok_or_else(|| ShapeError::ParseError {
319 reason: format!("missing size for dimension '{}'", label),
320 })?
321 .trim();
322
323 if split.next().is_some() {
324 return Err(ShapeError::ParseError {
325 reason: format!("invalid dimension format: '{}'", part),
326 });
327 }
328
329 if label.is_empty() {
330 return Err(ShapeError::ParseError {
331 reason: format!("missing label in dimension: '{}'", part),
332 });
333 }
334
335 let size = size_str
336 .parse::<usize>()
337 .map_err(|_| ShapeError::ParseError {
338 reason: format!("invalid size '{}' for dimension '{}'", size_str, label),
339 })?;
340
341 labels.push(label.to_string());
342 sizes.push(size);
343 }
344
345 let slice = Slice::new_row_major(sizes);
346 Shape::new(labels, slice)
347 }
348}
349
350#[macro_export]
360macro_rules! shape {
361 ( $( $label:ident = $size:expr ),* $(,)? ) => {
362 {
363 let mut labels = Vec::new();
364 let mut sizes = Vec::new();
365
366 $(
367 labels.push(stringify!($label).to_string());
368 sizes.push($size);
369 )*
370
371 $crate::shape::Shape::new(labels, $crate::Slice::new_row_major(sizes)).unwrap()
372 }
373 };
374}
375
376#[macro_export]
388macro_rules! select {
389 ($shape:ident, $label:ident = $range:expr) => {
390 $shape.select(stringify!($label), $range)
391 };
392
393 ($shape:ident, $label:ident = $range:expr, $($labels:ident = $ranges:expr),+) => {
394 $shape.select(stringify!($label), $range).and_then(|shape| $crate::select!(shape, $($labels = $ranges),+))
395 };
396}
397
398#[derive(
405 Debug,
406 Clone,
407 Eq,
408 Hash,
409 PartialEq,
410 Serialize,
411 Deserialize,
412 PartialOrd,
413 Ord
414)]
415pub struct Range(pub usize, pub Option<usize>, pub usize);
416
417impl Range {
418 pub(crate) fn resolve(&self, size: usize) -> (usize, usize, usize) {
419 match self {
420 Range(begin, Some(end), stride) => (*begin, std::cmp::min(size, *end), *stride),
421 Range(begin, None, stride) => (*begin, size, *stride),
422 }
423 }
424
425 pub(crate) fn is_empty(&self) -> bool {
426 matches!(self, Range(begin, Some(end), _) if end <= begin)
427 }
428}
429
430impl fmt::Display for Range {
431 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
432 match self {
433 Range(begin, None, stride) => write!(f, "{}::{}", begin, stride),
434 Range(begin, Some(end), stride) => write!(f, "{}:{}:{}", begin, end, stride),
435 }
436 }
437}
438
439impl From<std::ops::Range<usize>> for Range {
440 fn from(r: std::ops::Range<usize>) -> Self {
441 Self(r.start, Some(r.end), 1)
442 }
443}
444
445impl From<std::ops::RangeInclusive<usize>> for Range {
446 fn from(r: std::ops::RangeInclusive<usize>) -> Self {
447 Self(*r.start(), Some(*r.end() + 1), 1)
448 }
449}
450
451impl From<std::ops::RangeFrom<usize>> for Range {
452 fn from(r: std::ops::RangeFrom<usize>) -> Self {
453 Self(r.start, None, 1)
454 }
455}
456
457impl From<usize> for Range {
458 fn from(idx: usize) -> Self {
459 Self(idx, Some(idx + 1), 1)
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use std::assert_matches::assert_matches;
466
467 use super::*;
468
469 #[test]
470 fn test_basic() {
471 let s = shape!(host = 2, gpu = 8);
472 assert_eq!(&s.labels, &["host".to_string(), "gpu".to_string()]);
473 assert_eq!(s.slice.offset(), 0);
474 assert_eq!(s.slice.sizes(), &[2, 8]);
475 assert_eq!(s.slice.strides(), &[8, 1]);
476
477 assert_eq!(s.to_string(), "{host=2,gpu=8}");
478 }
479
480 #[test]
481 fn test_select() {
482 let s = shape!(host = 2, gpu = 8);
483
484 assert_eq!(
485 s.slice().iter().collect::<Vec<_>>(),
486 &[
487 0,
488 1,
489 2,
490 3,
491 4,
492 5,
493 6,
494 7,
495 8,
496 8 + 1,
497 8 + 2,
498 8 + 3,
499 8 + 4,
500 8 + 5,
501 8 + 6,
502 8 + 7
503 ]
504 );
505
506 assert_eq!(
507 select!(s, host = 1)
508 .unwrap()
509 .slice()
510 .iter()
511 .collect::<Vec<_>>(),
512 &[8, 8 + 1, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
513 );
514
515 assert_eq!(
516 select!(s, gpu = 2..)
517 .unwrap()
518 .slice()
519 .iter()
520 .collect::<Vec<_>>(),
521 &[2, 3, 4, 5, 6, 7, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
522 );
523
524 assert_eq!(
525 select!(s, gpu = 3..5)
526 .unwrap()
527 .slice()
528 .iter()
529 .collect::<Vec<_>>(),
530 &[3, 4, 8 + 3, 8 + 4]
531 );
532
533 assert_eq!(
534 select!(s, gpu = 3..5, host = 1)
535 .unwrap()
536 .slice()
537 .iter()
538 .collect::<Vec<_>>(),
539 &[8 + 3, 8 + 4]
540 );
541 }
542
543 #[test]
544 fn test_select_iter() {
545 let s = shape!(replica = 2, host = 2, gpu = 8);
546 let selections: Vec<_> = s.select_iter(2).unwrap().collect();
547 assert_eq!(selections[0].slice().sizes(), &[1, 1, 8]);
548 assert_eq!(selections[1].slice().sizes(), &[1, 1, 8]);
549 assert_eq!(selections[2].slice().sizes(), &[1, 1, 8]);
550 assert_eq!(selections[3].slice().sizes(), &[1, 1, 8]);
551 assert_eq!(
552 selections,
553 &[
554 select!(s, replica = 0, host = 0).unwrap(),
555 select!(s, replica = 0, host = 1).unwrap(),
556 select!(s, replica = 1, host = 0).unwrap(),
557 select!(s, replica = 1, host = 1).unwrap()
558 ]
559 );
560 }
561
562 #[test]
563 fn test_coordinates() {
564 let s = shape!(host = 2, gpu = 8);
565 assert_eq!(
566 s.coordinates(0).unwrap(),
567 vec![("host".to_string(), 0), ("gpu".to_string(), 0)]
568 );
569 assert_eq!(
570 s.coordinates(1).unwrap(),
571 vec![("host".to_string(), 0), ("gpu".to_string(), 1)]
572 );
573 assert_eq!(
574 s.coordinates(8).unwrap(),
575 vec![("host".to_string(), 1), ("gpu".to_string(), 0)]
576 );
577 assert_eq!(
578 s.coordinates(9).unwrap(),
579 vec![("host".to_string(), 1), ("gpu".to_string(), 1)]
580 );
581
582 assert_matches!(
583 s.coordinates(16).unwrap_err(),
584 ShapeError::SliceError(SliceError::ValueNotInSlice { value: 16 })
585 );
586 }
587
588 #[test]
589 fn test_select_bad() {
590 let s = shape!(host = 2, gpu = 8);
591
592 assert_matches!(
593 select!(s, gpu = 1..1).unwrap_err(),
594 ShapeError::EmptyRange {
595 range: Range(1, Some(1), 1)
596 },
597 );
598
599 assert_matches!(
600 select!(s, gpu = 8).unwrap_err(),
601 ShapeError::OutOfRange {
602 range: Range(8, Some(9), 1),
603 dim,
604 size: 8,
605 } if dim == "gpu",
606 );
607 }
608
609 #[test]
610 fn test_shape_index() {
611 let n_hosts = 5;
612 let n_gpus = 7;
613
614 let s = shape!(host = n_hosts, gpu = n_gpus);
616 assert_eq!(
617 s.index(vec![("host".to_string(), 0)]).unwrap(),
618 Shape::new(
619 vec!["gpu".to_string()],
620 Slice::new(0, vec![n_gpus], vec![1]).unwrap()
621 )
622 .unwrap()
623 );
624
625 let offset = 1;
627 assert_eq!(
628 s.index(vec![("gpu".to_string(), offset)]).unwrap(),
629 Shape::new(
630 vec!["host".to_string()],
631 Slice::new(offset, vec![n_hosts], vec![n_gpus]).unwrap()
632 )
633 .unwrap()
634 );
635
636 let n_zone = 2;
638 let s = shape!(zone = n_zone, host = n_hosts, gpu = n_gpus);
639 let offset = 3;
640 assert_eq!(
641 s.index(vec![("host".to_string(), offset)]).unwrap(),
642 Shape::new(
643 vec!["zone".to_string(), "gpu".to_string()],
644 Slice::new(
645 offset * n_gpus,
646 vec![n_zone, n_gpus],
647 vec![n_hosts * n_gpus, 1]
648 )
649 .unwrap()
650 )
651 .unwrap()
652 );
653
654 assert!(
656 shape!(gpu = n_gpus)
657 .index(vec![("gpu".to_string(), n_gpus)])
658 .is_err()
659 );
660 assert!(
662 shape!(gpu = n_gpus)
663 .index(vec![("non-exist-dim".to_string(), 0)])
664 .is_err()
665 );
666 }
667
668 #[test]
669 fn test_shape_select_stride_rounding() {
670 let shape = shape!(x = 10);
671 let sub = shape.select("x", Range(0, Some(10), 3)).unwrap();
673 let slice = sub.slice();
674 assert_eq!(
676 slice,
677 &Slice::new(0, vec![4], vec![3]).unwrap(),
678 "Expected offset 0, size 4, stride 3"
679 );
680 }
681
682 #[test]
683 fn test_shape_at_removes_dimension() {
684 let labels = vec![
685 "batch".to_string(),
686 "height".to_string(),
687 "width".to_string(),
688 ];
689 let slice = Slice::new_row_major(vec![2, 3, 4]);
690 let shape = Shape::new(labels, slice).unwrap();
691
692 let result = shape.at("batch", 1).unwrap();
694
695 assert_eq!(result.labels(), &["height", "width"]);
697 assert_eq!(result.slice().sizes(), &[3, 4]);
698 assert_eq!(result.slice().offset(), 12); }
700
701 #[test]
702 fn test_shape_at_middle_dimension() {
703 let labels = vec![
704 "batch".to_string(),
705 "height".to_string(),
706 "width".to_string(),
707 ];
708 let slice = Slice::new_row_major(vec![2, 3, 4]);
709 let shape = Shape::new(labels, slice).unwrap();
710
711 let result = shape.at("height", 1).unwrap();
713
714 assert_eq!(result.labels(), &["batch", "width"]);
716 assert_eq!(result.slice().sizes(), &[2, 4]);
717 assert_eq!(result.slice().offset(), 4); }
719
720 #[test]
721 fn test_shape_at_invalid_label() {
722 let labels = vec!["batch".to_string(), "height".to_string()];
723 let slice = Slice::new_row_major(vec![2, 3]);
724 let shape = Shape::new(labels, slice).unwrap();
725
726 let result = shape.at("nonexistent", 0);
727 assert!(matches!(result, Err(ShapeError::InvalidLabels { .. })));
728 }
729
730 #[test]
731 fn test_shape_at_index_out_of_range() {
732 let labels = vec!["batch".to_string(), "height".to_string()];
733 let slice = Slice::new_row_major(vec![2, 3]);
734 let shape = Shape::new(labels, slice).unwrap();
735
736 let result = shape.at("batch", 5); assert!(matches!(result, Err(ShapeError::OutOfRange { .. })));
738 }
739
740 #[test]
741 fn test_shape_from_str_round_trip() {
742 let test_cases = vec![
743 shape!(host = 2, gpu = 8),
744 shape!(x = 1),
745 shape!(batch = 10, height = 224, width = 224, channels = 3),
746 Shape::unity(), ];
748
749 for original in test_cases {
750 let display_str = original.to_string();
751 let parsed: Shape = display_str.parse().unwrap();
752 assert_eq!(
753 parsed, original,
754 "Round-trip failed for shape: {}",
755 display_str
756 );
757 }
758 }
759
760 #[test]
761 fn test_shape_from_str_valid_cases() {
762 let test_cases = vec![
763 ("{host=2,gpu=8}", shape!(host = 2, gpu = 8)),
764 ("{x=1}", shape!(x = 1)),
765 ("{ host = 2 , gpu = 8 }", shape!(host = 2, gpu = 8)), ("{}", Shape::unity()), ];
768
769 for (input, expected) in test_cases {
770 let parsed: Shape = input.parse().unwrap();
771 assert_eq!(parsed, expected, "Failed to parse: {}", input);
772 }
773 }
774
775 #[test]
776 fn test_shape_from_str_error_cases() {
777 let error_cases = vec![
778 "host=2,gpu=8",
779 "{host=2,gpu=8",
780 "host=2,gpu=8}",
781 "{host=2,gpu=}",
782 "{host=,gpu=8}",
783 "{host=2=3,gpu=8}",
784 "{host=abc,gpu=8}",
785 "{host=2,}",
786 "{=8}",
787 ];
788
789 for input in error_cases {
790 let result: Result<Shape, ShapeError> = input.parse();
791 assert!(result.is_err(), "expected error for input: {}", input);
792 }
793 }
794}