ndslice/
reshape.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Dimensional reshaping of slices and shapes.
10//!
11//! This module defines utilities for transforming a [`Slice`] or
12//! [`Shape`] by factoring large extents into smaller ones under a
13//! given limit. The result is a reshaped view with increased
14//! dimensionality and preserved memory layout.
15//!
16//! This is useful for hierarchical routing, structured fanout, and
17//! other multidimensional layout transformations.
18//!
19//! For [`Shape`]s, reshaping also expands dimension labels using a
20//! `label/N` naming convention, preserving the semantics of the
21//! original shape in the reshaped reshape_with_limit.
22//!
23//! See [`reshape_with_limit`] and [`reshape_shape`] for entry points.
24use std::fmt;
25
26use crate::Range;
27use crate::Selection;
28use crate::dsl::union;
29use crate::shape::Shape;
30use crate::slice::Slice;
31
32/// Coordinate vector used throughout reshape logic. Semantically
33/// represents a point in multidimensional space.
34pub type Coord = Vec<usize>;
35
36/// A reshaped version of a `Shape`, with factored dimensions and
37/// updated labels.
38///
39/// This type preserves coordinate bijections with the original shape
40/// and provides access to the transformed layout and label mappings.
41pub struct ReshapedShape {
42    /// The reshaped shape, with new labels and underlying factored
43    /// slice.
44    pub shape: Shape,
45
46    /// For each original dimension label, the list of sizes it was
47    /// split into.
48    pub factors: Vec<(String, Vec<usize>)>,
49}
50
51#[allow(dead_code)]
52const _: () = {
53    fn assert<T: Send + Sync + 'static>() {}
54    let _ = assert::<ReshapedShape>;
55};
56
57impl std::fmt::Debug for ReshapedShape {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("ReshapedShape")
60            .field("labels", &self.shape.labels())
61            .field("sizes", &self.shape.slice().sizes())
62            .field("strides", &self.shape.slice().strides())
63            .field("offset", &self.shape.slice().offset())
64            .field("factors", &self.factors)
65            .finish()
66    }
67}
68
69impl std::fmt::Display for ReshapedShape {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
71        write!(
72            f,
73            "ReshapedShape {{ [off={} sz={:?} st={:?} lab={:?} fac={:?}] }}",
74            self.shape.slice().offset(),
75            self.shape.slice().sizes(),
76            self.shape.slice().strides(),
77            self.shape.labels(),
78            self.factors
79        )
80    }
81}
82
83/// Returns, for each size, a list of factors that respect the given
84/// limit. If a size is ≤ limit, it is returned as a singleton.
85/// Otherwise, it is factored greedily using divisors ≤ limit, from
86/// largest to smallest.
87///
88/// For best results, dimensions should be chosen to allow factoring
89/// into small values under the selected limit (e.g., ≤ 32).
90/// Large prime numbers cannot be broken down and will remain as-is,
91/// limiting reshaping potential.
92///
93/// Prefer powers of 2 or other highly composite numbers
94/// (e.g., 8, 16, 32, 60, 120) over large primes (e.g., 17, 37, 113)
95/// when designing shapes intended for reshaping.
96pub(crate) fn factor_dims(sizes: &[usize], limit: Limit) -> Vec<Vec<usize>> {
97    let limit = limit.get();
98    sizes
99        .iter()
100        .map(|&size| {
101            if size <= limit {
102                return vec![size];
103            }
104            let mut rem = size;
105            let mut factors = Vec::new();
106            for d in (2..=limit).rev() {
107                while rem % d == 0 {
108                    factors.push(d);
109                    rem /= d;
110                }
111            }
112            if rem > 1 {
113                factors.push(rem);
114            }
115            factors
116        })
117        .collect()
118}
119
120/// Constructs a function that maps coordinates from the original
121/// slice to equivalent coordinates in the reshaped slice, preserving
122/// their flat (linear) position.
123pub fn to_reshaped_coord<'a>(
124    original: &'a Slice,
125    reshaped: &'a Slice,
126) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
127    let original = original.clone();
128    let reshaped = reshaped.clone();
129    move |coord: &[usize]| -> Coord {
130        let flat = original.location(coord).unwrap();
131        reshaped.coordinates(flat).unwrap()
132    }
133}
134
135/// Constructs a function that maps coordinates from the reshaped
136/// slice back to equivalent coordinates in the original slice,
137/// preserving their flat (linear) position.
138pub fn to_original_coord<'a>(
139    reshaped: &'a Slice,
140    original: &'a Slice,
141) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
142    let reshaped = reshaped.clone();
143    let original = original.clone();
144    move |coord: &[usize]| -> Coord {
145        let flat = reshaped.location(coord).unwrap();
146        original.coordinates(flat).unwrap()
147    }
148}
149
150/// A shaping constraint that bounds the maximum extent allowed in any
151/// reshaped dimension.
152///
153/// This limit controls how a given dimension is factored during
154/// reshaping. Values larger than `limit` are recursively decomposed
155/// into smaller factors (e.g., `reshape_with_limit([1024],
156/// Limit::new(32))` → `[32, 32]`).
157///
158/// The default limit is `32`, which balances fanout depth and layout
159/// regularity.
160///
161/// # Example
162/// ```
163/// use ndslice::reshape::Limit;
164/// let limit = Limit::new(64);
165/// assert_eq!(limit.get(), 64);
166/// ```
167#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
168pub struct Limit(usize);
169
170impl Limit {
171    /// Creates a new `Limit`. Panics if less than 1.
172    pub fn new(n: usize) -> Self {
173        assert!(n >= 1, "Limit must be at least 1");
174        Self(n)
175    }
176
177    /// Returns the inner value.
178    pub fn get(self) -> usize {
179        self.0
180    }
181}
182
183impl Default for Limit {
184    fn default() -> Self {
185        Self(32)
186    }
187}
188
189impl From<usize> for Limit {
190    fn from(n: usize) -> Self {
191        Self::new(n)
192    }
193}
194
195/// A trait for types that can be reshaped into a higher-dimensional
196/// view by factoring large extents into smaller ones.
197///
198/// This is implemented for [`Slice`], enabling ergonomic access to
199/// [`reshape_with_limit`] as a method.
200///
201/// # Example
202/// ```
203/// use ndslice::Slice;
204/// use ndslice::reshape::Limit;
205/// use ndslice::reshape::ReshapeSliceExt;
206///
207/// let slice = Slice::new_row_major(vec![1024]);
208/// let reshaped = slice.reshape_with_limit(Limit::new(32));
209/// assert_eq!(reshaped.sizes(), &[32, 32]);
210/// ```
211/// # Returns
212/// A reshaped [`Slice`] with increased dimensionality and preserved
213/// layout.
214pub trait ReshapeSliceExt {
215    /// Returns a reshaped version of this structure by factoring each
216    /// dimension into smaller extents no greater than `limit`,
217    /// preserving memory layout and flat index semantics. See
218    /// [`reshape_with_limit`] for full behavior and rationale.
219    ///
220    /// # Arguments
221    /// - `limit`: maximum size allowed in any reshaped dimension
222    ///
223    /// # Returns
224    /// A reshaped [`Slice`] with increased dimensionality and a
225    /// bijective mapping to the original.
226    fn reshape_with_limit(&self, limit: Limit) -> Slice;
227}
228
229impl ReshapeSliceExt for Slice {
230    fn reshape_with_limit(&self, limit: Limit) -> Slice {
231        reshape_with_limit(self, limit)
232    }
233}
234
235/// Extension trait for reshaping `Shape`s by factoring large dimensions.
236pub trait ReshapeShapeExt {
237    /// Produces a reshaped version of the shape with expanded
238    /// dimensions under the given size limit.
239    fn reshape(&self, limit: Limit) -> ReshapedShape;
240}
241
242impl ReshapeShapeExt for Shape {
243    fn reshape(&self, limit: Limit) -> ReshapedShape {
244        reshape_shape(self, limit)
245    }
246}
247
248/// For convenient `slice.reshape_with_limit()`, `shape.reshape()`
249/// syntax, `use reshape::prelude::*`.
250pub mod prelude {
251    pub use super::ReshapeShapeExt;
252    pub use super::ReshapeSliceExt;
253}
254
255/// Reshapes a slice by factoring each dimension into smaller extents
256/// under the given limit.
257///
258/// This transformation increases dimensionality by breaking large
259/// sizes into products of smaller factors (e.g., `[1024]` with limit
260/// 32 becomes `[32, 32]`). The result is a new [`Slice`] that
261/// preserves memory layout and flat index semantics.
262///
263/// Factoring is greedy, starting from the largest divisors ≤ `limit`.
264/// Dimensions that cannot be factored under the limit are left
265/// unchanged.
266///
267/// # Arguments
268/// - `slice`: the original multidimensional slice
269/// - `limit`: maximum extent allowed in any factored subdimension
270///
271/// # Returns
272/// A reshaped [`Slice`] with updated sizes and strides.
273///
274/// # Example
275/// ```
276/// use ndslice::Slice;
277/// use ndslice::reshape::Limit;
278/// use ndslice::reshape::reshape_with_limit;
279///
280/// let slice = Slice::new_row_major(vec![1024]);
281/// let reshaped = reshape_with_limit(&slice, Limit::new(32));
282/// assert_eq!(reshaped.sizes(), &[32, 32]);
283/// ```
284pub fn reshape_with_limit(slice: &Slice, limit: Limit) -> Slice {
285    let orig_sizes = slice.sizes();
286    let orig_strides = slice.strides();
287
288    // Step 1: Factor each size into subdimensions ≤ limit.
289    let factored_sizes = factor_dims(orig_sizes, limit);
290
291    // Step 2: Compute reshaped sizes and strides (row-major only).
292    let reshaped_sizes: Vec<usize> = factored_sizes.iter().flatten().cloned().collect();
293    let mut reshaped_strides = Vec::with_capacity(reshaped_sizes.len());
294
295    for (&orig_stride, factors) in orig_strides.iter().zip(&factored_sizes) {
296        let mut sub_strides = Vec::with_capacity(factors.len());
297        let mut stride = orig_stride;
298        for &f in factors.iter().rev() {
299            sub_strides.push(stride);
300            stride *= f;
301        }
302        sub_strides.reverse();
303        reshaped_strides.extend(sub_strides);
304    }
305
306    Slice::new(slice.offset(), reshaped_sizes, reshaped_strides).unwrap()
307}
308
309/// Reshapes a labeled [`Shape`] by factoring large extents into
310/// smaller ones, producing a new shape with expanded dimensionality
311/// and updated labels.
312///
313/// This uses [`reshape_with_limit`] on the underlying slice and [`expand_labels`]
314/// to generate labels for each factored dimension.
315///
316/// # Arguments
317/// - `shape`: the labeled shape to reshape
318/// - `limit`: maximum extent allowed per factored dimension
319///
320/// # Returns
321/// A new [`ReshapedShape`] with an updated [`Shape`] and dimension
322/// factoring metadata.
323///
324/// # Panics
325/// Panics if constructing the new `Shape` fails. This should not
326/// occur unless the reshaped slice and labels are inconsistent (a
327/// programming logic error).
328pub fn reshape_shape(shape: &Shape, limit: Limit) -> ReshapedShape {
329    let reshaped_slice = shape.slice().reshape_with_limit(limit);
330    let original_labels = shape.labels();
331    let original_sizes = shape.slice().sizes();
332
333    let factors = factor_dims(original_sizes, limit);
334    let factored_dims: Vec<(String, Vec<usize>)> =
335        original_labels.iter().cloned().zip(factors).collect();
336
337    let labels = expand_labels(&factored_dims);
338    let shape = Shape::new(labels, reshaped_slice).expect("invalid reshaped shape");
339
340    ReshapedShape {
341        shape,
342        factors: factored_dims,
343    }
344}
345
346/// Expands factored dimension labels into one label per subdimension.
347///
348/// Each input pair `(label, factors)` represents an original
349/// dimension and the extents it was factored into. If a dimension was
350/// not factored, it will have a single-element vector.
351///
352/// For example:
353/// - `[("zone", vec![2]), ("gpu", vec![2, 2, 2])]`
354///   becomes `["zone", "gpu/0", "gpu/1", "gpu/2"]`
355///
356/// This is used to generate new labels for reshaped shapes, where the
357/// dimensionality increases due to factoring.
358///
359/// # Arguments
360/// - `factors`: a list of factored dimension extents, paired with
361///   their labels
362///
363/// # Returns
364/// - A `Vec<String>` of expanded labels, one for each reshaped
365///   dimension.
366pub fn expand_labels(factors: &[(String, Vec<usize>)]) -> Vec<String> {
367    let mut labels = Vec::new();
368    for (label, dims) in factors {
369        if dims.len() == 1 {
370            labels.push(label.clone());
371        } else {
372            for (i, _) in dims.iter().enumerate() {
373                labels.push(format!("{}/{}", label, i));
374            }
375        }
376    }
377    labels
378}
379
380#[derive(Debug, thiserror::Error)]
381pub enum ReshapeError {
382    #[error("unsupported selection kind {selection}")]
383    UnsupportedSelection { selection: Selection },
384}
385/// Maps a `Selection` on a `Slice` to a new `Selection` that selects all
386/// ranks in the reshaped `Slice` that the original `Selection` selected in the
387/// original `Slice`
388pub fn reshape_selection(
389    selection: Selection,
390    original_slice: &Slice,
391    reshaped_slice: &Slice,
392) -> Result<Selection, ReshapeError> {
393    fn recursive_fold(
394        selection: Selection,
395        original_slice: &Slice,
396        original_size_index: usize,
397        reshaped_slice: &Slice,
398        reshaped_size_index: usize,
399    ) -> Result<Selection, ReshapeError> {
400        if matches!(selection, Selection::True | Selection::False) {
401            return Ok(selection);
402        }
403
404        let Some(&original_dim_size) = original_slice.sizes().get(original_size_index) else {
405            return Ok(selection);
406        };
407
408        let mut accum = *reshaped_slice.sizes().get(reshaped_size_index).unwrap();
409        let mut next_reshaped_dimension_start = reshaped_size_index + 1;
410
411        while accum < original_dim_size {
412            accum *= *reshaped_slice
413                .sizes()
414                .get(next_reshaped_dimension_start)
415                .unwrap();
416            next_reshaped_dimension_start += 1;
417        }
418
419        match selection {
420            // base case
421            Selection::True | Selection::False => Ok(selection),
422            // For these cases we are not drilling down any dimensions when we recurse
423            Selection::Union(left, right) => {
424                let left = recursive_fold(
425                    *left,
426                    original_slice,
427                    original_size_index,
428                    reshaped_slice,
429                    reshaped_size_index,
430                )?;
431
432                match left {
433                    Selection::True => return Ok(Selection::True),
434                    Selection::False => {
435                        return recursive_fold(
436                            *right,
437                            original_slice,
438                            original_size_index,
439                            reshaped_slice,
440                            reshaped_size_index,
441                        );
442                    }
443                    _ => {}
444                }
445
446                let right = recursive_fold(
447                    *right,
448                    original_slice,
449                    original_size_index,
450                    reshaped_slice,
451                    reshaped_size_index,
452                )?;
453
454                Ok(match right {
455                    Selection::True => Selection::True,
456                    Selection::False => left,
457                    _ => Selection::Union(Box::new(left), Box::new(right)),
458                })
459            }
460            Selection::Intersection(left, right) => {
461                let left = recursive_fold(
462                    *left,
463                    original_slice,
464                    original_size_index,
465                    reshaped_slice,
466                    reshaped_size_index,
467                )?;
468                match left {
469                    Selection::False => return Ok(Selection::False),
470                    Selection::True => {
471                        return recursive_fold(
472                            *right,
473                            original_slice,
474                            original_size_index,
475                            reshaped_slice,
476                            reshaped_size_index,
477                        );
478                    }
479                    _ => {}
480                }
481
482                let right = recursive_fold(
483                    *right,
484                    original_slice,
485                    original_size_index,
486                    reshaped_slice,
487                    reshaped_size_index,
488                )?;
489                Ok(match right {
490                    Selection::False => Selection::False,
491                    Selection::True => left,
492                    _ => Selection::Intersection(Box::new(left), Box::new(right)),
493                })
494            }
495            Selection::All(inner) => {
496                let inner = recursive_fold(
497                    *inner,
498                    original_slice,
499                    original_size_index + 1,
500                    reshaped_slice,
501                    next_reshaped_dimension_start,
502                )?;
503
504                if matches!(inner, Selection::True | Selection::False) {
505                    return Ok(inner);
506                }
507
508                Ok((reshaped_size_index..next_reshaped_dimension_start - 1)
509                    .fold(Selection::All(Box::new(inner)), |result, _| {
510                        Selection::All(Box::new(result))
511                    }))
512            }
513            Selection::Any(inner) => {
514                let inner = recursive_fold(
515                    *inner,
516                    original_slice,
517                    original_size_index + 1,
518                    reshaped_slice,
519                    next_reshaped_dimension_start,
520                )?;
521
522                if matches!(inner, Selection::False) {
523                    return Ok(inner);
524                }
525
526                Ok((reshaped_size_index..next_reshaped_dimension_start - 1)
527                    .fold(Selection::Any(Box::new(inner)), |result, _| {
528                        Selection::Any(Box::new(result))
529                    }))
530            }
531            Selection::First(inner) => {
532                let inner = recursive_fold(
533                    *inner,
534                    original_slice,
535                    original_size_index + 1,
536                    reshaped_slice,
537                    next_reshaped_dimension_start,
538                )?;
539
540                if matches!(inner, Selection::False) {
541                    return Ok(inner);
542                }
543
544                Ok((reshaped_size_index..next_reshaped_dimension_start - 1)
545                    .fold(Selection::First(Box::new(inner)), |result, _| {
546                        Selection::First(Box::new(result))
547                    }))
548            }
549            Selection::Range(range, inner) => {
550                // We can fold a rectangle along a dimension by factoring it into up to 3 pieces:
551                // A starting piece if there is a region that begins after the start of a fold but spans to the end of that fold
552                // A middle piece that spans the entire fold for n folds
553                // An ending piece if there is a region that begins at the start of the fold but ends before the end of that fold
554                //
555                // To visualize: range(2:8, true)
556                // 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8
557                // o | o | x | x | x | x | x | x | o
558                //
559                // => fold into 3 pieces:
560                //
561                // start: range(0:1, range(2:3, true))
562                // 0 | 1 | 2 |
563                // o | o | x |
564                //
565                // middle: range(1:2, all(true))
566                // 3 | 4 | 5 |
567                // x | x | x |
568                //
569                // end: range(2:3, range(0:2, true))
570                // 6 | 7 | 8
571                // x | x | o
572                //
573                // When step size is larger than 1 this does get a bit more hairy where the middle piece must be
574                // split into several pieces and the end must be aligned
575                fn fold_once(
576                    Range(start, end, step): Range,
577                    inner: Selection,
578                    original_dimension_n_size: usize,
579                    new_dimension_n_size: usize,
580                ) -> Vec<Selection> {
581                    let dimension_n_plus_one_start = start / new_dimension_n_size;
582                    let dimension_n_plus_one_end = end.map(|end| {
583                        if end % new_dimension_n_size == 0 {
584                            end / new_dimension_n_size
585                        } else {
586                            end / new_dimension_n_size + 1
587                        }
588                    });
589                    let dimension_n_plus_one_size =
590                        original_dimension_n_size / new_dimension_n_size;
591
592                    let new_dimension_n_start = start % new_dimension_n_size;
593                    let new_dimension_n_end = end.map(|end| {
594                        if end % new_dimension_n_size == 0 && end > new_dimension_n_size - 1 {
595                            // If the original end was 3 and the new dimension size is 3
596                            // Then the new end should be 3 as opposed to 0 as end represents an upper bound
597                            new_dimension_n_size
598                        } else {
599                            end % new_dimension_n_size
600                        }
601                    });
602
603                    let mut result = vec![];
604
605                    // Simplest case where the entire rectangle is contains on a single fold of dim n+1
606                    if dimension_n_plus_one_end
607                        .is_some_and(|end| dimension_n_plus_one_start + 1 == end)
608                        || (end.is_none()
609                            && dimension_n_plus_one_start == dimension_n_plus_one_size)
610                    {
611                        return vec![Selection::Range(
612                            Range(dimension_n_plus_one_start, dimension_n_plus_one_end, 1),
613                            Box::new(Selection::Range(
614                                Range(new_dimension_n_start, new_dimension_n_end, step),
615                                Box::new(inner.clone()),
616                            )),
617                        )];
618                    }
619
620                    // Simpler case first where the middle piece can be represented with a single range
621                    if step == 1 {
622                        // Starting piece
623                        // ex: range(0:1, range(2:3, true))
624                        // 0 | 1 | 2 |
625                        // o | o | x |
626                        let middle_start = match start % new_dimension_n_size {
627                            0 => dimension_n_plus_one_start,
628                            _ => {
629                                result.push(Selection::Range(
630                                    Range(
631                                        dimension_n_plus_one_start,
632                                        Some(dimension_n_plus_one_start + 1),
633                                        1,
634                                    ),
635                                    Box::new(Selection::Range(
636                                        Range(
637                                            new_dimension_n_start,
638                                            Some(new_dimension_n_size),
639                                            step,
640                                        ),
641                                        Box::new(inner.clone()),
642                                    )),
643                                ));
644                                dimension_n_plus_one_start + 1
645                            }
646                        };
647
648                        // Ending piece
649                        // ex: range(2:3, range(0:2, true))
650                        // 6 | 7 | 8
651                        // x | x | o
652                        let middle_end = match (end, dimension_n_plus_one_end) {
653                            (Some(end), Some(dimension_n_plus_one_end))
654                                if end % new_dimension_n_size != 0 =>
655                            {
656                                result.push(Selection::Range(
657                                    Range(
658                                        dimension_n_plus_one_end - 1,
659                                        Some(dimension_n_plus_one_end),
660                                        1,
661                                    ),
662                                    Box::new(Selection::Range(
663                                        Range(0, new_dimension_n_end, step),
664                                        Box::new(inner.clone()),
665                                    )),
666                                ));
667                                Some(dimension_n_plus_one_end - 1)
668                            }
669                            _ => dimension_n_plus_one_end,
670                        };
671
672                        // Middle pieces
673                        // ex: range(1:2, all(true))
674                        // 3 | 4 | 5 |
675                        // x | x | x |
676                        if middle_end.is_some_and(|end| end > middle_start)
677                            || (middle_end.is_none() && middle_start < dimension_n_plus_one_size)
678                        {
679                            result.push(Selection::Range(
680                                Range(middle_start, middle_end, 1),
681                                Box::new(Selection::All(Box::new(inner.clone()))),
682                            ));
683                        }
684                    // Complicated case where step size is larger than 1 that involves splitting up
685                    // the middle piece
686                    } else {
687                        // Greatest common divisor
688                        fn gcd(a: usize, b: usize) -> usize {
689                            if b == 0 { a } else { gcd(b, a % b) }
690                        }
691
692                        let row_pattern_period = step / gcd(step, new_dimension_n_size);
693
694                        // get the coordinates of the first item on the next row
695                        let mut row_col_iter = std::iter::successors(
696                            Some((dimension_n_plus_one_start, start % new_dimension_n_size)),
697                            |&(row, col)| {
698                                let cols_before_end = new_dimension_n_size - 1 - col;
699                                let steps_before_end = cols_before_end / step;
700                                let last_col_before_end = col + step * steps_before_end;
701
702                                let next_row =
703                                    ((row * new_dimension_n_size) + last_col_before_end + step)
704                                        / new_dimension_n_size;
705                                let next_col = (last_col_before_end + step) % new_dimension_n_size;
706
707                                Some((next_row, next_col))
708                            },
709                        )
710                        .peekable();
711
712                        // Needs start piece
713                        if start % new_dimension_n_size != 0 {
714                            let (row, col) = row_col_iter.next().unwrap();
715
716                            result.push(Selection::Range(
717                                Range(row, Some(row + 1), 1),
718                                Box::new(Selection::Range(
719                                    Range(col, None, step),
720                                    Box::new(inner.clone()),
721                                )),
722                            ));
723                        };
724
725                        // Middle pieces
726                        for _ in 0..row_pattern_period {
727                            let end_row = end.map(|end| end / new_dimension_n_size);
728
729                            if match end_row {
730                                Some(end_row) => row_col_iter.peek().unwrap().0 >= end_row,
731                                None => row_col_iter.peek().unwrap().0 >= dimension_n_plus_one_size,
732                            } {
733                                break;
734                            }
735                            let (row_index, col) = row_col_iter.next().unwrap();
736
737                            result.push(Selection::Range(
738                                Range(row_index, end_row, row_pattern_period),
739                                Box::new(Selection::Range(
740                                    Range(col, None, step),
741                                    Box::new(inner.clone()),
742                                )),
743                            ));
744                        }
745
746                        // Needs end piece
747                        if let Some(end) = end {
748                            let end_row = end / new_dimension_n_size;
749
750                            for (row, col) in row_col_iter {
751                                if row > end_row {
752                                    break;
753                                }
754
755                                if row % row_pattern_period == end_row % row_pattern_period
756                                    && col < end % new_dimension_n_size
757                                {
758                                    result.push(Selection::Range(
759                                        Range(end_row, Some(end_row + 1), 1),
760                                        Box::new(Selection::Range(
761                                            Range(col, Some(end % new_dimension_n_size), step),
762                                            Box::new(inner.clone()),
763                                        )),
764                                    ));
765                                    break;
766                                }
767                            }
768                        }
769                    }
770                    result
771                }
772
773                let inner = recursive_fold(
774                    *inner,
775                    original_slice,
776                    original_size_index + 1,
777                    reshaped_slice,
778                    next_reshaped_dimension_start,
779                )?;
780                if matches!(inner, Selection::False) {
781                    return Ok(inner);
782                }
783                let mut pieces = vec![Selection::Range(range, Box::new(inner))];
784
785                // If [24] is being reshaped to [4, 3, 2] this will yield [2, 3] (dropping the first dimension and reversed)
786                // This is because we need to first fold by 2 to get [12, 3], then fold by 3 to get [4, 3, 2]
787                let reversed_dimensions = reshaped_slice.sizes()
788                    [reshaped_size_index + 1..next_reshaped_dimension_start]
789                    .iter()
790                    .copied()
791                    .rev();
792
793                let mut original_dimension_size = original_dim_size;
794                for dimension in reversed_dimensions {
795                    pieces = pieces
796                        .into_iter()
797                        .flat_map(|piece| {
798                            if let Selection::Range(range, inner) = piece {
799                                fold_once(range, *inner, original_dimension_size, dimension)
800                            } else {
801                                vec![]
802                            }
803                        })
804                        .collect();
805                    original_dimension_size /= dimension;
806                }
807
808                Ok(pieces.into_iter().fold(Selection::False, |x, y| match x {
809                    Selection::False => y,
810                    _ => union(x, y),
811                }))
812            }
813            _ => Err(ReshapeError::UnsupportedSelection { selection }),
814        }
815    }
816
817    recursive_fold(selection, original_slice, 0, reshaped_slice, 0)
818}
819
820#[cfg(test)]
821mod tests {
822    use super::*;
823    use crate::Slice;
824    use crate::shape;
825
826    #[test]
827    fn test_factor_dims_basic() {
828        assert_eq!(
829            factor_dims(&[6, 8], Limit::from(4)),
830            vec![vec![3, 2], vec![4, 2]]
831        );
832        assert_eq!(factor_dims(&[5], Limit::from(3)), vec![vec![5]]);
833        assert_eq!(factor_dims(&[30], Limit::from(5)), vec![vec![5, 3, 2]]);
834    }
835
836    // Verify that reshaping preserves memory layout by checking:
837    // 1. Coordinate round-tripping: original → reshaped → original
838    // 2. Flat index equality: original and reshaped coordinates map
839    //    to the same linear index
840    // 3. Index inversion: reshaped flat index maps back to the same
841    //    reshaped coordinate
842    //
843    // Together, these checks ensure that the reshaped view is
844    // layout-preserving and provides a bijective mapping between
845    // coordinate systems.
846    #[macro_export]
847    macro_rules! assert_layout_preserved {
848        ($original:expr, $reshaped:expr) => {{
849            // Iterate over all coordinates in the original slice.
850            for coord in $original.dim_iter($original.num_dim()) {
851                let forward = to_reshaped_coord($original, &$reshaped);
852                let inverse = to_original_coord(&$reshaped, $original);
853                // Apply the forward coordinate mapping from original
854                // to reshaped space.
855                let reshaped_coord = forward(&coord);
856                // Inverse mapping: reshaped coord → original coord.
857                let roundtrip = inverse(&reshaped_coord);
858                assert_eq!(
859                    roundtrip, coord,
860                    "Inverse mismatch: reshaped {:?} → original {:?}, expected {:?}",
861                    reshaped_coord, roundtrip, coord
862                );
863                // Compute flat index in the original slice.
864                let flat_orig = $original.location(&coord).unwrap();
865                // Compute flat index in the reshaped slice.
866                let flat_reshaped = $reshaped.location(&reshaped_coord).unwrap();
867                // Check that the flat index is preserved by the
868                // reshaping.
869                assert_eq!(
870                    flat_orig, flat_reshaped,
871                    "Flat index mismatch: original {:?} → reshaped {:?}",
872                    coord, reshaped_coord
873                );
874                // Invert the reshaped flat index back to coordinates.
875                let recovered = $reshaped.coordinates(flat_reshaped).unwrap();
876                // Ensure coordinate inversion is correct (round
877                // trip).
878                assert_eq!(
879                    reshaped_coord, recovered,
880                    "Coordinate mismatch: flat index {} → expected {:?}, got {:?}",
881                    flat_reshaped, reshaped_coord, recovered
882                );
883            }
884        }};
885    }
886
887    #[test]
888    fn test_reshape_split_1d_row_major() {
889        let s = Slice::new_row_major(vec![1024]);
890        let reshaped = s.reshape_with_limit(Limit::from(8));
891
892        assert_eq!(reshaped.offset(), 0);
893        assert_eq!(reshaped.sizes(), &vec![8, 8, 8, 2]);
894        assert_eq!(reshaped.strides(), &vec![128, 16, 2, 1]);
895        assert_eq!(
896            factor_dims(s.sizes(), Limit::from(8)),
897            vec![vec![8, 8, 8, 2]]
898        );
899
900        assert_layout_preserved!(&s, &reshaped);
901    }
902
903    #[test]
904    fn test_reshape_6_with_limit_2() {
905        let s = Slice::new_row_major(vec![6]);
906        let reshaped = reshape_with_limit(&s, Limit::from(2));
907        assert_eq!(factor_dims(s.sizes(), Limit::from(2)), vec![vec![2, 3]]);
908        assert_layout_preserved!(&s, &reshaped);
909    }
910
911    #[test]
912    fn test_reshape_identity_noop_2d() {
913        // All dimensions ≤ limit.
914        let original = Slice::new_row_major(vec![4, 8]);
915        let reshaped = original.reshape_with_limit(Limit::from(8));
916
917        assert_eq!(reshaped.sizes(), original.sizes());
918        assert_eq!(reshaped.strides(), original.strides());
919        assert_eq!(reshaped.offset(), original.offset());
920        assert_eq!(
921            vec![vec![4], vec![8]],
922            original
923                .sizes()
924                .iter()
925                .map(|&n| vec![n])
926                .collect::<Vec<_>>()
927        );
928        assert_layout_preserved!(&original, &reshaped);
929    }
930
931    #[test]
932    fn test_reshape_empty_slice() {
933        // 0-dimensional slice.
934        let original = Slice::new_row_major(vec![]);
935        let reshaped = reshape_with_limit(&original, Limit::from(8));
936
937        assert_eq!(reshaped.sizes(), original.sizes());
938        assert_eq!(reshaped.strides(), original.strides());
939        assert_eq!(reshaped.offset(), original.offset());
940
941        assert_layout_preserved!(&original, &reshaped);
942    }
943
944    #[test]
945    fn test_reshape_mixed_dims_3d() {
946        // 3D slice with one dimension exceeding the limit.
947        let original = Slice::new_row_major(vec![6, 8, 10]);
948        let reshaped = original.reshape_with_limit(Limit::from(4));
949
950        assert_eq!(
951            factor_dims(original.sizes(), Limit::from(4)),
952            vec![vec![3, 2], vec![4, 2], vec![2, 5]]
953        );
954        assert_eq!(reshaped.sizes(), &[3, 2, 4, 2, 2, 5]);
955
956        assert_layout_preserved!(&original, &reshaped);
957    }
958
959    #[test]
960    fn test_reshape_all_large_dims() {
961        // 3D slice with all dimensions exceeding the limit.
962        let original = Slice::new_row_major(vec![12, 18, 20]);
963        let reshaped = original.reshape_with_limit(Limit::from(4));
964
965        assert_eq!(
966            factor_dims(original.sizes(), Limit::from(4)),
967            vec![vec![4, 3], vec![3, 3, 2], vec![4, 5]]
968        );
969        assert_eq!(reshaped.sizes(), &[4, 3, 3, 3, 2, 4, 5]);
970
971        assert_layout_preserved!(&original, &reshaped);
972    }
973
974    #[test]
975    fn test_reshape_split_1d_factors_3_3_2_2() {
976        // 36 = 3 × 3 × 2 × 2.
977        let original = Slice::new_row_major(vec![36]);
978        let reshaped = reshape_with_limit(&original, Limit::from(3));
979
980        assert_eq!(
981            factor_dims(original.sizes(), Limit::from(3)),
982            vec![vec![3, 3, 2, 2]]
983        );
984        assert_eq!(reshaped.sizes(), &[3, 3, 2, 2]);
985        assert_layout_preserved!(&original, &reshaped);
986    }
987
988    #[test]
989    fn test_reshape_large_prime_dimension() {
990        // Prime larger than limit, cannot be factored.
991        let original = Slice::new_row_major(vec![7]);
992        let reshaped = reshape_with_limit(&original, Limit::from(4));
993
994        // Should remain as-is since 7 is prime > 4
995        assert_eq!(factor_dims(original.sizes(), Limit::from(4)), vec![vec![7]]);
996        assert_eq!(reshaped.sizes(), &[7]);
997
998        assert_layout_preserved!(&original, &reshaped);
999    }
1000
1001    #[test]
1002    fn test_reshape_split_1d_factors_5_3_2() {
1003        // 30 = 5 × 3 × 2, all ≤ limit.
1004        let original = Slice::new_row_major(vec![30]);
1005        let reshaped = reshape_with_limit(&original, Limit::from(5));
1006
1007        assert_eq!(
1008            factor_dims(original.sizes(), Limit::from(5)),
1009            vec![vec![5, 3, 2]]
1010        );
1011        assert_eq!(reshaped.sizes(), &[5, 3, 2]);
1012        assert_eq!(reshaped.strides(), &[6, 2, 1]);
1013
1014        assert_layout_preserved!(&original, &reshaped);
1015    }
1016
1017    #[test]
1018    fn test_reshape_factors_2_6_2_8_8() {
1019        // 12 = 6 × 2, 64 = 8 × 8 — all ≤ 8
1020        let original = Slice::new_row_major(vec![2, 12, 64]);
1021        let reshaped = original.reshape_with_limit(Limit::from(8));
1022
1023        assert_eq!(
1024            factor_dims(original.sizes(), Limit::from(8)),
1025            vec![vec![2], vec![6, 2], vec![8, 8]]
1026        );
1027        assert_eq!(reshaped.sizes(), &[2, 6, 2, 8, 8]);
1028        assert_eq!(reshaped.strides(), &[768, 128, 64, 8, 1]);
1029
1030        assert_layout_preserved!(&original, &reshaped);
1031    }
1032
1033    #[test]
1034    fn test_reshape_all_dims_within_limit() {
1035        // Original shape: [2, 3, 4] — all ≤ limit (4).
1036        let original = Slice::new_row_major(vec![2, 3, 4]);
1037        let reshaped = original.reshape_with_limit(Limit::from(4));
1038
1039        assert_eq!(
1040            factor_dims(original.sizes(), Limit::from(4)),
1041            vec![vec![2], vec![3], vec![4]]
1042        );
1043        assert_eq!(reshaped.sizes(), &[2, 3, 4]);
1044        assert_eq!(reshaped.strides(), original.strides());
1045        assert_eq!(reshaped.offset(), original.offset());
1046
1047        assert_layout_preserved!(&original, &reshaped);
1048    }
1049
1050    #[test]
1051    fn test_reshape_degenerate_dimension() {
1052        // Degenerate dimension should remain unchanged.
1053        let original = Slice::new_row_major(vec![1, 12]);
1054        let reshaped = original.reshape_with_limit(Limit::from(4));
1055
1056        assert_eq!(
1057            factor_dims(original.sizes(), Limit::from(4)),
1058            vec![vec![1], vec![4, 3]]
1059        );
1060        assert_eq!(reshaped.sizes(), &[1, 4, 3]);
1061
1062        assert_layout_preserved!(&original, &reshaped);
1063    }
1064
1065    #[test]
1066    fn test_select_then_reshape() {
1067        // Original shape: 2 zones, 3 hosts, 4 gpus
1068        let original = shape!(zone = 2, host = 3, gpu = 4);
1069
1070        // Select the zone=1 plane: shape becomes [1, 3, 4]
1071        let selected = original.select("zone", 1).unwrap();
1072        assert_eq!(selected.slice().offset(), 12); // Nonzero offset.
1073        assert_eq!(selected.slice().sizes(), &[1, 3, 4]);
1074
1075        // Reshape the selected slice using limit=2 in row-major
1076        // layout.
1077        let reshaped = selected.slice().reshape_with_limit(Limit::from(2));
1078
1079        assert_eq!(
1080            factor_dims(selected.slice().sizes(), Limit::from(2)),
1081            vec![vec![1], vec![3], vec![2, 2]]
1082        );
1083        assert_eq!(reshaped.sizes(), &[1, 3, 2, 2]);
1084        assert_eq!(reshaped.strides(), &[12, 4, 2, 1]);
1085        assert_eq!(reshaped.offset(), 12); // Offset verified preserved.
1086
1087        assert_layout_preserved!(selected.slice(), &reshaped);
1088    }
1089
1090    #[test]
1091    fn test_select_host_plane_then_reshape() {
1092        // Original shape: 2 zones, 3 hosts, 4 gpus.
1093        let original = shape!(zone = 2, host = 3, gpu = 4);
1094        // Select the host=2 plane: shape becomes [2, 1, 4].
1095        let selected = original.select("host", 2).unwrap();
1096        // Reshape the selected slice using limit=2 in row-major
1097        // layout.
1098        let reshaped = selected.slice().reshape_with_limit(Limit::from(2));
1099
1100        assert_layout_preserved!(selected.slice(), &reshaped);
1101    }
1102
1103    #[test]
1104    fn test_reshape_after_select_no_factoring_due_to_primes() {
1105        // Original shape: 3 zones, 4 hosts, 5 gpus
1106        let original = shape!(zone = 3, host = 4, gpu = 5);
1107        // First select: fix zone = 1 → shape: [1, 4, 5].
1108        let selected_zone = original.select("zone", 1).unwrap();
1109        assert_eq!(selected_zone.slice().sizes(), &[1, 4, 5]);
1110        // Second select: fix host = 2 → shape: [1, 1, 5].
1111        let selected_host = selected_zone.select("host", 2).unwrap();
1112        assert_eq!(selected_host.slice().sizes(), &[1, 1, 5]);
1113        // Reshape with limit = 2.
1114        let reshaped = selected_host.slice().reshape_with_limit(Limit::from(2));
1115
1116        assert_eq!(
1117            factor_dims(selected_host.slice().sizes(), Limit::from(2)),
1118            vec![vec![1], vec![1], vec![5]]
1119        );
1120        assert_eq!(reshaped.sizes(), &[1, 1, 5]);
1121
1122        assert_layout_preserved!(selected_host.slice(), &reshaped);
1123    }
1124
1125    #[test]
1126    fn test_reshape_after_multiple_selects_triggers_factoring() {
1127        // Original shape: 2 zones, 4 hosts, 8 gpus
1128        let original = shape!(zone = 2, host = 4, gpu = 8);
1129        // Select zone=1 → shape: [1, 4, 8]
1130        let selected_zone = original.select("zone", 1).unwrap();
1131        assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
1132
1133        // Select host=2 → shape: [1, 1, 8]
1134        let selected_host = selected_zone.select("host", 2).unwrap();
1135        assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
1136
1137        // Reshape with limit = 2 → gpu=8 should factor
1138        let reshaped = selected_host.slice().reshape_with_limit(Limit::from(2));
1139
1140        assert_eq!(
1141            factor_dims(selected_host.slice().sizes(), Limit::from(2)),
1142            vec![vec![1], vec![1], vec![2, 2, 2]]
1143        );
1144        assert_eq!(reshaped.sizes(), &[1, 1, 2, 2, 2]);
1145
1146        assert_layout_preserved!(selected_host.slice(), &reshaped);
1147    }
1148
1149    #[test]
1150    fn test_expand_labels_singleton_dims() {
1151        let factors = vec![("x".into(), vec![2]), ("y".into(), vec![4])];
1152        let expected = vec!["x", "y"];
1153        assert_eq!(expand_labels(&factors), expected);
1154    }
1155
1156    #[test]
1157    fn test_expand_labels_factored_dims() {
1158        let factors = vec![("gpu".into(), vec![2, 2, 2])];
1159        let expected = vec!["gpu/0", "gpu/1", "gpu/2"];
1160        assert_eq!(expand_labels(&factors), expected);
1161    }
1162
1163    #[test]
1164    fn test_expand_labels_mixed_dims() {
1165        let factors = vec![("zone".into(), vec![2]), ("gpu".into(), vec![2, 2])];
1166        let expected = vec!["zone", "gpu/0", "gpu/1"];
1167        assert_eq!(expand_labels(&factors), expected);
1168    }
1169
1170    #[test]
1171    fn test_expand_labels_empty() {
1172        let factors: Vec<(String, Vec<usize>)> = vec![];
1173        let expected: Vec<String> = vec![];
1174        assert_eq!(expand_labels(&factors), expected);
1175    }
1176
1177    #[test]
1178    fn test_reshape_shape_noop() {
1179        let shape = shape!(x = 4, y = 8);
1180        let reshaped = reshape_shape(&shape, Limit::from(8));
1181        assert_eq!(reshaped.shape.labels(), &["x", "y"]);
1182        assert_eq!(reshaped.shape.slice(), shape.slice());
1183    }
1184
1185    #[test]
1186    fn test_reshape_shape_factored() {
1187        let shape = shape!(gpu = 8);
1188        let reshaped = reshape_shape(&shape, Limit::from(2));
1189        assert_eq!(reshaped.shape.labels(), &["gpu/0", "gpu/1", "gpu/2"]);
1190        assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2]);
1191
1192        let expected = shape.slice().reshape_with_limit(Limit::from(2));
1193        assert_eq!(reshaped.shape.slice(), &expected);
1194    }
1195
1196    #[test]
1197    fn test_reshape_shape_singleton() {
1198        let shape = shape!(x = 3);
1199        let reshaped = reshape_shape(&shape, Limit::from(8));
1200        assert_eq!(reshaped.shape.labels(), &["x"]);
1201        assert_eq!(reshaped.shape.slice(), shape.slice());
1202    }
1203
1204    #[test]
1205    fn test_reshape_shape_prime_exceeds_limit() {
1206        let shape = shape!(x = 11);
1207        let reshaped = reshape_shape(&shape, Limit::from(5));
1208        assert_eq!(reshaped.shape.labels(), &["x"]);
1209        assert_eq!(reshaped.shape.slice(), shape.slice());
1210    }
1211
1212    #[test]
1213    fn test_reshape_shape_mixed_dims() {
1214        let shape = shape!(zone = 2, gpu = 8);
1215        let reshaped = reshape_shape(&shape, Limit::from(2));
1216        assert_eq!(
1217            reshaped.shape.labels(),
1218            &["zone", "gpu/0", "gpu/1", "gpu/2"]
1219        );
1220        assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2, 2]);
1221
1222        let expected = shape.slice().reshape_with_limit(Limit::from(2));
1223        assert_eq!(reshaped.shape.slice(), &expected);
1224    }
1225
1226    #[test]
1227    fn test_reshape_shape_after_selects() {
1228        // Original shape: 2 zones, 4 hosts, 8 gpus
1229        let original = shape!(zone = 2, host = 4, gpu = 8);
1230
1231        // Select zone=1 → shape: [1, 4, 8]
1232        let selected_zone = original.select("zone", 1).unwrap();
1233        assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
1234
1235        // Select host=2 → shape: [1, 1, 8]
1236        let selected_host = selected_zone.select("host", 2).unwrap();
1237        assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
1238
1239        // Reshape shape through high-level API
1240        let reshaped = reshape_shape(&selected_host, Limit::from(2));
1241
1242        // Labels should be: zone, host, gpu/0, gpu/1, gpu/2
1243        assert_eq!(
1244            reshaped.shape.labels(),
1245            &["zone", "host", "gpu/0", "gpu/1", "gpu/2"]
1246        );
1247
1248        // Sizes should reflect factored GPU dimension
1249        assert_eq!(reshaped.shape.slice().sizes(), &[1, 1, 2, 2, 2]);
1250
1251        // Check against low-level equivalent reshaped slice
1252        let expected = selected_host.slice().reshape_with_limit(Limit::from(2));
1253        assert_eq!(reshaped.shape.slice(), &expected);
1254    }
1255
1256    use std::collections::BTreeSet;
1257
1258    use proptest::prelude::*;
1259    use proptest::test_runner::TestRunner;
1260
1261    use crate::selection::EvalOpts;
1262    use crate::strategy::gen_selection;
1263    use crate::strategy::gen_slice;
1264
1265    proptest! {
1266        #![proptest_config(ProptestConfig {
1267            cases: 20, ..ProptestConfig::default()
1268        })]
1269        #[test]
1270        fn test_reshape_selection((slice, fanout_limit) in gen_slice(4, 64).prop_flat_map(|slice| {
1271            let max_dimension_size = slice.sizes().iter().max().unwrap();
1272            (1..=*max_dimension_size).prop_map(move |fanout_limit| (slice.clone(), fanout_limit))
1273        })) {
1274            let shape = slice.sizes().to_vec();
1275
1276            let mut runner = TestRunner::default();
1277            let selection = gen_selection(4, shape.clone(), 0).new_tree(&mut runner).unwrap().current();
1278
1279            let original_selected_ranks = selection
1280                .eval(&EvalOpts::strict(), &slice)
1281                .unwrap()
1282                .collect::<BTreeSet<_>>();
1283
1284            let reshaped_slice = reshape_with_limit(&slice, Limit::from(fanout_limit));
1285            let reshaped_selection = reshape_selection(selection, &slice, &reshaped_slice).ok().unwrap();
1286
1287            let folded_selected_ranks = reshaped_selection
1288            .eval(&EvalOpts::strict(), &reshaped_slice)?
1289            .collect::<BTreeSet<_>>();
1290
1291            prop_assert_eq!(original_selected_ranks, folded_selected_ranks);
1292        }
1293    }
1294}