ndslice/
reshape.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Dimensional reshaping of slices and shapes.
10//!
11//! This module defines utilities for transforming a [`Slice`] or
12//! [`Shape`] by factoring large extents into smaller ones under a
13//! given limit. The result is a reshaped view with increased
14//! dimensionality and preserved memory layout.
15//!
16//! This is useful for hierarchical routing, structured fanout, and
17//! other multidimensional layout transformations.
18//!
19//! For [`Shape`]s, reshaping also expands dimension labels using a
20//! `label/N` naming convention, preserving the semantics of the
21//! original shape in the reshaped reshape_with_limit.
22//!
23//! See [`reshape_with_limit`] and [`reshape_shape`] for entry points.
24use std::fmt;
25
26use crate::Range;
27use crate::Selection;
28use crate::dsl::union;
29use crate::shape::Shape;
30use crate::slice::Slice;
31
32/// Coordinate vector used throughout reshape logic. Semantically
33/// represents a point in multidimensional space.
34pub type Coord = Vec<usize>;
35
36/// A reshaped version of a `Shape`, with factored dimensions and
37/// updated labels.
38///
39/// This type preserves coordinate bijections with the original shape
40/// and provides access to the transformed layout and label mappings.
41pub struct ReshapedShape {
42    /// The reshaped shape, with new labels and underlying factored
43    /// slice.
44    pub shape: Shape,
45
46    /// For each original dimension label, the list of sizes it was
47    /// split into.
48    pub factors: Vec<(String, Vec<usize>)>,
49}
50
51#[allow(dead_code)]
52const _: () = {
53    fn assert<T: Send + Sync + 'static>() {}
54    let _ = assert::<ReshapedShape>;
55};
56
57impl std::fmt::Debug for ReshapedShape {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("ReshapedShape")
60            .field("labels", &self.shape.labels())
61            .field("sizes", &self.shape.slice().sizes())
62            .field("strides", &self.shape.slice().strides())
63            .field("offset", &self.shape.slice().offset())
64            .field("factors", &self.factors)
65            .finish()
66    }
67}
68
69impl std::fmt::Display for ReshapedShape {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
71        write!(
72            f,
73            "ReshapedShape {{ [off={} sz={:?} st={:?} lab={:?} fac={:?}] }}",
74            self.shape.slice().offset(),
75            self.shape.slice().sizes(),
76            self.shape.slice().strides(),
77            self.shape.labels(),
78            self.factors
79        )
80    }
81}
82
83/// Returns, for each size, a list of factors that respect the given
84/// limit. If a size is ≤ limit, it is returned as a singleton.
85/// Otherwise, it is factored greedily using divisors ≤ limit, from
86/// largest to smallest.
87///
88/// For best results, dimensions should be chosen to allow factoring
89/// into small values under the selected limit (e.g., ≤ 32).
90/// Large prime numbers cannot be broken down and will remain as-is,
91/// limiting reshaping potential.
92///
93/// Prefer powers of 2 or other highly composite numbers
94/// (e.g., 8, 16, 32, 60, 120) over large primes (e.g., 17, 37, 113)
95/// when designing shapes intended for reshaping.
96pub(crate) fn factor_dims(sizes: &[usize], limit: Limit) -> Vec<Vec<usize>> {
97    let limit = limit.get();
98    sizes
99        .iter()
100        .map(|&size| {
101            if size <= limit {
102                return vec![size];
103            }
104            let mut rem = size;
105            let mut factors = Vec::new();
106            for d in (2..=limit).rev() {
107                while rem % d == 0 {
108                    factors.push(d);
109                    rem /= d;
110                }
111            }
112            if rem > 1 {
113                factors.push(rem);
114            }
115            factors
116        })
117        .collect()
118}
119
120/// Constructs a function that maps coordinates from the original
121/// slice to equivalent coordinates in the reshaped slice, preserving
122/// their flat (linear) position.
123pub fn to_reshaped_coord<'a>(
124    original: &'a Slice,
125    reshaped: &'a Slice,
126) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
127    let original = original.clone();
128    let reshaped = reshaped.clone();
129    move |coord: &[usize]| -> Coord {
130        let flat = original.location(coord).unwrap();
131        reshaped.coordinates(flat).unwrap()
132    }
133}
134
135/// Constructs a function that maps coordinates from the reshaped
136/// slice back to equivalent coordinates in the original slice,
137/// preserving their flat (linear) position.
138pub fn to_original_coord<'a>(
139    reshaped: &'a Slice,
140    original: &'a Slice,
141) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
142    let reshaped = reshaped.clone();
143    let original = original.clone();
144    move |coord: &[usize]| -> Coord {
145        let flat = reshaped.location(coord).unwrap();
146        original.coordinates(flat).unwrap()
147    }
148}
149
150/// A shaping constraint that bounds the maximum extent allowed in any
151/// reshaped dimension.
152///
153/// This limit controls how a given dimension is factored during
154/// reshaping. Values larger than `limit` are recursively decomposed
155/// into smaller factors (e.g., `reshape_with_limit([1024],
156/// Limit::new(32))` → `[32, 32]`).
157///
158/// The default limit is `32`, which balances fanout depth and layout
159/// regularity.
160///
161/// # Example
162/// ```
163/// use ndslice::reshape::Limit;
164/// let limit = Limit::new(64);
165/// assert_eq!(limit.get(), 64);
166/// ```
167#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
168pub struct Limit(usize);
169
170impl Limit {
171    /// Creates a new `Limit`. Panics if less than 1.
172    pub fn new(n: usize) -> Self {
173        assert!(n >= 1, "Limit must be at least 1");
174        Self(n)
175    }
176
177    /// Returns the inner value.
178    pub fn get(self) -> usize {
179        self.0
180    }
181}
182
183impl Default for Limit {
184    fn default() -> Self {
185        Self(32)
186    }
187}
188
189impl From<usize> for Limit {
190    fn from(n: usize) -> Self {
191        Self::new(n)
192    }
193}
194
195/// A trait for types that can be reshaped into a higher-dimensional
196/// view by factoring large extents into smaller ones.
197///
198/// This is implemented for [`Slice`], enabling ergonomic access to
199/// [`reshape_with_limit`] as a method.
200///
201/// # Example
202/// ```
203/// use ndslice::Slice;
204/// use ndslice::reshape::Limit;
205/// use ndslice::reshape::ReshapeSliceExt;
206///
207/// let slice = Slice::new_row_major(vec![1024]);
208/// let reshaped = slice.reshape_with_limit(Limit::new(32));
209/// assert_eq!(reshaped.sizes(), &[32, 32]);
210/// ```
211/// # Returns
212/// A reshaped [`Slice`] with increased dimensionality and preserved
213/// layout.
214pub trait ReshapeSliceExt {
215    /// Returns a reshaped version of this structure by factoring each
216    /// dimension into smaller extents no greater than `limit`,
217    /// preserving memory layout and flat index semantics. See
218    /// [`reshape_with_limit`] for full behavior and rationale.
219    ///
220    /// # Arguments
221    /// - `limit`: maximum size allowed in any reshaped dimension
222    ///
223    /// # Returns
224    /// A reshaped [`Slice`] with increased dimensionality and a
225    /// bijective mapping to the original.
226    fn reshape_with_limit(&self, limit: Limit) -> Slice;
227}
228
229impl ReshapeSliceExt for Slice {
230    fn reshape_with_limit(&self, limit: Limit) -> Slice {
231        reshape_with_limit(self, limit)
232    }
233}
234
235/// Extension trait for reshaping `Shape`s by factoring large dimensions.
236pub trait ReshapeShapeExt {
237    /// Produces a reshaped version of the shape with expanded
238    /// dimensions under the given size limit.
239    fn reshape(&self, limit: Limit) -> ReshapedShape;
240}
241
242impl ReshapeShapeExt for Shape {
243    fn reshape(&self, limit: Limit) -> ReshapedShape {
244        reshape_shape(self, limit)
245    }
246}
247
248/// For convenient `slice.reshape_with_limit()`, `shape.reshape()`
249/// syntax, `use reshape::prelude::*`.
250pub mod prelude {
251    pub use super::ReshapeShapeExt;
252    pub use super::ReshapeSliceExt;
253}
254
255/// Reshapes a slice by factoring each dimension into smaller extents
256/// under the given limit.
257///
258/// This transformation increases dimensionality by breaking large
259/// sizes into products of smaller factors (e.g., `[1024]` with limit
260/// 32 becomes `[32, 32]`). The result is a new [`Slice`] that
261/// preserves memory layout and flat index semantics.
262///
263/// Factoring is greedy, starting from the largest divisors ≤ `limit`.
264/// Dimensions that cannot be factored under the limit are left
265/// unchanged.
266///
267/// # Arguments
268/// - `slice`: the original multidimensional slice
269/// - `limit`: maximum extent allowed in any factored subdimension
270///
271/// # Returns
272/// A reshaped [`Slice`] with updated sizes and strides.
273///
274/// # Example
275/// ```
276/// use ndslice::Slice;
277/// use ndslice::reshape::Limit;
278/// use ndslice::reshape::reshape_with_limit;
279///
280/// let slice = Slice::new_row_major(vec![1024]);
281/// let reshaped = reshape_with_limit(&slice, Limit::new(32));
282/// assert_eq!(reshaped.sizes(), &[32, 32]);
283/// ```
284pub fn reshape_with_limit(slice: &Slice, limit: Limit) -> Slice {
285    let orig_sizes = slice.sizes();
286    let orig_strides = slice.strides();
287
288    // Step 1: Factor each size into subdimensions ≤ limit.
289    let factored_sizes = factor_dims(orig_sizes, limit);
290
291    // Step 2: Compute reshaped sizes and strides (row-major only).
292    let reshaped_sizes: Vec<usize> = factored_sizes.iter().flatten().cloned().collect();
293    let mut reshaped_strides = Vec::with_capacity(reshaped_sizes.len());
294
295    for (&orig_stride, factors) in orig_strides.iter().zip(&factored_sizes) {
296        let mut sub_strides = Vec::with_capacity(factors.len());
297        let mut stride = orig_stride;
298        for &f in factors.iter().rev() {
299            sub_strides.push(stride);
300            stride *= f;
301        }
302        sub_strides.reverse();
303        reshaped_strides.extend(sub_strides);
304    }
305
306    Slice::new(slice.offset(), reshaped_sizes, reshaped_strides).unwrap()
307}
308
309/// Reshapes a labeled [`Shape`] by factoring large extents into
310/// smaller ones, producing a new shape with expanded dimensionality
311/// and updated labels.
312///
313/// This uses [`reshape_with_limit`] on the underlying slice and [`expand_labels`]
314/// to generate labels for each factored dimension.
315///
316/// # Arguments
317/// - `shape`: the labeled shape to reshape
318/// - `limit`: maximum extent allowed per factored dimension
319///
320/// # Returns
321/// A new [`ReshapedShape`] with an updated [`Shape`] and dimension
322/// factoring metadata.
323///
324/// # Panics
325/// Panics if constructing the new `Shape` fails. This should not
326/// occur unless the reshaped slice and labels are inconsistent (a
327/// programming logic error).
328pub fn reshape_shape(shape: &Shape, limit: Limit) -> ReshapedShape {
329    let reshaped_slice = shape.slice().reshape_with_limit(limit);
330    let original_labels = shape.labels();
331    let original_sizes = shape.slice().sizes();
332
333    let factors = factor_dims(original_sizes, limit);
334    let factored_dims: Vec<(String, Vec<usize>)> =
335        original_labels.iter().cloned().zip(factors).collect();
336
337    let labels = expand_labels(&factored_dims);
338    let shape = Shape::new(labels, reshaped_slice).expect("invalid reshaped shape");
339
340    ReshapedShape {
341        shape,
342        factors: factored_dims,
343    }
344}
345
346/// Expands factored dimension labels into one label per subdimension.
347///
348/// Each input pair `(label, factors)` represents an original
349/// dimension and the extents it was factored into. If a dimension was
350/// not factored, it will have a single-element vector.
351///
352/// For example:
353/// - `[("zone", vec![2]), ("gpu", vec![2, 2, 2])]`
354///   becomes `["zone", "gpu/0", "gpu/1", "gpu/2"]`
355///
356/// This is used to generate new labels for reshaped shapes, where the
357/// dimensionality increases due to factoring.
358///
359/// # Arguments
360/// - `factors`: a list of factored dimension extents, paired with
361///   their labels
362///
363/// # Returns
364/// - A `Vec<String>` of expanded labels, one for each reshaped
365///   dimension.
366pub fn expand_labels(factors: &[(String, Vec<usize>)]) -> Vec<String> {
367    let mut labels = Vec::new();
368    for (label, dims) in factors {
369        if dims.len() == 1 {
370            labels.push(label.clone());
371        } else {
372            for (i, _) in dims.iter().enumerate() {
373                labels.push(format!("{}/{}", label, i));
374            }
375        }
376    }
377    labels
378}
379
380#[derive(Debug, thiserror::Error)]
381pub enum ReshapeError {
382    #[error("unsupported selection kind {selection}")]
383    UnsupportedSelection { selection: Selection },
384}
385/// Maps a `Selection` on a `Slice` to a new `Selection` that selects all
386/// ranks in the reshaped `Slice` that the original `Selection` selected in the
387/// original `Slice`
388pub fn reshape_selection(
389    selection: Selection,
390    original_slice: &Slice,
391    reshaped_slice: &Slice,
392) -> Result<Selection, ReshapeError> {
393    fn recursive_fold(
394        selection: Selection,
395        original_slice: &Slice,
396        original_size_index: usize,
397        reshaped_slice: &Slice,
398        reshaped_size_index: usize,
399    ) -> Result<Selection, ReshapeError> {
400        if matches!(selection, Selection::True | Selection::False) {
401            return Ok(selection);
402        }
403
404        let Some(&original_dim_size) = original_slice.sizes().get(original_size_index) else {
405            return Ok(selection);
406        };
407
408        let mut accum = *reshaped_slice.sizes().get(reshaped_size_index).unwrap();
409        let mut next_reshaped_dimension_start = reshaped_size_index + 1;
410
411        while accum < original_dim_size {
412            accum *= *reshaped_slice
413                .sizes()
414                .get(next_reshaped_dimension_start)
415                .unwrap();
416            next_reshaped_dimension_start += 1;
417        }
418
419        match selection {
420            // base case
421            Selection::True | Selection::False => Ok(selection),
422            // For these cases we are not drilling down any dimensions when we recurse
423            Selection::Union(left, right) => {
424                let left = recursive_fold(
425                    *left,
426                    original_slice,
427                    original_size_index,
428                    reshaped_slice,
429                    reshaped_size_index,
430                )?;
431
432                match left {
433                    Selection::True => return Ok(Selection::True),
434                    Selection::False => {
435                        return recursive_fold(
436                            *right,
437                            original_slice,
438                            original_size_index,
439                            reshaped_slice,
440                            reshaped_size_index,
441                        );
442                    }
443                    _ => {}
444                }
445
446                let right = recursive_fold(
447                    *right,
448                    original_slice,
449                    original_size_index,
450                    reshaped_slice,
451                    reshaped_size_index,
452                )?;
453
454                Ok(match right {
455                    Selection::True => Selection::True,
456                    Selection::False => left,
457                    _ => Selection::Union(Box::new(left), Box::new(right)),
458                })
459            }
460            Selection::Intersection(left, right) => {
461                let left = recursive_fold(
462                    *left,
463                    original_slice,
464                    original_size_index,
465                    reshaped_slice,
466                    reshaped_size_index,
467                )?;
468                match left {
469                    Selection::False => return Ok(Selection::False),
470                    Selection::True => {
471                        return recursive_fold(
472                            *right,
473                            original_slice,
474                            original_size_index,
475                            reshaped_slice,
476                            reshaped_size_index,
477                        );
478                    }
479                    _ => {}
480                }
481
482                let right = recursive_fold(
483                    *right,
484                    original_slice,
485                    original_size_index,
486                    reshaped_slice,
487                    reshaped_size_index,
488                )?;
489                Ok(match right {
490                    Selection::False => Selection::False,
491                    Selection::True => left,
492                    _ => Selection::Intersection(Box::new(left), Box::new(right)),
493                })
494            }
495            Selection::All(inner) => {
496                let inner = recursive_fold(
497                    *inner,
498                    original_slice,
499                    original_size_index + 1,
500                    reshaped_slice,
501                    next_reshaped_dimension_start,
502                )?;
503
504                if matches!(inner, Selection::True | Selection::False) {
505                    return Ok(inner);
506                }
507
508                Ok((reshaped_size_index..next_reshaped_dimension_start - 1)
509                    .fold(Selection::All(Box::new(inner)), |result, _| {
510                        Selection::All(Box::new(result))
511                    }))
512            }
513            Selection::Any(inner) => {
514                let inner = recursive_fold(
515                    *inner,
516                    original_slice,
517                    original_size_index + 1,
518                    reshaped_slice,
519                    next_reshaped_dimension_start,
520                )?;
521
522                if matches!(inner, Selection::False) {
523                    return Ok(inner);
524                }
525
526                Ok((reshaped_size_index..next_reshaped_dimension_start - 1)
527                    .fold(Selection::Any(Box::new(inner)), |result, _| {
528                        Selection::Any(Box::new(result))
529                    }))
530            }
531            Selection::First(inner) => {
532                let inner = recursive_fold(
533                    *inner,
534                    original_slice,
535                    original_size_index + 1,
536                    reshaped_slice,
537                    next_reshaped_dimension_start,
538                )?;
539
540                if matches!(inner, Selection::False) {
541                    return Ok(inner);
542                }
543
544                Ok((reshaped_size_index..next_reshaped_dimension_start - 1)
545                    .fold(Selection::First(Box::new(inner)), |result, _| {
546                        Selection::First(Box::new(result))
547                    }))
548            }
549            Selection::Range(range, inner) => {
550                // We can fold a rectangle along a dimension by factoring it into up to 3 pieces:
551                // A starting piece if there is a region that begins after the start of a fold but spans to the end of that fold
552                // A middle piece that spans the entire fold for n folds
553                // An ending piece if there is a region that begins at the start of the fold but ends before the end of that fold
554                //
555                // To visualize: range(2:8, true)
556                // 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8
557                // o | o | x | x | x | x | x | x | o
558                //
559                // => fold into 3 pieces:
560                //
561                // start: range(0:1, range(2:3, true))
562                // 0 | 1 | 2 |
563                // o | o | x |
564                //
565                // middle: range(1:2, all(true))
566                // 3 | 4 | 5 |
567                // x | x | x |
568                //
569                // end: range(2:3, range(0:2, true))
570                // 6 | 7 | 8
571                // x | x | o
572                //
573                // When step size is larger than 1 this does get a bit more hairy where the middle piece must be
574                // split into several pieces and the end must be aligned
575                fn fold_once(
576                    Range(start, end, step): Range,
577                    inner: Selection,
578                    original_dimension_n_size: usize,
579                    new_dimension_n_size: usize,
580                ) -> Vec<Selection> {
581                    // Clamp end to the original dimension size to avoid computing
582                    // reshaped coordinates for out-of-bounds indices.
583                    // This is necessary because Selection allows Range end > dim_size
584                    // (the indices beyond the dimension are simply not selected).
585                    let end = end.map(|e| e.min(original_dimension_n_size));
586
587                    let dimension_n_plus_one_start = start / new_dimension_n_size;
588                    let dimension_n_plus_one_end = end.map(|end| {
589                        if end % new_dimension_n_size == 0 {
590                            end / new_dimension_n_size
591                        } else {
592                            end / new_dimension_n_size + 1
593                        }
594                    });
595                    let dimension_n_plus_one_size =
596                        original_dimension_n_size / new_dimension_n_size;
597
598                    let new_dimension_n_start = start % new_dimension_n_size;
599                    let new_dimension_n_end = end.map(|end| {
600                        if end % new_dimension_n_size == 0 && end > new_dimension_n_size - 1 {
601                            // If the original end was 3 and the new dimension size is 3
602                            // Then the new end should be 3 as opposed to 0 as end represents an upper bound
603                            new_dimension_n_size
604                        } else {
605                            end % new_dimension_n_size
606                        }
607                    });
608
609                    let mut result = vec![];
610
611                    // Simplest case where the entire rectangle is contains on a single fold of dim n+1
612                    if dimension_n_plus_one_end
613                        .is_some_and(|end| dimension_n_plus_one_start + 1 == end)
614                        || (end.is_none()
615                            && dimension_n_plus_one_start == dimension_n_plus_one_size)
616                    {
617                        return vec![Selection::Range(
618                            Range(dimension_n_plus_one_start, dimension_n_plus_one_end, 1),
619                            Box::new(Selection::Range(
620                                Range(new_dimension_n_start, new_dimension_n_end, step),
621                                Box::new(inner.clone()),
622                            )),
623                        )];
624                    }
625
626                    // Simpler case first where the middle piece can be represented with a single range
627                    if step == 1 {
628                        // Starting piece
629                        // ex: range(0:1, range(2:3, true))
630                        // 0 | 1 | 2 |
631                        // o | o | x |
632                        let middle_start = match start % new_dimension_n_size {
633                            0 => dimension_n_plus_one_start,
634                            _ => {
635                                result.push(Selection::Range(
636                                    Range(
637                                        dimension_n_plus_one_start,
638                                        Some(dimension_n_plus_one_start + 1),
639                                        1,
640                                    ),
641                                    Box::new(Selection::Range(
642                                        Range(
643                                            new_dimension_n_start,
644                                            Some(new_dimension_n_size),
645                                            step,
646                                        ),
647                                        Box::new(inner.clone()),
648                                    )),
649                                ));
650                                dimension_n_plus_one_start + 1
651                            }
652                        };
653
654                        // Ending piece
655                        // ex: range(2:3, range(0:2, true))
656                        // 6 | 7 | 8
657                        // x | x | o
658                        let middle_end = match (end, dimension_n_plus_one_end) {
659                            (Some(end), Some(dimension_n_plus_one_end))
660                                if end % new_dimension_n_size != 0 =>
661                            {
662                                result.push(Selection::Range(
663                                    Range(
664                                        dimension_n_plus_one_end - 1,
665                                        Some(dimension_n_plus_one_end),
666                                        1,
667                                    ),
668                                    Box::new(Selection::Range(
669                                        Range(0, new_dimension_n_end, step),
670                                        Box::new(inner.clone()),
671                                    )),
672                                ));
673                                Some(dimension_n_plus_one_end - 1)
674                            }
675                            _ => dimension_n_plus_one_end,
676                        };
677
678                        // Middle pieces
679                        // ex: range(1:2, all(true))
680                        // 3 | 4 | 5 |
681                        // x | x | x |
682                        if middle_end.is_some_and(|end| end > middle_start)
683                            || (middle_end.is_none() && middle_start < dimension_n_plus_one_size)
684                        {
685                            result.push(Selection::Range(
686                                Range(middle_start, middle_end, 1),
687                                Box::new(Selection::All(Box::new(inner.clone()))),
688                            ));
689                        }
690                    // Complicated case where step size is larger than 1 that involves splitting up
691                    // the middle piece
692                    } else {
693                        // Greatest common divisor
694                        fn gcd(a: usize, b: usize) -> usize {
695                            if b == 0 { a } else { gcd(b, a % b) }
696                        }
697
698                        let row_pattern_period = step / gcd(step, new_dimension_n_size);
699
700                        // get the coordinates of the first item on the next row
701                        let mut row_col_iter = std::iter::successors(
702                            Some((dimension_n_plus_one_start, start % new_dimension_n_size)),
703                            |&(row, col)| {
704                                let cols_before_end = new_dimension_n_size - 1 - col;
705                                let steps_before_end = cols_before_end / step;
706                                let last_col_before_end = col + step * steps_before_end;
707
708                                let next_row =
709                                    ((row * new_dimension_n_size) + last_col_before_end + step)
710                                        / new_dimension_n_size;
711                                let next_col = (last_col_before_end + step) % new_dimension_n_size;
712
713                                Some((next_row, next_col))
714                            },
715                        )
716                        .peekable();
717
718                        // Needs start piece
719                        if start % new_dimension_n_size != 0 {
720                            let (row, col) = row_col_iter.next().unwrap();
721
722                            result.push(Selection::Range(
723                                Range(row, Some(row + 1), 1),
724                                Box::new(Selection::Range(
725                                    Range(col, None, step),
726                                    Box::new(inner.clone()),
727                                )),
728                            ));
729                        };
730
731                        // Middle pieces
732                        for _ in 0..row_pattern_period {
733                            let end_row = end.map(|end| end / new_dimension_n_size);
734
735                            if match end_row {
736                                Some(end_row) => row_col_iter.peek().unwrap().0 >= end_row,
737                                None => row_col_iter.peek().unwrap().0 >= dimension_n_plus_one_size,
738                            } {
739                                break;
740                            }
741                            let (row_index, col) = row_col_iter.next().unwrap();
742
743                            result.push(Selection::Range(
744                                Range(row_index, end_row, row_pattern_period),
745                                Box::new(Selection::Range(
746                                    Range(col, None, step),
747                                    Box::new(inner.clone()),
748                                )),
749                            ));
750                        }
751
752                        // Needs end piece
753                        if let Some(end) = end {
754                            let end_row = end / new_dimension_n_size;
755
756                            for (row, col) in row_col_iter {
757                                if row > end_row {
758                                    break;
759                                }
760
761                                if row % row_pattern_period == end_row % row_pattern_period
762                                    && col < end % new_dimension_n_size
763                                {
764                                    result.push(Selection::Range(
765                                        Range(end_row, Some(end_row + 1), 1),
766                                        Box::new(Selection::Range(
767                                            Range(col, Some(end % new_dimension_n_size), step),
768                                            Box::new(inner.clone()),
769                                        )),
770                                    ));
771                                    break;
772                                }
773                            }
774                        }
775                    }
776                    result
777                }
778
779                let inner = recursive_fold(
780                    *inner,
781                    original_slice,
782                    original_size_index + 1,
783                    reshaped_slice,
784                    next_reshaped_dimension_start,
785                )?;
786                if matches!(inner, Selection::False) {
787                    return Ok(inner);
788                }
789                let mut pieces = vec![Selection::Range(range, Box::new(inner))];
790
791                // If [24] is being reshaped to [4, 3, 2] this will yield [2, 3] (dropping the first dimension and reversed)
792                // This is because we need to first fold by 2 to get [12, 3], then fold by 3 to get [4, 3, 2]
793                let reversed_dimensions = reshaped_slice.sizes()
794                    [reshaped_size_index + 1..next_reshaped_dimension_start]
795                    .iter()
796                    .copied()
797                    .rev();
798
799                let mut original_dimension_size = original_dim_size;
800                for dimension in reversed_dimensions {
801                    pieces = pieces
802                        .into_iter()
803                        .flat_map(|piece| {
804                            if let Selection::Range(range, inner) = piece {
805                                fold_once(range, *inner, original_dimension_size, dimension)
806                            } else {
807                                vec![]
808                            }
809                        })
810                        .collect();
811                    original_dimension_size /= dimension;
812                }
813
814                Ok(pieces.into_iter().fold(Selection::False, |x, y| match x {
815                    Selection::False => y,
816                    _ => union(x, y),
817                }))
818            }
819            _ => Err(ReshapeError::UnsupportedSelection { selection }),
820        }
821    }
822
823    recursive_fold(selection, original_slice, 0, reshaped_slice, 0)
824}
825
826#[cfg(test)]
827mod tests {
828    use super::*;
829    use crate::Slice;
830    use crate::shape;
831
832    #[test]
833    fn test_factor_dims_basic() {
834        assert_eq!(
835            factor_dims(&[6, 8], Limit::from(4)),
836            vec![vec![3, 2], vec![4, 2]]
837        );
838        assert_eq!(factor_dims(&[5], Limit::from(3)), vec![vec![5]]);
839        assert_eq!(factor_dims(&[30], Limit::from(5)), vec![vec![5, 3, 2]]);
840    }
841
842    // Verify that reshaping preserves memory layout by checking:
843    // 1. Coordinate round-tripping: original → reshaped → original
844    // 2. Flat index equality: original and reshaped coordinates map
845    //    to the same linear index
846    // 3. Index inversion: reshaped flat index maps back to the same
847    //    reshaped coordinate
848    //
849    // Together, these checks ensure that the reshaped view is
850    // layout-preserving and provides a bijective mapping between
851    // coordinate systems.
852    #[macro_export]
853    macro_rules! assert_layout_preserved {
854        ($original:expr, $reshaped:expr) => {{
855            // Iterate over all coordinates in the original slice.
856            for coord in $original.dim_iter($original.num_dim()) {
857                let forward = to_reshaped_coord($original, &$reshaped);
858                let inverse = to_original_coord(&$reshaped, $original);
859                // Apply the forward coordinate mapping from original
860                // to reshaped space.
861                let reshaped_coord = forward(&coord);
862                // Inverse mapping: reshaped coord → original coord.
863                let roundtrip = inverse(&reshaped_coord);
864                assert_eq!(
865                    roundtrip, coord,
866                    "Inverse mismatch: reshaped {:?} → original {:?}, expected {:?}",
867                    reshaped_coord, roundtrip, coord
868                );
869                // Compute flat index in the original slice.
870                let flat_orig = $original.location(&coord).unwrap();
871                // Compute flat index in the reshaped slice.
872                let flat_reshaped = $reshaped.location(&reshaped_coord).unwrap();
873                // Check that the flat index is preserved by the
874                // reshaping.
875                assert_eq!(
876                    flat_orig, flat_reshaped,
877                    "Flat index mismatch: original {:?} → reshaped {:?}",
878                    coord, reshaped_coord
879                );
880                // Invert the reshaped flat index back to coordinates.
881                let recovered = $reshaped.coordinates(flat_reshaped).unwrap();
882                // Ensure coordinate inversion is correct (round
883                // trip).
884                assert_eq!(
885                    reshaped_coord, recovered,
886                    "Coordinate mismatch: flat index {} → expected {:?}, got {:?}",
887                    flat_reshaped, reshaped_coord, recovered
888                );
889            }
890        }};
891    }
892
893    #[test]
894    fn test_reshape_split_1d_row_major() {
895        let s = Slice::new_row_major(vec![1024]);
896        let reshaped = s.reshape_with_limit(Limit::from(8));
897
898        assert_eq!(reshaped.offset(), 0);
899        assert_eq!(reshaped.sizes(), &vec![8, 8, 8, 2]);
900        assert_eq!(reshaped.strides(), &vec![128, 16, 2, 1]);
901        assert_eq!(
902            factor_dims(s.sizes(), Limit::from(8)),
903            vec![vec![8, 8, 8, 2]]
904        );
905
906        assert_layout_preserved!(&s, &reshaped);
907    }
908
909    #[test]
910    fn test_reshape_6_with_limit_2() {
911        let s = Slice::new_row_major(vec![6]);
912        let reshaped = reshape_with_limit(&s, Limit::from(2));
913        assert_eq!(factor_dims(s.sizes(), Limit::from(2)), vec![vec![2, 3]]);
914        assert_layout_preserved!(&s, &reshaped);
915    }
916
917    #[test]
918    fn test_reshape_identity_noop_2d() {
919        // All dimensions ≤ limit.
920        let original = Slice::new_row_major(vec![4, 8]);
921        let reshaped = original.reshape_with_limit(Limit::from(8));
922
923        assert_eq!(reshaped.sizes(), original.sizes());
924        assert_eq!(reshaped.strides(), original.strides());
925        assert_eq!(reshaped.offset(), original.offset());
926        assert_eq!(
927            vec![vec![4], vec![8]],
928            original
929                .sizes()
930                .iter()
931                .map(|&n| vec![n])
932                .collect::<Vec<_>>()
933        );
934        assert_layout_preserved!(&original, &reshaped);
935    }
936
937    #[test]
938    fn test_reshape_empty_slice() {
939        // 0-dimensional slice.
940        let original = Slice::new_row_major(vec![]);
941        let reshaped = reshape_with_limit(&original, Limit::from(8));
942
943        assert_eq!(reshaped.sizes(), original.sizes());
944        assert_eq!(reshaped.strides(), original.strides());
945        assert_eq!(reshaped.offset(), original.offset());
946
947        assert_layout_preserved!(&original, &reshaped);
948    }
949
950    #[test]
951    fn test_reshape_mixed_dims_3d() {
952        // 3D slice with one dimension exceeding the limit.
953        let original = Slice::new_row_major(vec![6, 8, 10]);
954        let reshaped = original.reshape_with_limit(Limit::from(4));
955
956        assert_eq!(
957            factor_dims(original.sizes(), Limit::from(4)),
958            vec![vec![3, 2], vec![4, 2], vec![2, 5]]
959        );
960        assert_eq!(reshaped.sizes(), &[3, 2, 4, 2, 2, 5]);
961
962        assert_layout_preserved!(&original, &reshaped);
963    }
964
965    #[test]
966    fn test_reshape_all_large_dims() {
967        // 3D slice with all dimensions exceeding the limit.
968        let original = Slice::new_row_major(vec![12, 18, 20]);
969        let reshaped = original.reshape_with_limit(Limit::from(4));
970
971        assert_eq!(
972            factor_dims(original.sizes(), Limit::from(4)),
973            vec![vec![4, 3], vec![3, 3, 2], vec![4, 5]]
974        );
975        assert_eq!(reshaped.sizes(), &[4, 3, 3, 3, 2, 4, 5]);
976
977        assert_layout_preserved!(&original, &reshaped);
978    }
979
980    #[test]
981    fn test_reshape_split_1d_factors_3_3_2_2() {
982        // 36 = 3 × 3 × 2 × 2.
983        let original = Slice::new_row_major(vec![36]);
984        let reshaped = reshape_with_limit(&original, Limit::from(3));
985
986        assert_eq!(
987            factor_dims(original.sizes(), Limit::from(3)),
988            vec![vec![3, 3, 2, 2]]
989        );
990        assert_eq!(reshaped.sizes(), &[3, 3, 2, 2]);
991        assert_layout_preserved!(&original, &reshaped);
992    }
993
994    #[test]
995    fn test_reshape_large_prime_dimension() {
996        // Prime larger than limit, cannot be factored.
997        let original = Slice::new_row_major(vec![7]);
998        let reshaped = reshape_with_limit(&original, Limit::from(4));
999
1000        // Should remain as-is since 7 is prime > 4
1001        assert_eq!(factor_dims(original.sizes(), Limit::from(4)), vec![vec![7]]);
1002        assert_eq!(reshaped.sizes(), &[7]);
1003
1004        assert_layout_preserved!(&original, &reshaped);
1005    }
1006
1007    #[test]
1008    fn test_reshape_split_1d_factors_5_3_2() {
1009        // 30 = 5 × 3 × 2, all ≤ limit.
1010        let original = Slice::new_row_major(vec![30]);
1011        let reshaped = reshape_with_limit(&original, Limit::from(5));
1012
1013        assert_eq!(
1014            factor_dims(original.sizes(), Limit::from(5)),
1015            vec![vec![5, 3, 2]]
1016        );
1017        assert_eq!(reshaped.sizes(), &[5, 3, 2]);
1018        assert_eq!(reshaped.strides(), &[6, 2, 1]);
1019
1020        assert_layout_preserved!(&original, &reshaped);
1021    }
1022
1023    #[test]
1024    fn test_reshape_factors_2_6_2_8_8() {
1025        // 12 = 6 × 2, 64 = 8 × 8 — all ≤ 8
1026        let original = Slice::new_row_major(vec![2, 12, 64]);
1027        let reshaped = original.reshape_with_limit(Limit::from(8));
1028
1029        assert_eq!(
1030            factor_dims(original.sizes(), Limit::from(8)),
1031            vec![vec![2], vec![6, 2], vec![8, 8]]
1032        );
1033        assert_eq!(reshaped.sizes(), &[2, 6, 2, 8, 8]);
1034        assert_eq!(reshaped.strides(), &[768, 128, 64, 8, 1]);
1035
1036        assert_layout_preserved!(&original, &reshaped);
1037    }
1038
1039    #[test]
1040    fn test_reshape_all_dims_within_limit() {
1041        // Original shape: [2, 3, 4] — all ≤ limit (4).
1042        let original = Slice::new_row_major(vec![2, 3, 4]);
1043        let reshaped = original.reshape_with_limit(Limit::from(4));
1044
1045        assert_eq!(
1046            factor_dims(original.sizes(), Limit::from(4)),
1047            vec![vec![2], vec![3], vec![4]]
1048        );
1049        assert_eq!(reshaped.sizes(), &[2, 3, 4]);
1050        assert_eq!(reshaped.strides(), original.strides());
1051        assert_eq!(reshaped.offset(), original.offset());
1052
1053        assert_layout_preserved!(&original, &reshaped);
1054    }
1055
1056    #[test]
1057    fn test_reshape_degenerate_dimension() {
1058        // Degenerate dimension should remain unchanged.
1059        let original = Slice::new_row_major(vec![1, 12]);
1060        let reshaped = original.reshape_with_limit(Limit::from(4));
1061
1062        assert_eq!(
1063            factor_dims(original.sizes(), Limit::from(4)),
1064            vec![vec![1], vec![4, 3]]
1065        );
1066        assert_eq!(reshaped.sizes(), &[1, 4, 3]);
1067
1068        assert_layout_preserved!(&original, &reshaped);
1069    }
1070
1071    #[test]
1072    fn test_select_then_reshape() {
1073        // Original shape: 2 zones, 3 hosts, 4 gpus
1074        let original = shape!(zone = 2, host = 3, gpu = 4);
1075
1076        // Select the zone=1 plane: shape becomes [1, 3, 4]
1077        let selected = original.select("zone", 1).unwrap();
1078        assert_eq!(selected.slice().offset(), 12); // Nonzero offset.
1079        assert_eq!(selected.slice().sizes(), &[1, 3, 4]);
1080
1081        // Reshape the selected slice using limit=2 in row-major
1082        // layout.
1083        let reshaped = selected.slice().reshape_with_limit(Limit::from(2));
1084
1085        assert_eq!(
1086            factor_dims(selected.slice().sizes(), Limit::from(2)),
1087            vec![vec![1], vec![3], vec![2, 2]]
1088        );
1089        assert_eq!(reshaped.sizes(), &[1, 3, 2, 2]);
1090        assert_eq!(reshaped.strides(), &[12, 4, 2, 1]);
1091        assert_eq!(reshaped.offset(), 12); // Offset verified preserved.
1092
1093        assert_layout_preserved!(selected.slice(), &reshaped);
1094    }
1095
1096    #[test]
1097    fn test_select_host_plane_then_reshape() {
1098        // Original shape: 2 zones, 3 hosts, 4 gpus.
1099        let original = shape!(zone = 2, host = 3, gpu = 4);
1100        // Select the host=2 plane: shape becomes [2, 1, 4].
1101        let selected = original.select("host", 2).unwrap();
1102        // Reshape the selected slice using limit=2 in row-major
1103        // layout.
1104        let reshaped = selected.slice().reshape_with_limit(Limit::from(2));
1105
1106        assert_layout_preserved!(selected.slice(), &reshaped);
1107    }
1108
1109    #[test]
1110    fn test_reshape_after_select_no_factoring_due_to_primes() {
1111        // Original shape: 3 zones, 4 hosts, 5 gpus
1112        let original = shape!(zone = 3, host = 4, gpu = 5);
1113        // First select: fix zone = 1 → shape: [1, 4, 5].
1114        let selected_zone = original.select("zone", 1).unwrap();
1115        assert_eq!(selected_zone.slice().sizes(), &[1, 4, 5]);
1116        // Second select: fix host = 2 → shape: [1, 1, 5].
1117        let selected_host = selected_zone.select("host", 2).unwrap();
1118        assert_eq!(selected_host.slice().sizes(), &[1, 1, 5]);
1119        // Reshape with limit = 2.
1120        let reshaped = selected_host.slice().reshape_with_limit(Limit::from(2));
1121
1122        assert_eq!(
1123            factor_dims(selected_host.slice().sizes(), Limit::from(2)),
1124            vec![vec![1], vec![1], vec![5]]
1125        );
1126        assert_eq!(reshaped.sizes(), &[1, 1, 5]);
1127
1128        assert_layout_preserved!(selected_host.slice(), &reshaped);
1129    }
1130
1131    #[test]
1132    fn test_reshape_after_multiple_selects_triggers_factoring() {
1133        // Original shape: 2 zones, 4 hosts, 8 gpus
1134        let original = shape!(zone = 2, host = 4, gpu = 8);
1135        // Select zone=1 → shape: [1, 4, 8]
1136        let selected_zone = original.select("zone", 1).unwrap();
1137        assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
1138
1139        // Select host=2 → shape: [1, 1, 8]
1140        let selected_host = selected_zone.select("host", 2).unwrap();
1141        assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
1142
1143        // Reshape with limit = 2 → gpu=8 should factor
1144        let reshaped = selected_host.slice().reshape_with_limit(Limit::from(2));
1145
1146        assert_eq!(
1147            factor_dims(selected_host.slice().sizes(), Limit::from(2)),
1148            vec![vec![1], vec![1], vec![2, 2, 2]]
1149        );
1150        assert_eq!(reshaped.sizes(), &[1, 1, 2, 2, 2]);
1151
1152        assert_layout_preserved!(selected_host.slice(), &reshaped);
1153    }
1154
1155    #[test]
1156    fn test_expand_labels_singleton_dims() {
1157        let factors = vec![("x".into(), vec![2]), ("y".into(), vec![4])];
1158        let expected = vec!["x", "y"];
1159        assert_eq!(expand_labels(&factors), expected);
1160    }
1161
1162    #[test]
1163    fn test_expand_labels_factored_dims() {
1164        let factors = vec![("gpu".into(), vec![2, 2, 2])];
1165        let expected = vec!["gpu/0", "gpu/1", "gpu/2"];
1166        assert_eq!(expand_labels(&factors), expected);
1167    }
1168
1169    #[test]
1170    fn test_expand_labels_mixed_dims() {
1171        let factors = vec![("zone".into(), vec![2]), ("gpu".into(), vec![2, 2])];
1172        let expected = vec!["zone", "gpu/0", "gpu/1"];
1173        assert_eq!(expand_labels(&factors), expected);
1174    }
1175
1176    #[test]
1177    fn test_expand_labels_empty() {
1178        let factors: Vec<(String, Vec<usize>)> = vec![];
1179        let expected: Vec<String> = vec![];
1180        assert_eq!(expand_labels(&factors), expected);
1181    }
1182
1183    #[test]
1184    fn test_reshape_shape_noop() {
1185        let shape = shape!(x = 4, y = 8);
1186        let reshaped = reshape_shape(&shape, Limit::from(8));
1187        assert_eq!(reshaped.shape.labels(), &["x", "y"]);
1188        assert_eq!(reshaped.shape.slice(), shape.slice());
1189    }
1190
1191    #[test]
1192    fn test_reshape_shape_factored() {
1193        let shape = shape!(gpu = 8);
1194        let reshaped = reshape_shape(&shape, Limit::from(2));
1195        assert_eq!(reshaped.shape.labels(), &["gpu/0", "gpu/1", "gpu/2"]);
1196        assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2]);
1197
1198        let expected = shape.slice().reshape_with_limit(Limit::from(2));
1199        assert_eq!(reshaped.shape.slice(), &expected);
1200    }
1201
1202    #[test]
1203    fn test_reshape_shape_singleton() {
1204        let shape = shape!(x = 3);
1205        let reshaped = reshape_shape(&shape, Limit::from(8));
1206        assert_eq!(reshaped.shape.labels(), &["x"]);
1207        assert_eq!(reshaped.shape.slice(), shape.slice());
1208    }
1209
1210    #[test]
1211    fn test_reshape_shape_prime_exceeds_limit() {
1212        let shape = shape!(x = 11);
1213        let reshaped = reshape_shape(&shape, Limit::from(5));
1214        assert_eq!(reshaped.shape.labels(), &["x"]);
1215        assert_eq!(reshaped.shape.slice(), shape.slice());
1216    }
1217
1218    #[test]
1219    fn test_reshape_shape_mixed_dims() {
1220        let shape = shape!(zone = 2, gpu = 8);
1221        let reshaped = reshape_shape(&shape, Limit::from(2));
1222        assert_eq!(
1223            reshaped.shape.labels(),
1224            &["zone", "gpu/0", "gpu/1", "gpu/2"]
1225        );
1226        assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2, 2]);
1227
1228        let expected = shape.slice().reshape_with_limit(Limit::from(2));
1229        assert_eq!(reshaped.shape.slice(), &expected);
1230    }
1231
1232    #[test]
1233    fn test_reshape_shape_after_selects() {
1234        // Original shape: 2 zones, 4 hosts, 8 gpus
1235        let original = shape!(zone = 2, host = 4, gpu = 8);
1236
1237        // Select zone=1 → shape: [1, 4, 8]
1238        let selected_zone = original.select("zone", 1).unwrap();
1239        assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
1240
1241        // Select host=2 → shape: [1, 1, 8]
1242        let selected_host = selected_zone.select("host", 2).unwrap();
1243        assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
1244
1245        // Reshape shape through high-level API
1246        let reshaped = reshape_shape(&selected_host, Limit::from(2));
1247
1248        // Labels should be: zone, host, gpu/0, gpu/1, gpu/2
1249        assert_eq!(
1250            reshaped.shape.labels(),
1251            &["zone", "host", "gpu/0", "gpu/1", "gpu/2"]
1252        );
1253
1254        // Sizes should reflect factored GPU dimension
1255        assert_eq!(reshaped.shape.slice().sizes(), &[1, 1, 2, 2, 2]);
1256
1257        // Check against low-level equivalent reshaped slice
1258        let expected = selected_host.slice().reshape_with_limit(Limit::from(2));
1259        assert_eq!(reshaped.shape.slice(), &expected);
1260    }
1261
1262    use std::collections::BTreeSet;
1263
1264    use proptest::prelude::*;
1265
1266    use crate::selection::EvalOpts;
1267    use crate::strategy::gen_selection;
1268    use crate::strategy::gen_slice;
1269
1270    proptest! {
1271        #![proptest_config(ProptestConfig {
1272            cases: 100,
1273            ..ProptestConfig::default()
1274        })]
1275        #[test]
1276        #[cfg_attr(not(fbcode_build), ignore)]
1277        fn test_reshape_selection((slice, fanout_limit, selection) in gen_slice(4, 64).prop_flat_map(|slice| {
1278                let shape = slice.sizes().to_vec();
1279                let max_dimension_size = *slice.sizes().iter().max().unwrap();
1280                (Just(slice), 1..=max_dimension_size, gen_selection(4, shape, 0))
1281        })) {
1282            let original_selected_ranks = selection
1283                .eval(&EvalOpts::strict(), &slice)
1284                .unwrap()
1285                .collect::<BTreeSet<_>>();
1286
1287            let reshaped_slice = reshape_with_limit(&slice, Limit::from(fanout_limit));
1288            let reshaped_selection = reshape_selection(selection, &slice, &reshaped_slice).ok().unwrap();
1289
1290            let folded_selected_ranks = reshaped_selection
1291                .eval(&EvalOpts::strict(), &reshaped_slice)?
1292                .collect::<BTreeSet<_>>();
1293
1294            prop_assert_eq!(original_selected_ranks, folded_selected_ranks);
1295        }
1296    }
1297}