1use std::fmt;
25
26use crate::Range;
27use crate::Selection;
28use crate::dsl::union;
29use crate::shape::Shape;
30use crate::slice::Slice;
31
32pub type Coord = Vec<usize>;
35
36pub struct ReshapedShape {
42 pub shape: Shape,
45
46 pub factors: Vec<(String, Vec<usize>)>,
49}
50
51#[allow(dead_code)]
52const _: () = {
53 fn assert<T: Send + Sync + 'static>() {}
54 let _ = assert::<ReshapedShape>;
55};
56
57impl std::fmt::Debug for ReshapedShape {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("ReshapedShape")
60 .field("labels", &self.shape.labels())
61 .field("sizes", &self.shape.slice().sizes())
62 .field("strides", &self.shape.slice().strides())
63 .field("offset", &self.shape.slice().offset())
64 .field("factors", &self.factors)
65 .finish()
66 }
67}
68
69impl std::fmt::Display for ReshapedShape {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
71 write!(
72 f,
73 "ReshapedShape {{ [off={} sz={:?} st={:?} lab={:?} fac={:?}] }}",
74 self.shape.slice().offset(),
75 self.shape.slice().sizes(),
76 self.shape.slice().strides(),
77 self.shape.labels(),
78 self.factors
79 )
80 }
81}
82
83pub(crate) fn factor_dims(sizes: &[usize], limit: Limit) -> Vec<Vec<usize>> {
97 let limit = limit.get();
98 sizes
99 .iter()
100 .map(|&size| {
101 if size <= limit {
102 return vec![size];
103 }
104 let mut rem = size;
105 let mut factors = Vec::new();
106 for d in (2..=limit).rev() {
107 while rem % d == 0 {
108 factors.push(d);
109 rem /= d;
110 }
111 }
112 if rem > 1 {
113 factors.push(rem);
114 }
115 factors
116 })
117 .collect()
118}
119
120pub fn to_reshaped_coord<'a>(
124 original: &'a Slice,
125 reshaped: &'a Slice,
126) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
127 let original = original.clone();
128 let reshaped = reshaped.clone();
129 move |coord: &[usize]| -> Coord {
130 let flat = original.location(coord).unwrap();
131 reshaped.coordinates(flat).unwrap()
132 }
133}
134
135pub fn to_original_coord<'a>(
139 reshaped: &'a Slice,
140 original: &'a Slice,
141) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
142 let reshaped = reshaped.clone();
143 let original = original.clone();
144 move |coord: &[usize]| -> Coord {
145 let flat = reshaped.location(coord).unwrap();
146 original.coordinates(flat).unwrap()
147 }
148}
149
150#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
168pub struct Limit(usize);
169
170impl Limit {
171 pub fn new(n: usize) -> Self {
173 assert!(n >= 1, "Limit must be at least 1");
174 Self(n)
175 }
176
177 pub fn get(self) -> usize {
179 self.0
180 }
181}
182
183impl Default for Limit {
184 fn default() -> Self {
185 Self(32)
186 }
187}
188
189impl From<usize> for Limit {
190 fn from(n: usize) -> Self {
191 Self::new(n)
192 }
193}
194
195pub trait ReshapeSliceExt {
215 fn reshape_with_limit(&self, limit: Limit) -> Slice;
227}
228
229impl ReshapeSliceExt for Slice {
230 fn reshape_with_limit(&self, limit: Limit) -> Slice {
231 reshape_with_limit(self, limit)
232 }
233}
234
235pub trait ReshapeShapeExt {
237 fn reshape(&self, limit: Limit) -> ReshapedShape;
240}
241
242impl ReshapeShapeExt for Shape {
243 fn reshape(&self, limit: Limit) -> ReshapedShape {
244 reshape_shape(self, limit)
245 }
246}
247
248pub mod prelude {
251 pub use super::ReshapeShapeExt;
252 pub use super::ReshapeSliceExt;
253}
254
255pub fn reshape_with_limit(slice: &Slice, limit: Limit) -> Slice {
285 let orig_sizes = slice.sizes();
286 let orig_strides = slice.strides();
287
288 let factored_sizes = factor_dims(orig_sizes, limit);
290
291 let reshaped_sizes: Vec<usize> = factored_sizes.iter().flatten().cloned().collect();
293 let mut reshaped_strides = Vec::with_capacity(reshaped_sizes.len());
294
295 for (&orig_stride, factors) in orig_strides.iter().zip(&factored_sizes) {
296 let mut sub_strides = Vec::with_capacity(factors.len());
297 let mut stride = orig_stride;
298 for &f in factors.iter().rev() {
299 sub_strides.push(stride);
300 stride *= f;
301 }
302 sub_strides.reverse();
303 reshaped_strides.extend(sub_strides);
304 }
305
306 Slice::new(slice.offset(), reshaped_sizes, reshaped_strides).unwrap()
307}
308
309pub fn reshape_shape(shape: &Shape, limit: Limit) -> ReshapedShape {
329 let reshaped_slice = shape.slice().reshape_with_limit(limit);
330 let original_labels = shape.labels();
331 let original_sizes = shape.slice().sizes();
332
333 let factors = factor_dims(original_sizes, limit);
334 let factored_dims: Vec<(String, Vec<usize>)> =
335 original_labels.iter().cloned().zip(factors).collect();
336
337 let labels = expand_labels(&factored_dims);
338 let shape = Shape::new(labels, reshaped_slice).expect("invalid reshaped shape");
339
340 ReshapedShape {
341 shape,
342 factors: factored_dims,
343 }
344}
345
346pub fn expand_labels(factors: &[(String, Vec<usize>)]) -> Vec<String> {
367 let mut labels = Vec::new();
368 for (label, dims) in factors {
369 if dims.len() == 1 {
370 labels.push(label.clone());
371 } else {
372 for (i, _) in dims.iter().enumerate() {
373 labels.push(format!("{}/{}", label, i));
374 }
375 }
376 }
377 labels
378}
379
380#[derive(Debug, thiserror::Error)]
381pub enum ReshapeError {
382 #[error("unsupported selection kind {selection}")]
383 UnsupportedSelection { selection: Selection },
384}
385pub fn reshape_selection(
389 selection: Selection,
390 original_slice: &Slice,
391 reshaped_slice: &Slice,
392) -> Result<Selection, ReshapeError> {
393 fn recursive_fold(
394 selection: Selection,
395 original_slice: &Slice,
396 original_size_index: usize,
397 reshaped_slice: &Slice,
398 reshaped_size_index: usize,
399 ) -> Result<Selection, ReshapeError> {
400 if matches!(selection, Selection::True | Selection::False) {
401 return Ok(selection);
402 }
403
404 let Some(&original_dim_size) = original_slice.sizes().get(original_size_index) else {
405 return Ok(selection);
406 };
407
408 let mut accum = *reshaped_slice.sizes().get(reshaped_size_index).unwrap();
409 let mut next_reshaped_dimension_start = reshaped_size_index + 1;
410
411 while accum < original_dim_size {
412 accum *= *reshaped_slice
413 .sizes()
414 .get(next_reshaped_dimension_start)
415 .unwrap();
416 next_reshaped_dimension_start += 1;
417 }
418
419 match selection {
420 Selection::True | Selection::False => Ok(selection),
422 Selection::Union(left, right) => {
424 let left = recursive_fold(
425 *left,
426 original_slice,
427 original_size_index,
428 reshaped_slice,
429 reshaped_size_index,
430 )?;
431
432 match left {
433 Selection::True => return Ok(Selection::True),
434 Selection::False => {
435 return recursive_fold(
436 *right,
437 original_slice,
438 original_size_index,
439 reshaped_slice,
440 reshaped_size_index,
441 );
442 }
443 _ => {}
444 }
445
446 let right = recursive_fold(
447 *right,
448 original_slice,
449 original_size_index,
450 reshaped_slice,
451 reshaped_size_index,
452 )?;
453
454 Ok(match right {
455 Selection::True => Selection::True,
456 Selection::False => left,
457 _ => Selection::Union(Box::new(left), Box::new(right)),
458 })
459 }
460 Selection::Intersection(left, right) => {
461 let left = recursive_fold(
462 *left,
463 original_slice,
464 original_size_index,
465 reshaped_slice,
466 reshaped_size_index,
467 )?;
468 match left {
469 Selection::False => return Ok(Selection::False),
470 Selection::True => {
471 return recursive_fold(
472 *right,
473 original_slice,
474 original_size_index,
475 reshaped_slice,
476 reshaped_size_index,
477 );
478 }
479 _ => {}
480 }
481
482 let right = recursive_fold(
483 *right,
484 original_slice,
485 original_size_index,
486 reshaped_slice,
487 reshaped_size_index,
488 )?;
489 Ok(match right {
490 Selection::False => Selection::False,
491 Selection::True => left,
492 _ => Selection::Intersection(Box::new(left), Box::new(right)),
493 })
494 }
495 Selection::All(inner) => {
496 let inner = recursive_fold(
497 *inner,
498 original_slice,
499 original_size_index + 1,
500 reshaped_slice,
501 next_reshaped_dimension_start,
502 )?;
503
504 if matches!(inner, Selection::True | Selection::False) {
505 return Ok(inner);
506 }
507
508 Ok((reshaped_size_index..next_reshaped_dimension_start - 1)
509 .fold(Selection::All(Box::new(inner)), |result, _| {
510 Selection::All(Box::new(result))
511 }))
512 }
513 Selection::Any(inner) => {
514 let inner = recursive_fold(
515 *inner,
516 original_slice,
517 original_size_index + 1,
518 reshaped_slice,
519 next_reshaped_dimension_start,
520 )?;
521
522 if matches!(inner, Selection::False) {
523 return Ok(inner);
524 }
525
526 Ok((reshaped_size_index..next_reshaped_dimension_start - 1)
527 .fold(Selection::Any(Box::new(inner)), |result, _| {
528 Selection::Any(Box::new(result))
529 }))
530 }
531 Selection::First(inner) => {
532 let inner = recursive_fold(
533 *inner,
534 original_slice,
535 original_size_index + 1,
536 reshaped_slice,
537 next_reshaped_dimension_start,
538 )?;
539
540 if matches!(inner, Selection::False) {
541 return Ok(inner);
542 }
543
544 Ok((reshaped_size_index..next_reshaped_dimension_start - 1)
545 .fold(Selection::First(Box::new(inner)), |result, _| {
546 Selection::First(Box::new(result))
547 }))
548 }
549 Selection::Range(range, inner) => {
550 fn fold_once(
576 Range(start, end, step): Range,
577 inner: Selection,
578 original_dimension_n_size: usize,
579 new_dimension_n_size: usize,
580 ) -> Vec<Selection> {
581 let end = end.map(|e| e.min(original_dimension_n_size));
586
587 let dimension_n_plus_one_start = start / new_dimension_n_size;
588 let dimension_n_plus_one_end = end.map(|end| {
589 if end % new_dimension_n_size == 0 {
590 end / new_dimension_n_size
591 } else {
592 end / new_dimension_n_size + 1
593 }
594 });
595 let dimension_n_plus_one_size =
596 original_dimension_n_size / new_dimension_n_size;
597
598 let new_dimension_n_start = start % new_dimension_n_size;
599 let new_dimension_n_end = end.map(|end| {
600 if end % new_dimension_n_size == 0 && end > new_dimension_n_size - 1 {
601 new_dimension_n_size
604 } else {
605 end % new_dimension_n_size
606 }
607 });
608
609 let mut result = vec![];
610
611 if dimension_n_plus_one_end
613 .is_some_and(|end| dimension_n_plus_one_start + 1 == end)
614 || (end.is_none()
615 && dimension_n_plus_one_start == dimension_n_plus_one_size)
616 {
617 return vec![Selection::Range(
618 Range(dimension_n_plus_one_start, dimension_n_plus_one_end, 1),
619 Box::new(Selection::Range(
620 Range(new_dimension_n_start, new_dimension_n_end, step),
621 Box::new(inner.clone()),
622 )),
623 )];
624 }
625
626 if step == 1 {
628 let middle_start = match start % new_dimension_n_size {
633 0 => dimension_n_plus_one_start,
634 _ => {
635 result.push(Selection::Range(
636 Range(
637 dimension_n_plus_one_start,
638 Some(dimension_n_plus_one_start + 1),
639 1,
640 ),
641 Box::new(Selection::Range(
642 Range(
643 new_dimension_n_start,
644 Some(new_dimension_n_size),
645 step,
646 ),
647 Box::new(inner.clone()),
648 )),
649 ));
650 dimension_n_plus_one_start + 1
651 }
652 };
653
654 let middle_end = match (end, dimension_n_plus_one_end) {
659 (Some(end), Some(dimension_n_plus_one_end))
660 if end % new_dimension_n_size != 0 =>
661 {
662 result.push(Selection::Range(
663 Range(
664 dimension_n_plus_one_end - 1,
665 Some(dimension_n_plus_one_end),
666 1,
667 ),
668 Box::new(Selection::Range(
669 Range(0, new_dimension_n_end, step),
670 Box::new(inner.clone()),
671 )),
672 ));
673 Some(dimension_n_plus_one_end - 1)
674 }
675 _ => dimension_n_plus_one_end,
676 };
677
678 if middle_end.is_some_and(|end| end > middle_start)
683 || (middle_end.is_none() && middle_start < dimension_n_plus_one_size)
684 {
685 result.push(Selection::Range(
686 Range(middle_start, middle_end, 1),
687 Box::new(Selection::All(Box::new(inner.clone()))),
688 ));
689 }
690 } else {
693 fn gcd(a: usize, b: usize) -> usize {
695 if b == 0 { a } else { gcd(b, a % b) }
696 }
697
698 let row_pattern_period = step / gcd(step, new_dimension_n_size);
699
700 let mut row_col_iter = std::iter::successors(
702 Some((dimension_n_plus_one_start, start % new_dimension_n_size)),
703 |&(row, col)| {
704 let cols_before_end = new_dimension_n_size - 1 - col;
705 let steps_before_end = cols_before_end / step;
706 let last_col_before_end = col + step * steps_before_end;
707
708 let next_row =
709 ((row * new_dimension_n_size) + last_col_before_end + step)
710 / new_dimension_n_size;
711 let next_col = (last_col_before_end + step) % new_dimension_n_size;
712
713 Some((next_row, next_col))
714 },
715 )
716 .peekable();
717
718 if start % new_dimension_n_size != 0 {
720 let (row, col) = row_col_iter.next().unwrap();
721
722 result.push(Selection::Range(
723 Range(row, Some(row + 1), 1),
724 Box::new(Selection::Range(
725 Range(col, None, step),
726 Box::new(inner.clone()),
727 )),
728 ));
729 };
730
731 for _ in 0..row_pattern_period {
733 let end_row = end.map(|end| end / new_dimension_n_size);
734
735 if match end_row {
736 Some(end_row) => row_col_iter.peek().unwrap().0 >= end_row,
737 None => row_col_iter.peek().unwrap().0 >= dimension_n_plus_one_size,
738 } {
739 break;
740 }
741 let (row_index, col) = row_col_iter.next().unwrap();
742
743 result.push(Selection::Range(
744 Range(row_index, end_row, row_pattern_period),
745 Box::new(Selection::Range(
746 Range(col, None, step),
747 Box::new(inner.clone()),
748 )),
749 ));
750 }
751
752 if let Some(end) = end {
754 let end_row = end / new_dimension_n_size;
755
756 for (row, col) in row_col_iter {
757 if row > end_row {
758 break;
759 }
760
761 if row % row_pattern_period == end_row % row_pattern_period
762 && col < end % new_dimension_n_size
763 {
764 result.push(Selection::Range(
765 Range(end_row, Some(end_row + 1), 1),
766 Box::new(Selection::Range(
767 Range(col, Some(end % new_dimension_n_size), step),
768 Box::new(inner.clone()),
769 )),
770 ));
771 break;
772 }
773 }
774 }
775 }
776 result
777 }
778
779 let inner = recursive_fold(
780 *inner,
781 original_slice,
782 original_size_index + 1,
783 reshaped_slice,
784 next_reshaped_dimension_start,
785 )?;
786 if matches!(inner, Selection::False) {
787 return Ok(inner);
788 }
789 let mut pieces = vec![Selection::Range(range, Box::new(inner))];
790
791 let reversed_dimensions = reshaped_slice.sizes()
794 [reshaped_size_index + 1..next_reshaped_dimension_start]
795 .iter()
796 .copied()
797 .rev();
798
799 let mut original_dimension_size = original_dim_size;
800 for dimension in reversed_dimensions {
801 pieces = pieces
802 .into_iter()
803 .flat_map(|piece| {
804 if let Selection::Range(range, inner) = piece {
805 fold_once(range, *inner, original_dimension_size, dimension)
806 } else {
807 vec![]
808 }
809 })
810 .collect();
811 original_dimension_size /= dimension;
812 }
813
814 Ok(pieces.into_iter().fold(Selection::False, |x, y| match x {
815 Selection::False => y,
816 _ => union(x, y),
817 }))
818 }
819 _ => Err(ReshapeError::UnsupportedSelection { selection }),
820 }
821 }
822
823 recursive_fold(selection, original_slice, 0, reshaped_slice, 0)
824}
825
826#[cfg(test)]
827mod tests {
828 use super::*;
829 use crate::Slice;
830 use crate::shape;
831
832 #[test]
833 fn test_factor_dims_basic() {
834 assert_eq!(
835 factor_dims(&[6, 8], Limit::from(4)),
836 vec![vec![3, 2], vec![4, 2]]
837 );
838 assert_eq!(factor_dims(&[5], Limit::from(3)), vec![vec![5]]);
839 assert_eq!(factor_dims(&[30], Limit::from(5)), vec![vec![5, 3, 2]]);
840 }
841
842 #[macro_export]
853 macro_rules! assert_layout_preserved {
854 ($original:expr, $reshaped:expr) => {{
855 for coord in $original.dim_iter($original.num_dim()) {
857 let forward = to_reshaped_coord($original, &$reshaped);
858 let inverse = to_original_coord(&$reshaped, $original);
859 let reshaped_coord = forward(&coord);
862 let roundtrip = inverse(&reshaped_coord);
864 assert_eq!(
865 roundtrip, coord,
866 "Inverse mismatch: reshaped {:?} → original {:?}, expected {:?}",
867 reshaped_coord, roundtrip, coord
868 );
869 let flat_orig = $original.location(&coord).unwrap();
871 let flat_reshaped = $reshaped.location(&reshaped_coord).unwrap();
873 assert_eq!(
876 flat_orig, flat_reshaped,
877 "Flat index mismatch: original {:?} → reshaped {:?}",
878 coord, reshaped_coord
879 );
880 let recovered = $reshaped.coordinates(flat_reshaped).unwrap();
882 assert_eq!(
885 reshaped_coord, recovered,
886 "Coordinate mismatch: flat index {} → expected {:?}, got {:?}",
887 flat_reshaped, reshaped_coord, recovered
888 );
889 }
890 }};
891 }
892
893 #[test]
894 fn test_reshape_split_1d_row_major() {
895 let s = Slice::new_row_major(vec![1024]);
896 let reshaped = s.reshape_with_limit(Limit::from(8));
897
898 assert_eq!(reshaped.offset(), 0);
899 assert_eq!(reshaped.sizes(), &vec![8, 8, 8, 2]);
900 assert_eq!(reshaped.strides(), &vec![128, 16, 2, 1]);
901 assert_eq!(
902 factor_dims(s.sizes(), Limit::from(8)),
903 vec![vec![8, 8, 8, 2]]
904 );
905
906 assert_layout_preserved!(&s, &reshaped);
907 }
908
909 #[test]
910 fn test_reshape_6_with_limit_2() {
911 let s = Slice::new_row_major(vec![6]);
912 let reshaped = reshape_with_limit(&s, Limit::from(2));
913 assert_eq!(factor_dims(s.sizes(), Limit::from(2)), vec![vec![2, 3]]);
914 assert_layout_preserved!(&s, &reshaped);
915 }
916
917 #[test]
918 fn test_reshape_identity_noop_2d() {
919 let original = Slice::new_row_major(vec![4, 8]);
921 let reshaped = original.reshape_with_limit(Limit::from(8));
922
923 assert_eq!(reshaped.sizes(), original.sizes());
924 assert_eq!(reshaped.strides(), original.strides());
925 assert_eq!(reshaped.offset(), original.offset());
926 assert_eq!(
927 vec![vec![4], vec![8]],
928 original
929 .sizes()
930 .iter()
931 .map(|&n| vec![n])
932 .collect::<Vec<_>>()
933 );
934 assert_layout_preserved!(&original, &reshaped);
935 }
936
937 #[test]
938 fn test_reshape_empty_slice() {
939 let original = Slice::new_row_major(vec![]);
941 let reshaped = reshape_with_limit(&original, Limit::from(8));
942
943 assert_eq!(reshaped.sizes(), original.sizes());
944 assert_eq!(reshaped.strides(), original.strides());
945 assert_eq!(reshaped.offset(), original.offset());
946
947 assert_layout_preserved!(&original, &reshaped);
948 }
949
950 #[test]
951 fn test_reshape_mixed_dims_3d() {
952 let original = Slice::new_row_major(vec![6, 8, 10]);
954 let reshaped = original.reshape_with_limit(Limit::from(4));
955
956 assert_eq!(
957 factor_dims(original.sizes(), Limit::from(4)),
958 vec![vec![3, 2], vec![4, 2], vec![2, 5]]
959 );
960 assert_eq!(reshaped.sizes(), &[3, 2, 4, 2, 2, 5]);
961
962 assert_layout_preserved!(&original, &reshaped);
963 }
964
965 #[test]
966 fn test_reshape_all_large_dims() {
967 let original = Slice::new_row_major(vec![12, 18, 20]);
969 let reshaped = original.reshape_with_limit(Limit::from(4));
970
971 assert_eq!(
972 factor_dims(original.sizes(), Limit::from(4)),
973 vec![vec![4, 3], vec![3, 3, 2], vec![4, 5]]
974 );
975 assert_eq!(reshaped.sizes(), &[4, 3, 3, 3, 2, 4, 5]);
976
977 assert_layout_preserved!(&original, &reshaped);
978 }
979
980 #[test]
981 fn test_reshape_split_1d_factors_3_3_2_2() {
982 let original = Slice::new_row_major(vec![36]);
984 let reshaped = reshape_with_limit(&original, Limit::from(3));
985
986 assert_eq!(
987 factor_dims(original.sizes(), Limit::from(3)),
988 vec![vec![3, 3, 2, 2]]
989 );
990 assert_eq!(reshaped.sizes(), &[3, 3, 2, 2]);
991 assert_layout_preserved!(&original, &reshaped);
992 }
993
994 #[test]
995 fn test_reshape_large_prime_dimension() {
996 let original = Slice::new_row_major(vec![7]);
998 let reshaped = reshape_with_limit(&original, Limit::from(4));
999
1000 assert_eq!(factor_dims(original.sizes(), Limit::from(4)), vec![vec![7]]);
1002 assert_eq!(reshaped.sizes(), &[7]);
1003
1004 assert_layout_preserved!(&original, &reshaped);
1005 }
1006
1007 #[test]
1008 fn test_reshape_split_1d_factors_5_3_2() {
1009 let original = Slice::new_row_major(vec![30]);
1011 let reshaped = reshape_with_limit(&original, Limit::from(5));
1012
1013 assert_eq!(
1014 factor_dims(original.sizes(), Limit::from(5)),
1015 vec![vec![5, 3, 2]]
1016 );
1017 assert_eq!(reshaped.sizes(), &[5, 3, 2]);
1018 assert_eq!(reshaped.strides(), &[6, 2, 1]);
1019
1020 assert_layout_preserved!(&original, &reshaped);
1021 }
1022
1023 #[test]
1024 fn test_reshape_factors_2_6_2_8_8() {
1025 let original = Slice::new_row_major(vec![2, 12, 64]);
1027 let reshaped = original.reshape_with_limit(Limit::from(8));
1028
1029 assert_eq!(
1030 factor_dims(original.sizes(), Limit::from(8)),
1031 vec![vec![2], vec![6, 2], vec![8, 8]]
1032 );
1033 assert_eq!(reshaped.sizes(), &[2, 6, 2, 8, 8]);
1034 assert_eq!(reshaped.strides(), &[768, 128, 64, 8, 1]);
1035
1036 assert_layout_preserved!(&original, &reshaped);
1037 }
1038
1039 #[test]
1040 fn test_reshape_all_dims_within_limit() {
1041 let original = Slice::new_row_major(vec![2, 3, 4]);
1043 let reshaped = original.reshape_with_limit(Limit::from(4));
1044
1045 assert_eq!(
1046 factor_dims(original.sizes(), Limit::from(4)),
1047 vec![vec![2], vec![3], vec![4]]
1048 );
1049 assert_eq!(reshaped.sizes(), &[2, 3, 4]);
1050 assert_eq!(reshaped.strides(), original.strides());
1051 assert_eq!(reshaped.offset(), original.offset());
1052
1053 assert_layout_preserved!(&original, &reshaped);
1054 }
1055
1056 #[test]
1057 fn test_reshape_degenerate_dimension() {
1058 let original = Slice::new_row_major(vec![1, 12]);
1060 let reshaped = original.reshape_with_limit(Limit::from(4));
1061
1062 assert_eq!(
1063 factor_dims(original.sizes(), Limit::from(4)),
1064 vec![vec![1], vec![4, 3]]
1065 );
1066 assert_eq!(reshaped.sizes(), &[1, 4, 3]);
1067
1068 assert_layout_preserved!(&original, &reshaped);
1069 }
1070
1071 #[test]
1072 fn test_select_then_reshape() {
1073 let original = shape!(zone = 2, host = 3, gpu = 4);
1075
1076 let selected = original.select("zone", 1).unwrap();
1078 assert_eq!(selected.slice().offset(), 12); assert_eq!(selected.slice().sizes(), &[1, 3, 4]);
1080
1081 let reshaped = selected.slice().reshape_with_limit(Limit::from(2));
1084
1085 assert_eq!(
1086 factor_dims(selected.slice().sizes(), Limit::from(2)),
1087 vec![vec![1], vec![3], vec![2, 2]]
1088 );
1089 assert_eq!(reshaped.sizes(), &[1, 3, 2, 2]);
1090 assert_eq!(reshaped.strides(), &[12, 4, 2, 1]);
1091 assert_eq!(reshaped.offset(), 12); assert_layout_preserved!(selected.slice(), &reshaped);
1094 }
1095
1096 #[test]
1097 fn test_select_host_plane_then_reshape() {
1098 let original = shape!(zone = 2, host = 3, gpu = 4);
1100 let selected = original.select("host", 2).unwrap();
1102 let reshaped = selected.slice().reshape_with_limit(Limit::from(2));
1105
1106 assert_layout_preserved!(selected.slice(), &reshaped);
1107 }
1108
1109 #[test]
1110 fn test_reshape_after_select_no_factoring_due_to_primes() {
1111 let original = shape!(zone = 3, host = 4, gpu = 5);
1113 let selected_zone = original.select("zone", 1).unwrap();
1115 assert_eq!(selected_zone.slice().sizes(), &[1, 4, 5]);
1116 let selected_host = selected_zone.select("host", 2).unwrap();
1118 assert_eq!(selected_host.slice().sizes(), &[1, 1, 5]);
1119 let reshaped = selected_host.slice().reshape_with_limit(Limit::from(2));
1121
1122 assert_eq!(
1123 factor_dims(selected_host.slice().sizes(), Limit::from(2)),
1124 vec![vec![1], vec![1], vec![5]]
1125 );
1126 assert_eq!(reshaped.sizes(), &[1, 1, 5]);
1127
1128 assert_layout_preserved!(selected_host.slice(), &reshaped);
1129 }
1130
1131 #[test]
1132 fn test_reshape_after_multiple_selects_triggers_factoring() {
1133 let original = shape!(zone = 2, host = 4, gpu = 8);
1135 let selected_zone = original.select("zone", 1).unwrap();
1137 assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
1138
1139 let selected_host = selected_zone.select("host", 2).unwrap();
1141 assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
1142
1143 let reshaped = selected_host.slice().reshape_with_limit(Limit::from(2));
1145
1146 assert_eq!(
1147 factor_dims(selected_host.slice().sizes(), Limit::from(2)),
1148 vec![vec![1], vec![1], vec![2, 2, 2]]
1149 );
1150 assert_eq!(reshaped.sizes(), &[1, 1, 2, 2, 2]);
1151
1152 assert_layout_preserved!(selected_host.slice(), &reshaped);
1153 }
1154
1155 #[test]
1156 fn test_expand_labels_singleton_dims() {
1157 let factors = vec![("x".into(), vec![2]), ("y".into(), vec![4])];
1158 let expected = vec!["x", "y"];
1159 assert_eq!(expand_labels(&factors), expected);
1160 }
1161
1162 #[test]
1163 fn test_expand_labels_factored_dims() {
1164 let factors = vec![("gpu".into(), vec![2, 2, 2])];
1165 let expected = vec!["gpu/0", "gpu/1", "gpu/2"];
1166 assert_eq!(expand_labels(&factors), expected);
1167 }
1168
1169 #[test]
1170 fn test_expand_labels_mixed_dims() {
1171 let factors = vec![("zone".into(), vec![2]), ("gpu".into(), vec![2, 2])];
1172 let expected = vec!["zone", "gpu/0", "gpu/1"];
1173 assert_eq!(expand_labels(&factors), expected);
1174 }
1175
1176 #[test]
1177 fn test_expand_labels_empty() {
1178 let factors: Vec<(String, Vec<usize>)> = vec![];
1179 let expected: Vec<String> = vec![];
1180 assert_eq!(expand_labels(&factors), expected);
1181 }
1182
1183 #[test]
1184 fn test_reshape_shape_noop() {
1185 let shape = shape!(x = 4, y = 8);
1186 let reshaped = reshape_shape(&shape, Limit::from(8));
1187 assert_eq!(reshaped.shape.labels(), &["x", "y"]);
1188 assert_eq!(reshaped.shape.slice(), shape.slice());
1189 }
1190
1191 #[test]
1192 fn test_reshape_shape_factored() {
1193 let shape = shape!(gpu = 8);
1194 let reshaped = reshape_shape(&shape, Limit::from(2));
1195 assert_eq!(reshaped.shape.labels(), &["gpu/0", "gpu/1", "gpu/2"]);
1196 assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2]);
1197
1198 let expected = shape.slice().reshape_with_limit(Limit::from(2));
1199 assert_eq!(reshaped.shape.slice(), &expected);
1200 }
1201
1202 #[test]
1203 fn test_reshape_shape_singleton() {
1204 let shape = shape!(x = 3);
1205 let reshaped = reshape_shape(&shape, Limit::from(8));
1206 assert_eq!(reshaped.shape.labels(), &["x"]);
1207 assert_eq!(reshaped.shape.slice(), shape.slice());
1208 }
1209
1210 #[test]
1211 fn test_reshape_shape_prime_exceeds_limit() {
1212 let shape = shape!(x = 11);
1213 let reshaped = reshape_shape(&shape, Limit::from(5));
1214 assert_eq!(reshaped.shape.labels(), &["x"]);
1215 assert_eq!(reshaped.shape.slice(), shape.slice());
1216 }
1217
1218 #[test]
1219 fn test_reshape_shape_mixed_dims() {
1220 let shape = shape!(zone = 2, gpu = 8);
1221 let reshaped = reshape_shape(&shape, Limit::from(2));
1222 assert_eq!(
1223 reshaped.shape.labels(),
1224 &["zone", "gpu/0", "gpu/1", "gpu/2"]
1225 );
1226 assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2, 2]);
1227
1228 let expected = shape.slice().reshape_with_limit(Limit::from(2));
1229 assert_eq!(reshaped.shape.slice(), &expected);
1230 }
1231
1232 #[test]
1233 fn test_reshape_shape_after_selects() {
1234 let original = shape!(zone = 2, host = 4, gpu = 8);
1236
1237 let selected_zone = original.select("zone", 1).unwrap();
1239 assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
1240
1241 let selected_host = selected_zone.select("host", 2).unwrap();
1243 assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
1244
1245 let reshaped = reshape_shape(&selected_host, Limit::from(2));
1247
1248 assert_eq!(
1250 reshaped.shape.labels(),
1251 &["zone", "host", "gpu/0", "gpu/1", "gpu/2"]
1252 );
1253
1254 assert_eq!(reshaped.shape.slice().sizes(), &[1, 1, 2, 2, 2]);
1256
1257 let expected = selected_host.slice().reshape_with_limit(Limit::from(2));
1259 assert_eq!(reshaped.shape.slice(), &expected);
1260 }
1261
1262 use std::collections::BTreeSet;
1263
1264 use proptest::prelude::*;
1265
1266 use crate::selection::EvalOpts;
1267 use crate::strategy::gen_selection;
1268 use crate::strategy::gen_slice;
1269
1270 proptest! {
1271 #![proptest_config(ProptestConfig {
1272 cases: 100,
1273 ..ProptestConfig::default()
1274 })]
1275 #[test]
1276 #[cfg_attr(not(fbcode_build), ignore)]
1277 fn test_reshape_selection((slice, fanout_limit, selection) in gen_slice(4, 64).prop_flat_map(|slice| {
1278 let shape = slice.sizes().to_vec();
1279 let max_dimension_size = *slice.sizes().iter().max().unwrap();
1280 (Just(slice), 1..=max_dimension_size, gen_selection(4, shape, 0))
1281 })) {
1282 let original_selected_ranks = selection
1283 .eval(&EvalOpts::strict(), &slice)
1284 .unwrap()
1285 .collect::<BTreeSet<_>>();
1286
1287 let reshaped_slice = reshape_with_limit(&slice, Limit::from(fanout_limit));
1288 let reshaped_selection = reshape_selection(selection, &slice, &reshaped_slice).ok().unwrap();
1289
1290 let folded_selected_ranks = reshaped_selection
1291 .eval(&EvalOpts::strict(), &reshaped_slice)?
1292 .collect::<BTreeSet<_>>();
1293
1294 prop_assert_eq!(original_selected_ranks, folded_selected_ranks);
1295 }
1296 }
1297}