1use std::fmt;
25
26use crate::Range;
27use crate::Selection;
28use crate::dsl::union;
29use crate::shape::Shape;
30use crate::slice::Slice;
31
32pub type Coord = Vec<usize>;
35
36pub struct ReshapedShape {
42 pub shape: Shape,
45
46 pub factors: Vec<(String, Vec<usize>)>,
49}
50
51#[allow(dead_code)]
52const _: () = {
53 fn assert<T: Send + Sync + 'static>() {}
54 let _ = assert::<ReshapedShape>;
55};
56
57impl std::fmt::Debug for ReshapedShape {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("ReshapedShape")
60 .field("labels", &self.shape.labels())
61 .field("sizes", &self.shape.slice().sizes())
62 .field("strides", &self.shape.slice().strides())
63 .field("offset", &self.shape.slice().offset())
64 .field("factors", &self.factors)
65 .finish()
66 }
67}
68
69impl std::fmt::Display for ReshapedShape {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
71 write!(
72 f,
73 "ReshapedShape {{ [off={} sz={:?} st={:?} lab={:?} fac={:?}] }}",
74 self.shape.slice().offset(),
75 self.shape.slice().sizes(),
76 self.shape.slice().strides(),
77 self.shape.labels(),
78 self.factors
79 )
80 }
81}
82
83pub(crate) fn factor_dims(sizes: &[usize], limit: Limit) -> Vec<Vec<usize>> {
97 let limit = limit.get();
98 sizes
99 .iter()
100 .map(|&size| {
101 if size <= limit {
102 return vec![size];
103 }
104 let mut rem = size;
105 let mut factors = Vec::new();
106 for d in (2..=limit).rev() {
107 while rem % d == 0 {
108 factors.push(d);
109 rem /= d;
110 }
111 }
112 if rem > 1 {
113 factors.push(rem);
114 }
115 factors
116 })
117 .collect()
118}
119
120pub fn to_reshaped_coord<'a>(
124 original: &'a Slice,
125 reshaped: &'a Slice,
126) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
127 let original = original.clone();
128 let reshaped = reshaped.clone();
129 move |coord: &[usize]| -> Coord {
130 let flat = original.location(coord).unwrap();
131 reshaped.coordinates(flat).unwrap()
132 }
133}
134
135pub fn to_original_coord<'a>(
139 reshaped: &'a Slice,
140 original: &'a Slice,
141) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
142 let reshaped = reshaped.clone();
143 let original = original.clone();
144 move |coord: &[usize]| -> Coord {
145 let flat = reshaped.location(coord).unwrap();
146 original.coordinates(flat).unwrap()
147 }
148}
149
150#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
168pub struct Limit(usize);
169
170impl Limit {
171 pub fn new(n: usize) -> Self {
173 assert!(n >= 1, "Limit must be at least 1");
174 Self(n)
175 }
176
177 pub fn get(self) -> usize {
179 self.0
180 }
181}
182
183impl Default for Limit {
184 fn default() -> Self {
185 Self(32)
186 }
187}
188
189impl From<usize> for Limit {
190 fn from(n: usize) -> Self {
191 Self::new(n)
192 }
193}
194
195pub trait ReshapeSliceExt {
215 fn reshape_with_limit(&self, limit: Limit) -> Slice;
227}
228
229impl ReshapeSliceExt for Slice {
230 fn reshape_with_limit(&self, limit: Limit) -> Slice {
231 reshape_with_limit(self, limit)
232 }
233}
234
235pub trait ReshapeShapeExt {
237 fn reshape(&self, limit: Limit) -> ReshapedShape;
240}
241
242impl ReshapeShapeExt for Shape {
243 fn reshape(&self, limit: Limit) -> ReshapedShape {
244 reshape_shape(self, limit)
245 }
246}
247
248pub mod prelude {
251 pub use super::ReshapeShapeExt;
252 pub use super::ReshapeSliceExt;
253}
254
255pub fn reshape_with_limit(slice: &Slice, limit: Limit) -> Slice {
285 let orig_sizes = slice.sizes();
286 let orig_strides = slice.strides();
287
288 let factored_sizes = factor_dims(orig_sizes, limit);
290
291 let reshaped_sizes: Vec<usize> = factored_sizes.iter().flatten().cloned().collect();
293 let mut reshaped_strides = Vec::with_capacity(reshaped_sizes.len());
294
295 for (&orig_stride, factors) in orig_strides.iter().zip(&factored_sizes) {
296 let mut sub_strides = Vec::with_capacity(factors.len());
297 let mut stride = orig_stride;
298 for &f in factors.iter().rev() {
299 sub_strides.push(stride);
300 stride *= f;
301 }
302 sub_strides.reverse();
303 reshaped_strides.extend(sub_strides);
304 }
305
306 Slice::new(slice.offset(), reshaped_sizes, reshaped_strides).unwrap()
307}
308
309pub fn reshape_shape(shape: &Shape, limit: Limit) -> ReshapedShape {
329 let reshaped_slice = shape.slice().reshape_with_limit(limit);
330 let original_labels = shape.labels();
331 let original_sizes = shape.slice().sizes();
332
333 let factors = factor_dims(original_sizes, limit);
334 let factored_dims: Vec<(String, Vec<usize>)> =
335 original_labels.iter().cloned().zip(factors).collect();
336
337 let labels = expand_labels(&factored_dims);
338 let shape = Shape::new(labels, reshaped_slice).expect("invalid reshaped shape");
339
340 ReshapedShape {
341 shape,
342 factors: factored_dims,
343 }
344}
345
346pub fn expand_labels(factors: &[(String, Vec<usize>)]) -> Vec<String> {
367 let mut labels = Vec::new();
368 for (label, dims) in factors {
369 if dims.len() == 1 {
370 labels.push(label.clone());
371 } else {
372 for (i, _) in dims.iter().enumerate() {
373 labels.push(format!("{}/{}", label, i));
374 }
375 }
376 }
377 labels
378}
379
380#[derive(Debug, thiserror::Error)]
381pub enum ReshapeError {
382 #[error("unsupported selection kind {selection}")]
383 UnsupportedSelection { selection: Selection },
384}
385pub fn reshape_selection(
389 selection: Selection,
390 original_slice: &Slice,
391 reshaped_slice: &Slice,
392) -> Result<Selection, ReshapeError> {
393 fn recursive_fold(
394 selection: Selection,
395 original_slice: &Slice,
396 original_size_index: usize,
397 reshaped_slice: &Slice,
398 reshaped_size_index: usize,
399 ) -> Result<Selection, ReshapeError> {
400 if matches!(selection, Selection::True | Selection::False) {
401 return Ok(selection);
402 }
403
404 let Some(&original_dim_size) = original_slice.sizes().get(original_size_index) else {
405 return Ok(selection);
406 };
407
408 let mut accum = *reshaped_slice.sizes().get(reshaped_size_index).unwrap();
409 let mut next_reshaped_dimension_start = reshaped_size_index + 1;
410
411 while accum < original_dim_size {
412 accum *= *reshaped_slice
413 .sizes()
414 .get(next_reshaped_dimension_start)
415 .unwrap();
416 next_reshaped_dimension_start += 1;
417 }
418
419 match selection {
420 Selection::True | Selection::False => Ok(selection),
422 Selection::Union(left, right) => {
424 let left = recursive_fold(
425 *left,
426 original_slice,
427 original_size_index,
428 reshaped_slice,
429 reshaped_size_index,
430 )?;
431
432 match left {
433 Selection::True => return Ok(Selection::True),
434 Selection::False => {
435 return recursive_fold(
436 *right,
437 original_slice,
438 original_size_index,
439 reshaped_slice,
440 reshaped_size_index,
441 );
442 }
443 _ => {}
444 }
445
446 let right = recursive_fold(
447 *right,
448 original_slice,
449 original_size_index,
450 reshaped_slice,
451 reshaped_size_index,
452 )?;
453
454 Ok(match right {
455 Selection::True => Selection::True,
456 Selection::False => left,
457 _ => Selection::Union(Box::new(left), Box::new(right)),
458 })
459 }
460 Selection::Intersection(left, right) => {
461 let left = recursive_fold(
462 *left,
463 original_slice,
464 original_size_index,
465 reshaped_slice,
466 reshaped_size_index,
467 )?;
468 match left {
469 Selection::False => return Ok(Selection::False),
470 Selection::True => {
471 return recursive_fold(
472 *right,
473 original_slice,
474 original_size_index,
475 reshaped_slice,
476 reshaped_size_index,
477 );
478 }
479 _ => {}
480 }
481
482 let right = recursive_fold(
483 *right,
484 original_slice,
485 original_size_index,
486 reshaped_slice,
487 reshaped_size_index,
488 )?;
489 Ok(match right {
490 Selection::False => Selection::False,
491 Selection::True => left,
492 _ => Selection::Intersection(Box::new(left), Box::new(right)),
493 })
494 }
495 Selection::All(inner) => {
496 let inner = recursive_fold(
497 *inner,
498 original_slice,
499 original_size_index + 1,
500 reshaped_slice,
501 next_reshaped_dimension_start,
502 )?;
503
504 if matches!(inner, Selection::True | Selection::False) {
505 return Ok(inner);
506 }
507
508 Ok((reshaped_size_index..next_reshaped_dimension_start - 1)
509 .fold(Selection::All(Box::new(inner)), |result, _| {
510 Selection::All(Box::new(result))
511 }))
512 }
513 Selection::Any(inner) => {
514 let inner = recursive_fold(
515 *inner,
516 original_slice,
517 original_size_index + 1,
518 reshaped_slice,
519 next_reshaped_dimension_start,
520 )?;
521
522 if matches!(inner, Selection::False) {
523 return Ok(inner);
524 }
525
526 Ok((reshaped_size_index..next_reshaped_dimension_start - 1)
527 .fold(Selection::Any(Box::new(inner)), |result, _| {
528 Selection::Any(Box::new(result))
529 }))
530 }
531 Selection::First(inner) => {
532 let inner = recursive_fold(
533 *inner,
534 original_slice,
535 original_size_index + 1,
536 reshaped_slice,
537 next_reshaped_dimension_start,
538 )?;
539
540 if matches!(inner, Selection::False) {
541 return Ok(inner);
542 }
543
544 Ok((reshaped_size_index..next_reshaped_dimension_start - 1)
545 .fold(Selection::First(Box::new(inner)), |result, _| {
546 Selection::First(Box::new(result))
547 }))
548 }
549 Selection::Range(range, inner) => {
550 fn fold_once(
576 Range(start, end, step): Range,
577 inner: Selection,
578 original_dimension_n_size: usize,
579 new_dimension_n_size: usize,
580 ) -> Vec<Selection> {
581 let dimension_n_plus_one_start = start / new_dimension_n_size;
582 let dimension_n_plus_one_end = end.map(|end| {
583 if end % new_dimension_n_size == 0 {
584 end / new_dimension_n_size
585 } else {
586 end / new_dimension_n_size + 1
587 }
588 });
589 let dimension_n_plus_one_size =
590 original_dimension_n_size / new_dimension_n_size;
591
592 let new_dimension_n_start = start % new_dimension_n_size;
593 let new_dimension_n_end = end.map(|end| {
594 if end % new_dimension_n_size == 0 && end > new_dimension_n_size - 1 {
595 new_dimension_n_size
598 } else {
599 end % new_dimension_n_size
600 }
601 });
602
603 let mut result = vec![];
604
605 if dimension_n_plus_one_end
607 .is_some_and(|end| dimension_n_plus_one_start + 1 == end)
608 || (end.is_none()
609 && dimension_n_plus_one_start == dimension_n_plus_one_size)
610 {
611 return vec![Selection::Range(
612 Range(dimension_n_plus_one_start, dimension_n_plus_one_end, 1),
613 Box::new(Selection::Range(
614 Range(new_dimension_n_start, new_dimension_n_end, step),
615 Box::new(inner.clone()),
616 )),
617 )];
618 }
619
620 if step == 1 {
622 let middle_start = match start % new_dimension_n_size {
627 0 => dimension_n_plus_one_start,
628 _ => {
629 result.push(Selection::Range(
630 Range(
631 dimension_n_plus_one_start,
632 Some(dimension_n_plus_one_start + 1),
633 1,
634 ),
635 Box::new(Selection::Range(
636 Range(
637 new_dimension_n_start,
638 Some(new_dimension_n_size),
639 step,
640 ),
641 Box::new(inner.clone()),
642 )),
643 ));
644 dimension_n_plus_one_start + 1
645 }
646 };
647
648 let middle_end = match (end, dimension_n_plus_one_end) {
653 (Some(end), Some(dimension_n_plus_one_end))
654 if end % new_dimension_n_size != 0 =>
655 {
656 result.push(Selection::Range(
657 Range(
658 dimension_n_plus_one_end - 1,
659 Some(dimension_n_plus_one_end),
660 1,
661 ),
662 Box::new(Selection::Range(
663 Range(0, new_dimension_n_end, step),
664 Box::new(inner.clone()),
665 )),
666 ));
667 Some(dimension_n_plus_one_end - 1)
668 }
669 _ => dimension_n_plus_one_end,
670 };
671
672 if middle_end.is_some_and(|end| end > middle_start)
677 || (middle_end.is_none() && middle_start < dimension_n_plus_one_size)
678 {
679 result.push(Selection::Range(
680 Range(middle_start, middle_end, 1),
681 Box::new(Selection::All(Box::new(inner.clone()))),
682 ));
683 }
684 } else {
687 fn gcd(a: usize, b: usize) -> usize {
689 if b == 0 { a } else { gcd(b, a % b) }
690 }
691
692 let row_pattern_period = step / gcd(step, new_dimension_n_size);
693
694 let mut row_col_iter = std::iter::successors(
696 Some((dimension_n_plus_one_start, start % new_dimension_n_size)),
697 |&(row, col)| {
698 let cols_before_end = new_dimension_n_size - 1 - col;
699 let steps_before_end = cols_before_end / step;
700 let last_col_before_end = col + step * steps_before_end;
701
702 let next_row =
703 ((row * new_dimension_n_size) + last_col_before_end + step)
704 / new_dimension_n_size;
705 let next_col = (last_col_before_end + step) % new_dimension_n_size;
706
707 Some((next_row, next_col))
708 },
709 )
710 .peekable();
711
712 if start % new_dimension_n_size != 0 {
714 let (row, col) = row_col_iter.next().unwrap();
715
716 result.push(Selection::Range(
717 Range(row, Some(row + 1), 1),
718 Box::new(Selection::Range(
719 Range(col, None, step),
720 Box::new(inner.clone()),
721 )),
722 ));
723 };
724
725 for _ in 0..row_pattern_period {
727 let end_row = end.map(|end| end / new_dimension_n_size);
728
729 if match end_row {
730 Some(end_row) => row_col_iter.peek().unwrap().0 >= end_row,
731 None => row_col_iter.peek().unwrap().0 >= dimension_n_plus_one_size,
732 } {
733 break;
734 }
735 let (row_index, col) = row_col_iter.next().unwrap();
736
737 result.push(Selection::Range(
738 Range(row_index, end_row, row_pattern_period),
739 Box::new(Selection::Range(
740 Range(col, None, step),
741 Box::new(inner.clone()),
742 )),
743 ));
744 }
745
746 if let Some(end) = end {
748 let end_row = end / new_dimension_n_size;
749
750 for (row, col) in row_col_iter {
751 if row > end_row {
752 break;
753 }
754
755 if row % row_pattern_period == end_row % row_pattern_period
756 && col < end % new_dimension_n_size
757 {
758 result.push(Selection::Range(
759 Range(end_row, Some(end_row + 1), 1),
760 Box::new(Selection::Range(
761 Range(col, Some(end % new_dimension_n_size), step),
762 Box::new(inner.clone()),
763 )),
764 ));
765 break;
766 }
767 }
768 }
769 }
770 result
771 }
772
773 let inner = recursive_fold(
774 *inner,
775 original_slice,
776 original_size_index + 1,
777 reshaped_slice,
778 next_reshaped_dimension_start,
779 )?;
780 if matches!(inner, Selection::False) {
781 return Ok(inner);
782 }
783 let mut pieces = vec![Selection::Range(range, Box::new(inner))];
784
785 let reversed_dimensions = reshaped_slice.sizes()
788 [reshaped_size_index + 1..next_reshaped_dimension_start]
789 .iter()
790 .copied()
791 .rev();
792
793 let mut original_dimension_size = original_dim_size;
794 for dimension in reversed_dimensions {
795 pieces = pieces
796 .into_iter()
797 .flat_map(|piece| {
798 if let Selection::Range(range, inner) = piece {
799 fold_once(range, *inner, original_dimension_size, dimension)
800 } else {
801 vec![]
802 }
803 })
804 .collect();
805 original_dimension_size /= dimension;
806 }
807
808 Ok(pieces.into_iter().fold(Selection::False, |x, y| match x {
809 Selection::False => y,
810 _ => union(x, y),
811 }))
812 }
813 _ => Err(ReshapeError::UnsupportedSelection { selection }),
814 }
815 }
816
817 recursive_fold(selection, original_slice, 0, reshaped_slice, 0)
818}
819
820#[cfg(test)]
821mod tests {
822 use super::*;
823 use crate::Slice;
824 use crate::shape;
825
826 #[test]
827 fn test_factor_dims_basic() {
828 assert_eq!(
829 factor_dims(&[6, 8], Limit::from(4)),
830 vec![vec![3, 2], vec![4, 2]]
831 );
832 assert_eq!(factor_dims(&[5], Limit::from(3)), vec![vec![5]]);
833 assert_eq!(factor_dims(&[30], Limit::from(5)), vec![vec![5, 3, 2]]);
834 }
835
836 #[macro_export]
847 macro_rules! assert_layout_preserved {
848 ($original:expr, $reshaped:expr) => {{
849 for coord in $original.dim_iter($original.num_dim()) {
851 let forward = to_reshaped_coord($original, &$reshaped);
852 let inverse = to_original_coord(&$reshaped, $original);
853 let reshaped_coord = forward(&coord);
856 let roundtrip = inverse(&reshaped_coord);
858 assert_eq!(
859 roundtrip, coord,
860 "Inverse mismatch: reshaped {:?} → original {:?}, expected {:?}",
861 reshaped_coord, roundtrip, coord
862 );
863 let flat_orig = $original.location(&coord).unwrap();
865 let flat_reshaped = $reshaped.location(&reshaped_coord).unwrap();
867 assert_eq!(
870 flat_orig, flat_reshaped,
871 "Flat index mismatch: original {:?} → reshaped {:?}",
872 coord, reshaped_coord
873 );
874 let recovered = $reshaped.coordinates(flat_reshaped).unwrap();
876 assert_eq!(
879 reshaped_coord, recovered,
880 "Coordinate mismatch: flat index {} → expected {:?}, got {:?}",
881 flat_reshaped, reshaped_coord, recovered
882 );
883 }
884 }};
885 }
886
887 #[test]
888 fn test_reshape_split_1d_row_major() {
889 let s = Slice::new_row_major(vec![1024]);
890 let reshaped = s.reshape_with_limit(Limit::from(8));
891
892 assert_eq!(reshaped.offset(), 0);
893 assert_eq!(reshaped.sizes(), &vec![8, 8, 8, 2]);
894 assert_eq!(reshaped.strides(), &vec![128, 16, 2, 1]);
895 assert_eq!(
896 factor_dims(s.sizes(), Limit::from(8)),
897 vec![vec![8, 8, 8, 2]]
898 );
899
900 assert_layout_preserved!(&s, &reshaped);
901 }
902
903 #[test]
904 fn test_reshape_6_with_limit_2() {
905 let s = Slice::new_row_major(vec![6]);
906 let reshaped = reshape_with_limit(&s, Limit::from(2));
907 assert_eq!(factor_dims(s.sizes(), Limit::from(2)), vec![vec![2, 3]]);
908 assert_layout_preserved!(&s, &reshaped);
909 }
910
911 #[test]
912 fn test_reshape_identity_noop_2d() {
913 let original = Slice::new_row_major(vec![4, 8]);
915 let reshaped = original.reshape_with_limit(Limit::from(8));
916
917 assert_eq!(reshaped.sizes(), original.sizes());
918 assert_eq!(reshaped.strides(), original.strides());
919 assert_eq!(reshaped.offset(), original.offset());
920 assert_eq!(
921 vec![vec![4], vec![8]],
922 original
923 .sizes()
924 .iter()
925 .map(|&n| vec![n])
926 .collect::<Vec<_>>()
927 );
928 assert_layout_preserved!(&original, &reshaped);
929 }
930
931 #[test]
932 fn test_reshape_empty_slice() {
933 let original = Slice::new_row_major(vec![]);
935 let reshaped = reshape_with_limit(&original, Limit::from(8));
936
937 assert_eq!(reshaped.sizes(), original.sizes());
938 assert_eq!(reshaped.strides(), original.strides());
939 assert_eq!(reshaped.offset(), original.offset());
940
941 assert_layout_preserved!(&original, &reshaped);
942 }
943
944 #[test]
945 fn test_reshape_mixed_dims_3d() {
946 let original = Slice::new_row_major(vec![6, 8, 10]);
948 let reshaped = original.reshape_with_limit(Limit::from(4));
949
950 assert_eq!(
951 factor_dims(original.sizes(), Limit::from(4)),
952 vec![vec![3, 2], vec![4, 2], vec![2, 5]]
953 );
954 assert_eq!(reshaped.sizes(), &[3, 2, 4, 2, 2, 5]);
955
956 assert_layout_preserved!(&original, &reshaped);
957 }
958
959 #[test]
960 fn test_reshape_all_large_dims() {
961 let original = Slice::new_row_major(vec![12, 18, 20]);
963 let reshaped = original.reshape_with_limit(Limit::from(4));
964
965 assert_eq!(
966 factor_dims(original.sizes(), Limit::from(4)),
967 vec![vec![4, 3], vec![3, 3, 2], vec![4, 5]]
968 );
969 assert_eq!(reshaped.sizes(), &[4, 3, 3, 3, 2, 4, 5]);
970
971 assert_layout_preserved!(&original, &reshaped);
972 }
973
974 #[test]
975 fn test_reshape_split_1d_factors_3_3_2_2() {
976 let original = Slice::new_row_major(vec![36]);
978 let reshaped = reshape_with_limit(&original, Limit::from(3));
979
980 assert_eq!(
981 factor_dims(original.sizes(), Limit::from(3)),
982 vec![vec![3, 3, 2, 2]]
983 );
984 assert_eq!(reshaped.sizes(), &[3, 3, 2, 2]);
985 assert_layout_preserved!(&original, &reshaped);
986 }
987
988 #[test]
989 fn test_reshape_large_prime_dimension() {
990 let original = Slice::new_row_major(vec![7]);
992 let reshaped = reshape_with_limit(&original, Limit::from(4));
993
994 assert_eq!(factor_dims(original.sizes(), Limit::from(4)), vec![vec![7]]);
996 assert_eq!(reshaped.sizes(), &[7]);
997
998 assert_layout_preserved!(&original, &reshaped);
999 }
1000
1001 #[test]
1002 fn test_reshape_split_1d_factors_5_3_2() {
1003 let original = Slice::new_row_major(vec![30]);
1005 let reshaped = reshape_with_limit(&original, Limit::from(5));
1006
1007 assert_eq!(
1008 factor_dims(original.sizes(), Limit::from(5)),
1009 vec![vec![5, 3, 2]]
1010 );
1011 assert_eq!(reshaped.sizes(), &[5, 3, 2]);
1012 assert_eq!(reshaped.strides(), &[6, 2, 1]);
1013
1014 assert_layout_preserved!(&original, &reshaped);
1015 }
1016
1017 #[test]
1018 fn test_reshape_factors_2_6_2_8_8() {
1019 let original = Slice::new_row_major(vec![2, 12, 64]);
1021 let reshaped = original.reshape_with_limit(Limit::from(8));
1022
1023 assert_eq!(
1024 factor_dims(original.sizes(), Limit::from(8)),
1025 vec![vec![2], vec![6, 2], vec![8, 8]]
1026 );
1027 assert_eq!(reshaped.sizes(), &[2, 6, 2, 8, 8]);
1028 assert_eq!(reshaped.strides(), &[768, 128, 64, 8, 1]);
1029
1030 assert_layout_preserved!(&original, &reshaped);
1031 }
1032
1033 #[test]
1034 fn test_reshape_all_dims_within_limit() {
1035 let original = Slice::new_row_major(vec![2, 3, 4]);
1037 let reshaped = original.reshape_with_limit(Limit::from(4));
1038
1039 assert_eq!(
1040 factor_dims(original.sizes(), Limit::from(4)),
1041 vec![vec![2], vec![3], vec![4]]
1042 );
1043 assert_eq!(reshaped.sizes(), &[2, 3, 4]);
1044 assert_eq!(reshaped.strides(), original.strides());
1045 assert_eq!(reshaped.offset(), original.offset());
1046
1047 assert_layout_preserved!(&original, &reshaped);
1048 }
1049
1050 #[test]
1051 fn test_reshape_degenerate_dimension() {
1052 let original = Slice::new_row_major(vec![1, 12]);
1054 let reshaped = original.reshape_with_limit(Limit::from(4));
1055
1056 assert_eq!(
1057 factor_dims(original.sizes(), Limit::from(4)),
1058 vec![vec![1], vec![4, 3]]
1059 );
1060 assert_eq!(reshaped.sizes(), &[1, 4, 3]);
1061
1062 assert_layout_preserved!(&original, &reshaped);
1063 }
1064
1065 #[test]
1066 fn test_select_then_reshape() {
1067 let original = shape!(zone = 2, host = 3, gpu = 4);
1069
1070 let selected = original.select("zone", 1).unwrap();
1072 assert_eq!(selected.slice().offset(), 12); assert_eq!(selected.slice().sizes(), &[1, 3, 4]);
1074
1075 let reshaped = selected.slice().reshape_with_limit(Limit::from(2));
1078
1079 assert_eq!(
1080 factor_dims(selected.slice().sizes(), Limit::from(2)),
1081 vec![vec![1], vec![3], vec![2, 2]]
1082 );
1083 assert_eq!(reshaped.sizes(), &[1, 3, 2, 2]);
1084 assert_eq!(reshaped.strides(), &[12, 4, 2, 1]);
1085 assert_eq!(reshaped.offset(), 12); assert_layout_preserved!(selected.slice(), &reshaped);
1088 }
1089
1090 #[test]
1091 fn test_select_host_plane_then_reshape() {
1092 let original = shape!(zone = 2, host = 3, gpu = 4);
1094 let selected = original.select("host", 2).unwrap();
1096 let reshaped = selected.slice().reshape_with_limit(Limit::from(2));
1099
1100 assert_layout_preserved!(selected.slice(), &reshaped);
1101 }
1102
1103 #[test]
1104 fn test_reshape_after_select_no_factoring_due_to_primes() {
1105 let original = shape!(zone = 3, host = 4, gpu = 5);
1107 let selected_zone = original.select("zone", 1).unwrap();
1109 assert_eq!(selected_zone.slice().sizes(), &[1, 4, 5]);
1110 let selected_host = selected_zone.select("host", 2).unwrap();
1112 assert_eq!(selected_host.slice().sizes(), &[1, 1, 5]);
1113 let reshaped = selected_host.slice().reshape_with_limit(Limit::from(2));
1115
1116 assert_eq!(
1117 factor_dims(selected_host.slice().sizes(), Limit::from(2)),
1118 vec![vec![1], vec![1], vec![5]]
1119 );
1120 assert_eq!(reshaped.sizes(), &[1, 1, 5]);
1121
1122 assert_layout_preserved!(selected_host.slice(), &reshaped);
1123 }
1124
1125 #[test]
1126 fn test_reshape_after_multiple_selects_triggers_factoring() {
1127 let original = shape!(zone = 2, host = 4, gpu = 8);
1129 let selected_zone = original.select("zone", 1).unwrap();
1131 assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
1132
1133 let selected_host = selected_zone.select("host", 2).unwrap();
1135 assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
1136
1137 let reshaped = selected_host.slice().reshape_with_limit(Limit::from(2));
1139
1140 assert_eq!(
1141 factor_dims(selected_host.slice().sizes(), Limit::from(2)),
1142 vec![vec![1], vec![1], vec![2, 2, 2]]
1143 );
1144 assert_eq!(reshaped.sizes(), &[1, 1, 2, 2, 2]);
1145
1146 assert_layout_preserved!(selected_host.slice(), &reshaped);
1147 }
1148
1149 #[test]
1150 fn test_expand_labels_singleton_dims() {
1151 let factors = vec![("x".into(), vec![2]), ("y".into(), vec![4])];
1152 let expected = vec!["x", "y"];
1153 assert_eq!(expand_labels(&factors), expected);
1154 }
1155
1156 #[test]
1157 fn test_expand_labels_factored_dims() {
1158 let factors = vec![("gpu".into(), vec![2, 2, 2])];
1159 let expected = vec!["gpu/0", "gpu/1", "gpu/2"];
1160 assert_eq!(expand_labels(&factors), expected);
1161 }
1162
1163 #[test]
1164 fn test_expand_labels_mixed_dims() {
1165 let factors = vec![("zone".into(), vec![2]), ("gpu".into(), vec![2, 2])];
1166 let expected = vec!["zone", "gpu/0", "gpu/1"];
1167 assert_eq!(expand_labels(&factors), expected);
1168 }
1169
1170 #[test]
1171 fn test_expand_labels_empty() {
1172 let factors: Vec<(String, Vec<usize>)> = vec![];
1173 let expected: Vec<String> = vec![];
1174 assert_eq!(expand_labels(&factors), expected);
1175 }
1176
1177 #[test]
1178 fn test_reshape_shape_noop() {
1179 let shape = shape!(x = 4, y = 8);
1180 let reshaped = reshape_shape(&shape, Limit::from(8));
1181 assert_eq!(reshaped.shape.labels(), &["x", "y"]);
1182 assert_eq!(reshaped.shape.slice(), shape.slice());
1183 }
1184
1185 #[test]
1186 fn test_reshape_shape_factored() {
1187 let shape = shape!(gpu = 8);
1188 let reshaped = reshape_shape(&shape, Limit::from(2));
1189 assert_eq!(reshaped.shape.labels(), &["gpu/0", "gpu/1", "gpu/2"]);
1190 assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2]);
1191
1192 let expected = shape.slice().reshape_with_limit(Limit::from(2));
1193 assert_eq!(reshaped.shape.slice(), &expected);
1194 }
1195
1196 #[test]
1197 fn test_reshape_shape_singleton() {
1198 let shape = shape!(x = 3);
1199 let reshaped = reshape_shape(&shape, Limit::from(8));
1200 assert_eq!(reshaped.shape.labels(), &["x"]);
1201 assert_eq!(reshaped.shape.slice(), shape.slice());
1202 }
1203
1204 #[test]
1205 fn test_reshape_shape_prime_exceeds_limit() {
1206 let shape = shape!(x = 11);
1207 let reshaped = reshape_shape(&shape, Limit::from(5));
1208 assert_eq!(reshaped.shape.labels(), &["x"]);
1209 assert_eq!(reshaped.shape.slice(), shape.slice());
1210 }
1211
1212 #[test]
1213 fn test_reshape_shape_mixed_dims() {
1214 let shape = shape!(zone = 2, gpu = 8);
1215 let reshaped = reshape_shape(&shape, Limit::from(2));
1216 assert_eq!(
1217 reshaped.shape.labels(),
1218 &["zone", "gpu/0", "gpu/1", "gpu/2"]
1219 );
1220 assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2, 2]);
1221
1222 let expected = shape.slice().reshape_with_limit(Limit::from(2));
1223 assert_eq!(reshaped.shape.slice(), &expected);
1224 }
1225
1226 #[test]
1227 fn test_reshape_shape_after_selects() {
1228 let original = shape!(zone = 2, host = 4, gpu = 8);
1230
1231 let selected_zone = original.select("zone", 1).unwrap();
1233 assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
1234
1235 let selected_host = selected_zone.select("host", 2).unwrap();
1237 assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
1238
1239 let reshaped = reshape_shape(&selected_host, Limit::from(2));
1241
1242 assert_eq!(
1244 reshaped.shape.labels(),
1245 &["zone", "host", "gpu/0", "gpu/1", "gpu/2"]
1246 );
1247
1248 assert_eq!(reshaped.shape.slice().sizes(), &[1, 1, 2, 2, 2]);
1250
1251 let expected = selected_host.slice().reshape_with_limit(Limit::from(2));
1253 assert_eq!(reshaped.shape.slice(), &expected);
1254 }
1255
1256 use std::collections::BTreeSet;
1257
1258 use proptest::prelude::*;
1259 use proptest::test_runner::TestRunner;
1260
1261 use crate::selection::EvalOpts;
1262 use crate::strategy::gen_selection;
1263 use crate::strategy::gen_slice;
1264
1265 proptest! {
1266 #![proptest_config(ProptestConfig {
1267 cases: 20, ..ProptestConfig::default()
1268 })]
1269 #[test]
1270 fn test_reshape_selection((slice, fanout_limit) in gen_slice(4, 64).prop_flat_map(|slice| {
1271 let max_dimension_size = slice.sizes().iter().max().unwrap();
1272 (1..=*max_dimension_size).prop_map(move |fanout_limit| (slice.clone(), fanout_limit))
1273 })) {
1274 let shape = slice.sizes().to_vec();
1275
1276 let mut runner = TestRunner::default();
1277 let selection = gen_selection(4, shape.clone(), 0).new_tree(&mut runner).unwrap().current();
1278
1279 let original_selected_ranks = selection
1280 .eval(&EvalOpts::strict(), &slice)
1281 .unwrap()
1282 .collect::<BTreeSet<_>>();
1283
1284 let reshaped_slice = reshape_with_limit(&slice, Limit::from(fanout_limit));
1285 let reshaped_selection = reshape_selection(selection, &slice, &reshaped_slice).ok().unwrap();
1286
1287 let folded_selected_ranks = reshaped_selection
1288 .eval(&EvalOpts::strict(), &reshaped_slice)?
1289 .collect::<BTreeSet<_>>();
1290
1291 prop_assert_eq!(original_selected_ranks, folded_selected_ranks);
1292 }
1293 }
1294}