1use std::fmt;
10
11use serde::Deserialize;
12use serde::Serialize;
13
14use crate::DimSliceIterator;
15use crate::Slice;
16use crate::SliceError;
17use crate::selection::Selection;
18use crate::view::Extent;
19
20#[derive(Debug, thiserror::Error)]
23pub enum ShapeError {
24 #[error("label slice dimension mismatch: {labels_dim} != {slice_dim}")]
25 DimSliceMismatch { labels_dim: usize, slice_dim: usize },
26
27 #[error("invalid labels `{labels:?}`")]
28 InvalidLabels { labels: Vec<String> },
29
30 #[error("empty range {range}")]
31 EmptyRange { range: Range },
32
33 #[error("out of range {range} for dimension {dim} of size {size}")]
34 OutOfRange {
35 range: Range,
36 dim: String,
37 size: usize,
38 },
39
40 #[error("selection `{expr}` exceeds dimensionality {num_dim}")]
41 SelectionTooDeep { expr: Selection, num_dim: usize },
42
43 #[error("dynamic selection `{expr}`")]
44 SelectionDynamic { expr: Selection },
45
46 #[error("{index} out of range for dimension {dim} of size {size}")]
47 IndexOutOfRange {
48 index: usize,
49 dim: String,
50 size: usize,
51 },
52
53 #[error(transparent)]
54 SliceError(#[from] SliceError),
55}
56
57#[derive(Clone, Deserialize, Serialize, PartialEq, Hash, Debug)]
59pub struct Shape {
60 labels: Vec<String>,
62 slice: Slice,
64}
65
66impl Shape {
67 pub fn new(labels: Vec<String>, slice: Slice) -> Result<Self, ShapeError> {
74 if labels.len() != slice.num_dim() {
75 return Err(ShapeError::DimSliceMismatch {
76 labels_dim: labels.len(),
77 slice_dim: slice.num_dim(),
78 });
79 }
80 Ok(Self { labels, slice })
81 }
82
83 pub fn at(&self, label: &str, index: usize) -> Result<Self, ShapeError> {
88 let dim = self.dim(label)?;
89 let slice = self.slice.at(dim, index).map_err(|err| match err {
90 SliceError::IndexOutOfRange { index, total } => ShapeError::OutOfRange {
91 range: Range(index, Some(index + 1), 1),
92 dim: label.to_string(),
93 size: total,
94 },
95 other => other.into(),
96 })?;
97 let mut labels = self.labels.clone();
98 labels.remove(dim);
99 Ok(Self { labels, slice })
100 }
101
102 pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
108 let dim = self.dim(label)?;
109 let range = range.into();
110 let (begin, end, step) = range.resolve(self.slice().sizes()[dim]);
111 let slice = self
112 .slice
113 .select(dim, begin, end, step)
114 .map_err(|err| match err {
115 SliceError::EmptyRange { .. } => ShapeError::EmptyRange { range },
116 SliceError::IndexOutOfRange { total, .. } => ShapeError::OutOfRange {
117 range,
118 dim: label.to_string(),
119 size: total,
120 },
121 other => other.into(),
122 })?;
123 let labels = self.labels.clone();
124 Ok(Self { labels, slice })
125 }
126
127 pub fn select_iter(&self, dims: usize) -> Result<SelectIterator, ShapeError> {
141 let num_dims = self.slice().num_dim();
142 if dims == 0 || dims >= num_dims {
143 return Err(ShapeError::SliceError(SliceError::IndexOutOfRange {
144 index: dims,
145 total: num_dims,
146 }));
147 }
148
149 Ok(SelectIterator {
150 shape: self,
151 iter: self.slice().dim_iter(dims),
152 })
153 }
154
155 pub fn index(&self, indices: Vec<(String, usize)>) -> Result<Shape, ShapeError> {
160 let mut shape = self.clone();
161 for (label, index) in indices {
162 shape = shape.at(&label, index)?;
163 }
164 Ok(shape)
165 }
166
167 pub fn labels(&self) -> &[String] {
169 &self.labels
170 }
171
172 pub fn slice(&self) -> &Slice {
174 &self.slice
175 }
176
177 pub fn coordinates(&self, rank: usize) -> Result<Vec<(String, usize)>, ShapeError> {
179 let coords = self.slice.coordinates(rank)?;
180 Ok(coords
181 .iter()
182 .zip(self.labels.iter())
183 .map(|(i, l)| (l.to_string(), *i))
184 .collect())
185 }
186
187 pub fn dim(&self, label: &str) -> Result<usize, ShapeError> {
188 self.labels
189 .iter()
190 .position(|l| l == label)
191 .ok_or_else(|| ShapeError::InvalidLabels {
192 labels: vec![label.to_string()],
193 })
194 }
195
196 pub fn unity() -> Shape {
198 Shape::new(vec![], Slice::new(0, vec![], vec![]).expect("unity")).expect("unity")
199 }
200
201 pub fn extent(&self) -> Extent {
203 Extent::new(self.labels.clone(), self.slice.sizes().to_vec()).unwrap()
204 }
205}
206
207pub struct SelectIterator<'a> {
234 shape: &'a Shape,
235 iter: DimSliceIterator,
236}
237
238impl<'a> Iterator for SelectIterator<'a> {
239 type Item = Shape;
240
241 fn next(&mut self) -> Option<Self::Item> {
242 let pos = self.iter.next()?;
243 let mut shape = self.shape.clone();
244 for (dim, index) in pos.iter().enumerate() {
245 shape = shape.select(&self.shape.labels()[dim], *index).unwrap();
246 }
247 Some(shape)
248 }
249}
250
251impl fmt::Display for Shape {
252 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253 write!(f, "{{")?;
257 for dim in 0..self.labels.len() {
258 write!(f, "{}={}", self.labels[dim], self.slice.sizes()[dim])?;
259 if dim < self.labels.len() - 1 {
260 write!(f, ",")?;
261 }
262 }
263 write!(f, "}}")
264 }
265}
266
267#[macro_export]
277macro_rules! shape {
278 ( $( $label:ident = $size:expr_2021 ),* $(,)? ) => {
279 {
280 let mut labels = Vec::new();
281 let mut sizes = Vec::new();
282
283 $(
284 labels.push(stringify!($label).to_string());
285 sizes.push($size);
286 )*
287
288 $crate::shape::Shape::new(labels, $crate::Slice::new_row_major(sizes)).unwrap()
289 }
290 };
291}
292
293#[macro_export]
305macro_rules! select {
306 ($shape:ident, $label:ident = $range:expr_2021) => {
307 $shape.select(stringify!($label), $range)
308 };
309
310 ($shape:ident, $label:ident = $range:expr_2021, $($labels:ident = $ranges:expr_2021),+) => {
311 $shape.select(stringify!($label), $range).and_then(|shape| $crate::select!(shape, $($labels = $ranges),+))
312 };
313}
314
315#[derive(
322 Debug,
323 Clone,
324 Eq,
325 Hash,
326 PartialEq,
327 Serialize,
328 Deserialize,
329 PartialOrd,
330 Ord
331)]
332pub struct Range(pub usize, pub Option<usize>, pub usize);
333
334impl Range {
335 pub(crate) fn resolve(&self, size: usize) -> (usize, usize, usize) {
336 match self {
337 Range(begin, Some(end), stride) => (*begin, std::cmp::min(size, *end), *stride),
338 Range(begin, None, stride) => (*begin, size, *stride),
339 }
340 }
341
342 pub(crate) fn is_empty(&self) -> bool {
343 matches!(self, Range(begin, Some(end), _) if end <= begin)
344 }
345}
346
347impl fmt::Display for Range {
348 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
349 match self {
350 Range(begin, None, stride) => write!(f, "{}::{}", begin, stride),
351 Range(begin, Some(end), stride) => write!(f, "{}:{}:{}", begin, end, stride),
352 }
353 }
354}
355
356impl From<std::ops::Range<usize>> for Range {
357 fn from(r: std::ops::Range<usize>) -> Self {
358 Self(r.start, Some(r.end), 1)
359 }
360}
361
362impl From<std::ops::RangeInclusive<usize>> for Range {
363 fn from(r: std::ops::RangeInclusive<usize>) -> Self {
364 Self(*r.start(), Some(*r.end() + 1), 1)
365 }
366}
367
368impl From<std::ops::RangeFrom<usize>> for Range {
369 fn from(r: std::ops::RangeFrom<usize>) -> Self {
370 Self(r.start, None, 1)
371 }
372}
373
374impl From<usize> for Range {
375 fn from(idx: usize) -> Self {
376 Self(idx, Some(idx + 1), 1)
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use std::assert_matches::assert_matches;
383
384 use super::*;
385
386 #[test]
387 fn test_basic() {
388 let s = shape!(host = 2, gpu = 8);
389 assert_eq!(&s.labels, &["host".to_string(), "gpu".to_string()]);
390 assert_eq!(s.slice.offset(), 0);
391 assert_eq!(s.slice.sizes(), &[2, 8]);
392 assert_eq!(s.slice.strides(), &[8, 1]);
393
394 assert_eq!(s.to_string(), "{host=2,gpu=8}");
395 }
396
397 #[test]
398 fn test_select() {
399 let s = shape!(host = 2, gpu = 8);
400
401 assert_eq!(
402 s.slice().iter().collect::<Vec<_>>(),
403 &[
404 0,
405 1,
406 2,
407 3,
408 4,
409 5,
410 6,
411 7,
412 8,
413 8 + 1,
414 8 + 2,
415 8 + 3,
416 8 + 4,
417 8 + 5,
418 8 + 6,
419 8 + 7
420 ]
421 );
422
423 assert_eq!(
424 select!(s, host = 1)
425 .unwrap()
426 .slice()
427 .iter()
428 .collect::<Vec<_>>(),
429 &[8, 8 + 1, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
430 );
431
432 assert_eq!(
433 select!(s, gpu = 2..)
434 .unwrap()
435 .slice()
436 .iter()
437 .collect::<Vec<_>>(),
438 &[2, 3, 4, 5, 6, 7, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
439 );
440
441 assert_eq!(
442 select!(s, gpu = 3..5)
443 .unwrap()
444 .slice()
445 .iter()
446 .collect::<Vec<_>>(),
447 &[3, 4, 8 + 3, 8 + 4]
448 );
449
450 assert_eq!(
451 select!(s, gpu = 3..5, host = 1)
452 .unwrap()
453 .slice()
454 .iter()
455 .collect::<Vec<_>>(),
456 &[8 + 3, 8 + 4]
457 );
458 }
459
460 #[test]
461 fn test_select_iter() {
462 let s = shape!(replica = 2, host = 2, gpu = 8);
463 let selections: Vec<_> = s.select_iter(2).unwrap().collect();
464 assert_eq!(selections[0].slice().sizes(), &[1, 1, 8]);
465 assert_eq!(selections[1].slice().sizes(), &[1, 1, 8]);
466 assert_eq!(selections[2].slice().sizes(), &[1, 1, 8]);
467 assert_eq!(selections[3].slice().sizes(), &[1, 1, 8]);
468 assert_eq!(
469 selections,
470 &[
471 select!(s, replica = 0, host = 0).unwrap(),
472 select!(s, replica = 0, host = 1).unwrap(),
473 select!(s, replica = 1, host = 0).unwrap(),
474 select!(s, replica = 1, host = 1).unwrap()
475 ]
476 );
477 }
478
479 #[test]
480 fn test_coordinates() {
481 let s = shape!(host = 2, gpu = 8);
482 assert_eq!(
483 s.coordinates(0).unwrap(),
484 vec![("host".to_string(), 0), ("gpu".to_string(), 0)]
485 );
486 assert_eq!(
487 s.coordinates(1).unwrap(),
488 vec![("host".to_string(), 0), ("gpu".to_string(), 1)]
489 );
490 assert_eq!(
491 s.coordinates(8).unwrap(),
492 vec![("host".to_string(), 1), ("gpu".to_string(), 0)]
493 );
494 assert_eq!(
495 s.coordinates(9).unwrap(),
496 vec![("host".to_string(), 1), ("gpu".to_string(), 1)]
497 );
498
499 assert_matches!(
500 s.coordinates(16).unwrap_err(),
501 ShapeError::SliceError(SliceError::ValueNotInSlice { value: 16 })
502 );
503 }
504
505 #[test]
506 fn test_select_bad() {
507 let s = shape!(host = 2, gpu = 8);
508
509 assert_matches!(
510 select!(s, gpu = 1..1).unwrap_err(),
511 ShapeError::EmptyRange {
512 range: Range(1, Some(1), 1)
513 },
514 );
515
516 assert_matches!(
517 select!(s, gpu = 8).unwrap_err(),
518 ShapeError::OutOfRange {
519 range: Range(8, Some(9), 1),
520 dim,
521 size: 8,
522 } if dim == "gpu",
523 );
524 }
525
526 #[test]
527 fn test_shape_index() {
528 let n_hosts = 5;
529 let n_gpus = 7;
530
531 let s = shape!(host = n_hosts, gpu = n_gpus);
533 assert_eq!(
534 s.index(vec![("host".to_string(), 0)]).unwrap(),
535 Shape::new(
536 vec!["gpu".to_string()],
537 Slice::new(0, vec![n_gpus], vec![1]).unwrap()
538 )
539 .unwrap()
540 );
541
542 let offset = 1;
544 assert_eq!(
545 s.index(vec![("gpu".to_string(), offset)]).unwrap(),
546 Shape::new(
547 vec!["host".to_string()],
548 Slice::new(offset, vec![n_hosts], vec![n_gpus]).unwrap()
549 )
550 .unwrap()
551 );
552
553 let n_zone = 2;
555 let s = shape!(zone = n_zone, host = n_hosts, gpu = n_gpus);
556 let offset = 3;
557 assert_eq!(
558 s.index(vec![("host".to_string(), offset)]).unwrap(),
559 Shape::new(
560 vec!["zone".to_string(), "gpu".to_string()],
561 Slice::new(
562 offset * n_gpus,
563 vec![n_zone, n_gpus],
564 vec![n_hosts * n_gpus, 1]
565 )
566 .unwrap()
567 )
568 .unwrap()
569 );
570
571 assert!(
573 shape!(gpu = n_gpus)
574 .index(vec![("gpu".to_string(), n_gpus)])
575 .is_err()
576 );
577 assert!(
579 shape!(gpu = n_gpus)
580 .index(vec![("non-exist-dim".to_string(), 0)])
581 .is_err()
582 );
583 }
584
585 #[test]
586 fn test_shape_select_stride_rounding() {
587 let shape = shape!(x = 10);
588 let sub = shape.select("x", Range(0, Some(10), 3)).unwrap();
590 let slice = sub.slice();
591 assert_eq!(
593 slice,
594 &Slice::new(0, vec![4], vec![3]).unwrap(),
595 "Expected offset 0, size 4, stride 3"
596 );
597 }
598
599 #[test]
600 fn test_shape_at_removes_dimension() {
601 let labels = vec![
602 "batch".to_string(),
603 "height".to_string(),
604 "width".to_string(),
605 ];
606 let slice = Slice::new_row_major(vec![2, 3, 4]);
607 let shape = Shape::new(labels, slice).unwrap();
608
609 let result = shape.at("batch", 1).unwrap();
611
612 assert_eq!(result.labels(), &["height", "width"]);
614 assert_eq!(result.slice().sizes(), &[3, 4]);
615 assert_eq!(result.slice().offset(), 12); }
617
618 #[test]
619 fn test_shape_at_middle_dimension() {
620 let labels = vec![
621 "batch".to_string(),
622 "height".to_string(),
623 "width".to_string(),
624 ];
625 let slice = Slice::new_row_major(vec![2, 3, 4]);
626 let shape = Shape::new(labels, slice).unwrap();
627
628 let result = shape.at("height", 1).unwrap();
630
631 assert_eq!(result.labels(), &["batch", "width"]);
633 assert_eq!(result.slice().sizes(), &[2, 4]);
634 assert_eq!(result.slice().offset(), 4); }
636
637 #[test]
638 fn test_shape_at_invalid_label() {
639 let labels = vec!["batch".to_string(), "height".to_string()];
640 let slice = Slice::new_row_major(vec![2, 3]);
641 let shape = Shape::new(labels, slice).unwrap();
642
643 let result = shape.at("nonexistent", 0);
644 assert!(matches!(result, Err(ShapeError::InvalidLabels { .. })));
645 }
646
647 #[test]
648 fn test_shape_at_index_out_of_range() {
649 let labels = vec!["batch".to_string(), "height".to_string()];
650 let slice = Slice::new_row_major(vec![2, 3]);
651 let shape = Shape::new(labels, slice).unwrap();
652
653 let result = shape.at("batch", 5); assert!(matches!(result, Err(ShapeError::OutOfRange { .. })));
655 }
656}