1use std::fmt;
25
26use crate::shape::Shape;
27use crate::slice::Slice;
28
29pub type Coord = Vec<usize>;
32
33pub struct ReshapedShape {
39 pub shape: Shape,
42
43 pub factors: Vec<(String, Vec<usize>)>,
46}
47
48#[allow(dead_code)]
49const _: () = {
50 fn assert<T: Send + Sync + 'static>() {}
51 let _ = assert::<ReshapedShape>;
52};
53
54impl std::fmt::Debug for ReshapedShape {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 f.debug_struct("ReshapedShape")
57 .field("labels", &self.shape.labels())
58 .field("sizes", &self.shape.slice().sizes())
59 .field("strides", &self.shape.slice().strides())
60 .field("offset", &self.shape.slice().offset())
61 .field("factors", &self.factors)
62 .finish()
63 }
64}
65
66impl std::fmt::Display for ReshapedShape {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
68 write!(
69 f,
70 "ReshapedShape {{ [off={} sz={:?} st={:?} lab={:?} fac={:?}] }}",
71 self.shape.slice().offset(),
72 self.shape.slice().sizes(),
73 self.shape.slice().strides(),
74 self.shape.labels(),
75 self.factors
76 )
77 }
78}
79
80pub(crate) fn factor_dims(sizes: &[usize], limit: Limit) -> Vec<Vec<usize>> {
94 let limit = limit.get();
95 sizes
96 .iter()
97 .map(|&size| {
98 if size <= limit {
99 return vec![size];
100 }
101 let mut rem = size;
102 let mut factors = Vec::new();
103 for d in (2..=limit).rev() {
104 while rem % d == 0 {
105 factors.push(d);
106 rem /= d;
107 }
108 }
109 if rem > 1 {
110 factors.push(rem);
111 }
112 factors
113 })
114 .collect()
115}
116
117pub fn to_reshaped_coord<'a>(
121 original: &'a Slice,
122 reshaped: &'a Slice,
123) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
124 let original = original.clone();
125 let reshaped = reshaped.clone();
126 move |coord: &[usize]| -> Coord {
127 let flat = original.location(coord).unwrap();
128 reshaped.coordinates(flat).unwrap()
129 }
130}
131
132pub fn to_original_coord<'a>(
136 reshaped: &'a Slice,
137 original: &'a Slice,
138) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
139 let reshaped = reshaped.clone();
140 let original = original.clone();
141 move |coord: &[usize]| -> Coord {
142 let flat = reshaped.location(coord).unwrap();
143 original.coordinates(flat).unwrap()
144 }
145}
146
147#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
165pub struct Limit(usize);
166
167impl Limit {
168 pub fn new(n: usize) -> Self {
170 assert!(n >= 1, "Limit must be at least 1");
171 Self(n)
172 }
173
174 pub fn get(self) -> usize {
176 self.0
177 }
178}
179
180impl Default for Limit {
181 fn default() -> Self {
182 Self(32)
183 }
184}
185
186impl From<usize> for Limit {
187 fn from(n: usize) -> Self {
188 Self::new(n)
189 }
190}
191
192pub trait ReshapeSliceExt {
212 fn reshape_with_limit(&self, limit: Limit) -> Slice;
224}
225
226impl ReshapeSliceExt for Slice {
227 fn reshape_with_limit(&self, limit: Limit) -> Slice {
228 reshape_with_limit(self, limit)
229 }
230}
231
232pub trait ReshapeShapeExt {
234 fn reshape(&self, limit: Limit) -> ReshapedShape;
237}
238
239impl ReshapeShapeExt for Shape {
240 fn reshape(&self, limit: Limit) -> ReshapedShape {
241 reshape_shape(self, limit)
242 }
243}
244
245pub mod prelude {
248 pub use super::ReshapeShapeExt;
249 pub use super::ReshapeSliceExt;
250}
251
252pub fn reshape_with_limit(slice: &Slice, limit: Limit) -> Slice {
282 let orig_sizes = slice.sizes();
283 let orig_strides = slice.strides();
284
285 let factored_sizes = factor_dims(orig_sizes, limit);
287
288 let reshaped_sizes: Vec<usize> = factored_sizes.iter().flatten().cloned().collect();
290 let mut reshaped_strides = Vec::with_capacity(reshaped_sizes.len());
291
292 for (&orig_stride, factors) in orig_strides.iter().zip(&factored_sizes) {
293 let mut sub_strides = Vec::with_capacity(factors.len());
294 let mut stride = orig_stride;
295 for &f in factors.iter().rev() {
296 sub_strides.push(stride);
297 stride *= f;
298 }
299 sub_strides.reverse();
300 reshaped_strides.extend(sub_strides);
301 }
302
303 Slice::new(slice.offset(), reshaped_sizes, reshaped_strides).unwrap()
304}
305
306pub fn reshape_shape(shape: &Shape, limit: Limit) -> ReshapedShape {
326 let reshaped_slice = shape.slice().reshape_with_limit(limit);
327 let original_labels = shape.labels();
328 let original_sizes = shape.slice().sizes();
329
330 let factors = factor_dims(original_sizes, limit);
331 let factored_dims: Vec<(String, Vec<usize>)> =
332 original_labels.iter().cloned().zip(factors).collect();
333
334 let labels = expand_labels(&factored_dims);
335 let shape = Shape::new(labels, reshaped_slice).expect("invalid reshaped shape");
336
337 ReshapedShape {
338 shape,
339 factors: factored_dims,
340 }
341}
342
343pub fn expand_labels(factors: &[(String, Vec<usize>)]) -> Vec<String> {
364 let mut labels = Vec::new();
365 for (label, dims) in factors {
366 if dims.len() == 1 {
367 labels.push(label.clone());
368 } else {
369 for (i, _) in dims.iter().enumerate() {
370 labels.push(format!("{}/{}", label, i));
371 }
372 }
373 }
374 labels
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use crate::Slice;
381 use crate::shape;
382
383 #[test]
384 fn test_factor_dims_basic() {
385 assert_eq!(
386 factor_dims(&[6, 8], Limit::from(4)),
387 vec![vec![3, 2], vec![4, 2]]
388 );
389 assert_eq!(factor_dims(&[5], Limit::from(3)), vec![vec![5]]);
390 assert_eq!(factor_dims(&[30], Limit::from(5)), vec![vec![5, 3, 2]]);
391 }
392
393 #[macro_export]
404 macro_rules! assert_layout_preserved {
405 ($original:expr_2021, $reshaped:expr_2021) => {{
406 for coord in $original.dim_iter($original.num_dim()) {
408 let forward = to_reshaped_coord($original, &$reshaped);
409 let inverse = to_original_coord(&$reshaped, $original);
410 let reshaped_coord = forward(&coord);
413 let roundtrip = inverse(&reshaped_coord);
415 assert_eq!(
416 roundtrip, coord,
417 "Inverse mismatch: reshaped {:?} → original {:?}, expected {:?}",
418 reshaped_coord, roundtrip, coord
419 );
420 let flat_orig = $original.location(&coord).unwrap();
422 let flat_reshaped = $reshaped.location(&reshaped_coord).unwrap();
424 assert_eq!(
427 flat_orig, flat_reshaped,
428 "Flat index mismatch: original {:?} → reshaped {:?}",
429 coord, reshaped_coord
430 );
431 let recovered = $reshaped.coordinates(flat_reshaped).unwrap();
433 assert_eq!(
436 reshaped_coord, recovered,
437 "Coordinate mismatch: flat index {} → expected {:?}, got {:?}",
438 flat_reshaped, reshaped_coord, recovered
439 );
440 }
441 }};
442 }
443
444 #[test]
445 fn test_reshape_split_1d_row_major() {
446 let s = Slice::new_row_major(vec![1024]);
447 let reshaped = s.reshape_with_limit(Limit::from(8));
448
449 assert_eq!(reshaped.offset(), 0);
450 assert_eq!(reshaped.sizes(), &vec![8, 8, 8, 2]);
451 assert_eq!(reshaped.strides(), &vec![128, 16, 2, 1]);
452 assert_eq!(
453 factor_dims(s.sizes(), Limit::from(8)),
454 vec![vec![8, 8, 8, 2]]
455 );
456
457 assert_layout_preserved!(&s, &reshaped);
458 }
459
460 #[test]
461 fn test_reshape_6_with_limit_2() {
462 let s = Slice::new_row_major(vec![6]);
463 let reshaped = reshape_with_limit(&s, Limit::from(2));
464 assert_eq!(factor_dims(s.sizes(), Limit::from(2)), vec![vec![2, 3]]);
465 assert_layout_preserved!(&s, &reshaped);
466 }
467
468 #[test]
469 fn test_reshape_identity_noop_2d() {
470 let original = Slice::new_row_major(vec![4, 8]);
472 let reshaped = original.reshape_with_limit(Limit::from(8));
473
474 assert_eq!(reshaped.sizes(), original.sizes());
475 assert_eq!(reshaped.strides(), original.strides());
476 assert_eq!(reshaped.offset(), original.offset());
477 assert_eq!(
478 vec![vec![4], vec![8]],
479 original
480 .sizes()
481 .iter()
482 .map(|&n| vec![n])
483 .collect::<Vec<_>>()
484 );
485 assert_layout_preserved!(&original, &reshaped);
486 }
487
488 #[test]
489 fn test_reshape_empty_slice() {
490 let original = Slice::new_row_major(vec![]);
492 let reshaped = reshape_with_limit(&original, Limit::from(8));
493
494 assert_eq!(reshaped.sizes(), original.sizes());
495 assert_eq!(reshaped.strides(), original.strides());
496 assert_eq!(reshaped.offset(), original.offset());
497
498 assert_layout_preserved!(&original, &reshaped);
499 }
500
501 #[test]
502 fn test_reshape_mixed_dims_3d() {
503 let original = Slice::new_row_major(vec![6, 8, 10]);
505 let reshaped = original.reshape_with_limit(Limit::from(4));
506
507 assert_eq!(
508 factor_dims(original.sizes(), Limit::from(4)),
509 vec![vec![3, 2], vec![4, 2], vec![2, 5]]
510 );
511 assert_eq!(reshaped.sizes(), &[3, 2, 4, 2, 2, 5]);
512
513 assert_layout_preserved!(&original, &reshaped);
514 }
515
516 #[test]
517 fn test_reshape_all_large_dims() {
518 let original = Slice::new_row_major(vec![12, 18, 20]);
520 let reshaped = original.reshape_with_limit(Limit::from(4));
521
522 assert_eq!(
523 factor_dims(original.sizes(), Limit::from(4)),
524 vec![vec![4, 3], vec![3, 3, 2], vec![4, 5]]
525 );
526 assert_eq!(reshaped.sizes(), &[4, 3, 3, 3, 2, 4, 5]);
527
528 assert_layout_preserved!(&original, &reshaped);
529 }
530
531 #[test]
532 fn test_reshape_split_1d_factors_3_3_2_2() {
533 let original = Slice::new_row_major(vec![36]);
535 let reshaped = reshape_with_limit(&original, Limit::from(3));
536
537 assert_eq!(
538 factor_dims(original.sizes(), Limit::from(3)),
539 vec![vec![3, 3, 2, 2]]
540 );
541 assert_eq!(reshaped.sizes(), &[3, 3, 2, 2]);
542 assert_layout_preserved!(&original, &reshaped);
543 }
544
545 #[test]
546 fn test_reshape_large_prime_dimension() {
547 let original = Slice::new_row_major(vec![7]);
549 let reshaped = reshape_with_limit(&original, Limit::from(4));
550
551 assert_eq!(factor_dims(original.sizes(), Limit::from(4)), vec![vec![7]]);
553 assert_eq!(reshaped.sizes(), &[7]);
554
555 assert_layout_preserved!(&original, &reshaped);
556 }
557
558 #[test]
559 fn test_reshape_split_1d_factors_5_3_2() {
560 let original = Slice::new_row_major(vec![30]);
562 let reshaped = reshape_with_limit(&original, Limit::from(5));
563
564 assert_eq!(
565 factor_dims(original.sizes(), Limit::from(5)),
566 vec![vec![5, 3, 2]]
567 );
568 assert_eq!(reshaped.sizes(), &[5, 3, 2]);
569 assert_eq!(reshaped.strides(), &[6, 2, 1]);
570
571 assert_layout_preserved!(&original, &reshaped);
572 }
573
574 #[test]
575 fn test_reshape_factors_2_6_2_8_8() {
576 let original = Slice::new_row_major(vec![2, 12, 64]);
578 let reshaped = original.reshape_with_limit(Limit::from(8));
579
580 assert_eq!(
581 factor_dims(original.sizes(), Limit::from(8)),
582 vec![vec![2], vec![6, 2], vec![8, 8]]
583 );
584 assert_eq!(reshaped.sizes(), &[2, 6, 2, 8, 8]);
585 assert_eq!(reshaped.strides(), &[768, 128, 64, 8, 1]);
586
587 assert_layout_preserved!(&original, &reshaped);
588 }
589
590 #[test]
591 fn test_reshape_all_dims_within_limit() {
592 let original = Slice::new_row_major(vec![2, 3, 4]);
594 let reshaped = original.reshape_with_limit(Limit::from(4));
595
596 assert_eq!(
597 factor_dims(original.sizes(), Limit::from(4)),
598 vec![vec![2], vec![3], vec![4]]
599 );
600 assert_eq!(reshaped.sizes(), &[2, 3, 4]);
601 assert_eq!(reshaped.strides(), original.strides());
602 assert_eq!(reshaped.offset(), original.offset());
603
604 assert_layout_preserved!(&original, &reshaped);
605 }
606
607 #[test]
608 fn test_reshape_degenerate_dimension() {
609 let original = Slice::new_row_major(vec![1, 12]);
611 let reshaped = original.reshape_with_limit(Limit::from(4));
612
613 assert_eq!(
614 factor_dims(original.sizes(), Limit::from(4)),
615 vec![vec![1], vec![4, 3]]
616 );
617 assert_eq!(reshaped.sizes(), &[1, 4, 3]);
618
619 assert_layout_preserved!(&original, &reshaped);
620 }
621
622 #[test]
623 fn test_select_then_reshape() {
624 let original = shape!(zone = 2, host = 3, gpu = 4);
626
627 let selected = original.select("zone", 1).unwrap();
629 assert_eq!(selected.slice().offset(), 12); assert_eq!(selected.slice().sizes(), &[1, 3, 4]);
631
632 let reshaped = selected.slice().reshape_with_limit(Limit::from(2));
635
636 assert_eq!(
637 factor_dims(selected.slice().sizes(), Limit::from(2)),
638 vec![vec![1], vec![3], vec![2, 2]]
639 );
640 assert_eq!(reshaped.sizes(), &[1, 3, 2, 2]);
641 assert_eq!(reshaped.strides(), &[12, 4, 2, 1]);
642 assert_eq!(reshaped.offset(), 12); assert_layout_preserved!(selected.slice(), &reshaped);
645 }
646
647 #[test]
648 fn test_select_host_plane_then_reshape() {
649 let original = shape!(zone = 2, host = 3, gpu = 4);
651 let selected = original.select("host", 2).unwrap();
653 let reshaped = selected.slice().reshape_with_limit(Limit::from(2));
656
657 assert_layout_preserved!(selected.slice(), &reshaped);
658 }
659
660 #[test]
661 fn test_reshape_after_select_no_factoring_due_to_primes() {
662 let original = shape!(zone = 3, host = 4, gpu = 5);
664 let selected_zone = original.select("zone", 1).unwrap();
666 assert_eq!(selected_zone.slice().sizes(), &[1, 4, 5]);
667 let selected_host = selected_zone.select("host", 2).unwrap();
669 assert_eq!(selected_host.slice().sizes(), &[1, 1, 5]);
670 let reshaped = selected_host.slice().reshape_with_limit(Limit::from(2));
672
673 assert_eq!(
674 factor_dims(selected_host.slice().sizes(), Limit::from(2)),
675 vec![vec![1], vec![1], vec![5]]
676 );
677 assert_eq!(reshaped.sizes(), &[1, 1, 5]);
678
679 assert_layout_preserved!(selected_host.slice(), &reshaped);
680 }
681
682 #[test]
683 fn test_reshape_after_multiple_selects_triggers_factoring() {
684 let original = shape!(zone = 2, host = 4, gpu = 8);
686 let selected_zone = original.select("zone", 1).unwrap();
688 assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
689
690 let selected_host = selected_zone.select("host", 2).unwrap();
692 assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
693
694 let reshaped = selected_host.slice().reshape_with_limit(Limit::from(2));
696
697 assert_eq!(
698 factor_dims(selected_host.slice().sizes(), Limit::from(2)),
699 vec![vec![1], vec![1], vec![2, 2, 2]]
700 );
701 assert_eq!(reshaped.sizes(), &[1, 1, 2, 2, 2]);
702
703 assert_layout_preserved!(selected_host.slice(), &reshaped);
704 }
705
706 #[test]
707 fn test_expand_labels_singleton_dims() {
708 let factors = vec![("x".into(), vec![2]), ("y".into(), vec![4])];
709 let expected = vec!["x", "y"];
710 assert_eq!(expand_labels(&factors), expected);
711 }
712
713 #[test]
714 fn test_expand_labels_factored_dims() {
715 let factors = vec![("gpu".into(), vec![2, 2, 2])];
716 let expected = vec!["gpu/0", "gpu/1", "gpu/2"];
717 assert_eq!(expand_labels(&factors), expected);
718 }
719
720 #[test]
721 fn test_expand_labels_mixed_dims() {
722 let factors = vec![("zone".into(), vec![2]), ("gpu".into(), vec![2, 2])];
723 let expected = vec!["zone", "gpu/0", "gpu/1"];
724 assert_eq!(expand_labels(&factors), expected);
725 }
726
727 #[test]
728 fn test_expand_labels_empty() {
729 let factors: Vec<(String, Vec<usize>)> = vec![];
730 let expected: Vec<String> = vec![];
731 assert_eq!(expand_labels(&factors), expected);
732 }
733
734 #[test]
735 fn test_reshape_shape_noop() {
736 let shape = shape!(x = 4, y = 8);
737 let reshaped = reshape_shape(&shape, Limit::from(8));
738 assert_eq!(reshaped.shape.labels(), &["x", "y"]);
739 assert_eq!(reshaped.shape.slice(), shape.slice());
740 }
741
742 #[test]
743 fn test_reshape_shape_factored() {
744 let shape = shape!(gpu = 8);
745 let reshaped = reshape_shape(&shape, Limit::from(2));
746 assert_eq!(reshaped.shape.labels(), &["gpu/0", "gpu/1", "gpu/2"]);
747 assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2]);
748
749 let expected = shape.slice().reshape_with_limit(Limit::from(2));
750 assert_eq!(reshaped.shape.slice(), &expected);
751 }
752
753 #[test]
754 fn test_reshape_shape_singleton() {
755 let shape = shape!(x = 3);
756 let reshaped = reshape_shape(&shape, Limit::from(8));
757 assert_eq!(reshaped.shape.labels(), &["x"]);
758 assert_eq!(reshaped.shape.slice(), shape.slice());
759 }
760
761 #[test]
762 fn test_reshape_shape_prime_exceeds_limit() {
763 let shape = shape!(x = 11);
764 let reshaped = reshape_shape(&shape, Limit::from(5));
765 assert_eq!(reshaped.shape.labels(), &["x"]);
766 assert_eq!(reshaped.shape.slice(), shape.slice());
767 }
768
769 #[test]
770 fn test_reshape_shape_mixed_dims() {
771 let shape = shape!(zone = 2, gpu = 8);
772 let reshaped = reshape_shape(&shape, Limit::from(2));
773 assert_eq!(
774 reshaped.shape.labels(),
775 &["zone", "gpu/0", "gpu/1", "gpu/2"]
776 );
777 assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2, 2]);
778
779 let expected = shape.slice().reshape_with_limit(Limit::from(2));
780 assert_eq!(reshaped.shape.slice(), &expected);
781 }
782
783 #[test]
784 fn test_reshape_shape_after_selects() {
785 let original = shape!(zone = 2, host = 4, gpu = 8);
787
788 let selected_zone = original.select("zone", 1).unwrap();
790 assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
791
792 let selected_host = selected_zone.select("host", 2).unwrap();
794 assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
795
796 let reshaped = reshape_shape(&selected_host, Limit::from(2));
798
799 assert_eq!(
801 reshaped.shape.labels(),
802 &["zone", "host", "gpu/0", "gpu/1", "gpu/2"]
803 );
804
805 assert_eq!(reshaped.shape.slice().sizes(), &[1, 1, 2, 2, 2]);
807
808 let expected = selected_host.slice().reshape_with_limit(Limit::from(2));
810 assert_eq!(reshaped.shape.slice(), &expected);
811 }
812}