ndslice/
reshape.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Dimensional reshaping of slices and shapes.
10//!
11//! This module defines utilities for transforming a [`Slice`] or
12//! [`Shape`] by factoring large extents into smaller ones under a
13//! given limit. The result is a reshaped view with increased
14//! dimensionality and preserved memory layout.
15//!
16//! This is useful for hierarchical routing, structured fanout, and
17//! other multidimensional layout transformations.
18//!
19//! For [`Shape`]s, reshaping also expands dimension labels using a
20//! `label/N` naming convention, preserving the semantics of the
21//! original shape in the reshaped reshape_with_limit.
22//!
23//! See [`reshape_with_limit`] and [`reshape_shape`] for entry points.
24use std::fmt;
25
26use crate::shape::Shape;
27use crate::slice::Slice;
28
29/// Coordinate vector used throughout reshape logic. Semantically
30/// represents a point in multidimensional space.
31pub type Coord = Vec<usize>;
32
33/// A reshaped version of a `Shape`, with factored dimensions and
34/// updated labels.
35///
36/// This type preserves coordinate bijections with the original shape
37/// and provides access to the transformed layout and label mappings.
38pub struct ReshapedShape {
39    /// The reshaped shape, with new labels and underlying factored
40    /// slice.
41    pub shape: Shape,
42
43    /// For each original dimension label, the list of sizes it was
44    /// split into.
45    pub factors: Vec<(String, Vec<usize>)>,
46}
47
48#[allow(dead_code)]
49const _: () = {
50    fn assert<T: Send + Sync + 'static>() {}
51    let _ = assert::<ReshapedShape>;
52};
53
54impl std::fmt::Debug for ReshapedShape {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        f.debug_struct("ReshapedShape")
57            .field("labels", &self.shape.labels())
58            .field("sizes", &self.shape.slice().sizes())
59            .field("strides", &self.shape.slice().strides())
60            .field("offset", &self.shape.slice().offset())
61            .field("factors", &self.factors)
62            .finish()
63    }
64}
65
66impl std::fmt::Display for ReshapedShape {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
68        write!(
69            f,
70            "ReshapedShape {{ [off={} sz={:?} st={:?} lab={:?} fac={:?}] }}",
71            self.shape.slice().offset(),
72            self.shape.slice().sizes(),
73            self.shape.slice().strides(),
74            self.shape.labels(),
75            self.factors
76        )
77    }
78}
79
80/// Returns, for each size, a list of factors that respect the given
81/// limit. If a size is ≤ limit, it is returned as a singleton.
82/// Otherwise, it is factored greedily using divisors ≤ limit, from
83/// largest to smallest.
84///
85/// For best results, dimensions should be chosen to allow factoring
86/// into small values under the selected limit (e.g., ≤ 32).
87/// Large prime numbers cannot be broken down and will remain as-is,
88/// limiting reshaping potential.
89///
90/// Prefer powers of 2 or other highly composite numbers
91/// (e.g., 8, 16, 32, 60, 120) over large primes (e.g., 17, 37, 113)
92/// when designing shapes intended for reshaping.
93pub(crate) fn factor_dims(sizes: &[usize], limit: Limit) -> Vec<Vec<usize>> {
94    let limit = limit.get();
95    sizes
96        .iter()
97        .map(|&size| {
98            if size <= limit {
99                return vec![size];
100            }
101            let mut rem = size;
102            let mut factors = Vec::new();
103            for d in (2..=limit).rev() {
104                while rem % d == 0 {
105                    factors.push(d);
106                    rem /= d;
107                }
108            }
109            if rem > 1 {
110                factors.push(rem);
111            }
112            factors
113        })
114        .collect()
115}
116
117/// Constructs a function that maps coordinates from the original
118/// slice to equivalent coordinates in the reshaped slice, preserving
119/// their flat (linear) position.
120pub fn to_reshaped_coord<'a>(
121    original: &'a Slice,
122    reshaped: &'a Slice,
123) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
124    let original = original.clone();
125    let reshaped = reshaped.clone();
126    move |coord: &[usize]| -> Coord {
127        let flat = original.location(coord).unwrap();
128        reshaped.coordinates(flat).unwrap()
129    }
130}
131
132/// Constructs a function that maps coordinates from the reshaped
133/// slice back to equivalent coordinates in the original slice,
134/// preserving their flat (linear) position.
135pub fn to_original_coord<'a>(
136    reshaped: &'a Slice,
137    original: &'a Slice,
138) -> impl Fn(&[usize]) -> Vec<usize> + 'a {
139    let reshaped = reshaped.clone();
140    let original = original.clone();
141    move |coord: &[usize]| -> Coord {
142        let flat = reshaped.location(coord).unwrap();
143        original.coordinates(flat).unwrap()
144    }
145}
146
147/// A shaping constraint that bounds the maximum extent allowed in any
148/// reshaped dimension.
149///
150/// This limit controls how a given dimension is factored during
151/// reshaping. Values larger than `limit` are recursively decomposed
152/// into smaller factors (e.g., `reshape_with_limit([1024],
153/// Limit::new(32))` → `[32, 32]`).
154///
155/// The default limit is `32`, which balances fanout depth and layout
156/// regularity.
157///
158/// # Example
159/// ```
160/// use ndslice::reshape::Limit;
161/// let limit = Limit::new(64);
162/// assert_eq!(limit.get(), 64);
163/// ```
164#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
165pub struct Limit(usize);
166
167impl Limit {
168    /// Creates a new `Limit`. Panics if less than 1.
169    pub fn new(n: usize) -> Self {
170        assert!(n >= 1, "Limit must be at least 1");
171        Self(n)
172    }
173
174    /// Returns the inner value.
175    pub fn get(self) -> usize {
176        self.0
177    }
178}
179
180impl Default for Limit {
181    fn default() -> Self {
182        Self(32)
183    }
184}
185
186impl From<usize> for Limit {
187    fn from(n: usize) -> Self {
188        Self::new(n)
189    }
190}
191
192/// A trait for types that can be reshaped into a higher-dimensional
193/// view by factoring large extents into smaller ones.
194///
195/// This is implemented for [`Slice`], enabling ergonomic access to
196/// [`reshape_with_limit`] as a method.
197///
198/// # Example
199/// ```
200/// use ndslice::Slice;
201/// use ndslice::reshape::Limit;
202/// use ndslice::reshape::ReshapeSliceExt;
203///
204/// let slice = Slice::new_row_major(vec![1024]);
205/// let reshaped = slice.reshape_with_limit(Limit::new(32));
206/// assert_eq!(reshaped.sizes(), &[32, 32]);
207/// ```
208/// # Returns
209/// A reshaped [`Slice`] with increased dimensionality and preserved
210/// layout.
211pub trait ReshapeSliceExt {
212    /// Returns a reshaped version of this structure by factoring each
213    /// dimension into smaller extents no greater than `limit`,
214    /// preserving memory layout and flat index semantics. See
215    /// [`reshape_with_limit`] for full behavior and rationale.
216    ///
217    /// # Arguments
218    /// - `limit`: maximum size allowed in any reshaped dimension
219    ///
220    /// # Returns
221    /// A reshaped [`Slice`] with increased dimensionality and a
222    /// bijective mapping to the original.
223    fn reshape_with_limit(&self, limit: Limit) -> Slice;
224}
225
226impl ReshapeSliceExt for Slice {
227    fn reshape_with_limit(&self, limit: Limit) -> Slice {
228        reshape_with_limit(self, limit)
229    }
230}
231
232/// Extension trait for reshaping `Shape`s by factoring large dimensions.
233pub trait ReshapeShapeExt {
234    /// Produces a reshaped version of the shape with expanded
235    /// dimensions under the given size limit.
236    fn reshape(&self, limit: Limit) -> ReshapedShape;
237}
238
239impl ReshapeShapeExt for Shape {
240    fn reshape(&self, limit: Limit) -> ReshapedShape {
241        reshape_shape(self, limit)
242    }
243}
244
245/// For convenient `slice.reshape_with_limit()`, `shape.reshape()`
246/// syntax, `use reshape::prelude::*`.
247pub mod prelude {
248    pub use super::ReshapeShapeExt;
249    pub use super::ReshapeSliceExt;
250}
251
252/// Reshapes a slice by factoring each dimension into smaller extents
253/// under the given limit.
254///
255/// This transformation increases dimensionality by breaking large
256/// sizes into products of smaller factors (e.g., `[1024]` with limit
257/// 32 becomes `[32, 32]`). The result is a new [`Slice`] that
258/// preserves memory layout and flat index semantics.
259///
260/// Factoring is greedy, starting from the largest divisors ≤ `limit`.
261/// Dimensions that cannot be factored under the limit are left
262/// unchanged.
263///
264/// # Arguments
265/// - `slice`: the original multidimensional slice
266/// - `limit`: maximum extent allowed in any factored subdimension
267///
268/// # Returns
269/// A reshaped [`Slice`] with updated sizes and strides.
270///
271/// # Example
272/// ```
273/// use ndslice::Slice;
274/// use ndslice::reshape::Limit;
275/// use ndslice::reshape::reshape_with_limit;
276///
277/// let slice = Slice::new_row_major(vec![1024]);
278/// let reshaped = reshape_with_limit(&slice, Limit::new(32));
279/// assert_eq!(reshaped.sizes(), &[32, 32]);
280/// ```
281pub fn reshape_with_limit(slice: &Slice, limit: Limit) -> Slice {
282    let orig_sizes = slice.sizes();
283    let orig_strides = slice.strides();
284
285    // Step 1: Factor each size into subdimensions ≤ limit.
286    let factored_sizes = factor_dims(orig_sizes, limit);
287
288    // Step 2: Compute reshaped sizes and strides (row-major only).
289    let reshaped_sizes: Vec<usize> = factored_sizes.iter().flatten().cloned().collect();
290    let mut reshaped_strides = Vec::with_capacity(reshaped_sizes.len());
291
292    for (&orig_stride, factors) in orig_strides.iter().zip(&factored_sizes) {
293        let mut sub_strides = Vec::with_capacity(factors.len());
294        let mut stride = orig_stride;
295        for &f in factors.iter().rev() {
296            sub_strides.push(stride);
297            stride *= f;
298        }
299        sub_strides.reverse();
300        reshaped_strides.extend(sub_strides);
301    }
302
303    Slice::new(slice.offset(), reshaped_sizes, reshaped_strides).unwrap()
304}
305
306/// Reshapes a labeled [`Shape`] by factoring large extents into
307/// smaller ones, producing a new shape with expanded dimensionality
308/// and updated labels.
309///
310/// This uses [`reshape_with_limit`] on the underlying slice and [`expand_labels`]
311/// to generate labels for each factored dimension.
312///
313/// # Arguments
314/// - `shape`: the labeled shape to reshape
315/// - `limit`: maximum extent allowed per factored dimension
316///
317/// # Returns
318/// A new [`ReshapedShape`] with an updated [`Shape`] and dimension
319/// factoring metadata.
320///
321/// # Panics
322/// Panics if constructing the new `Shape` fails. This should not
323/// occur unless the reshaped slice and labels are inconsistent (a
324/// programming logic error).
325pub fn reshape_shape(shape: &Shape, limit: Limit) -> ReshapedShape {
326    let reshaped_slice = shape.slice().reshape_with_limit(limit);
327    let original_labels = shape.labels();
328    let original_sizes = shape.slice().sizes();
329
330    let factors = factor_dims(original_sizes, limit);
331    let factored_dims: Vec<(String, Vec<usize>)> =
332        original_labels.iter().cloned().zip(factors).collect();
333
334    let labels = expand_labels(&factored_dims);
335    let shape = Shape::new(labels, reshaped_slice).expect("invalid reshaped shape");
336
337    ReshapedShape {
338        shape,
339        factors: factored_dims,
340    }
341}
342
343/// Expands factored dimension labels into one label per subdimension.
344///
345/// Each input pair `(label, factors)` represents an original
346/// dimension and the extents it was factored into. If a dimension was
347/// not factored, it will have a single-element vector.
348///
349/// For example:
350/// - `[("zone", vec![2]), ("gpu", vec![2, 2, 2])]`
351///   becomes `["zone", "gpu/0", "gpu/1", "gpu/2"]`
352///
353/// This is used to generate new labels for reshaped shapes, where the
354/// dimensionality increases due to factoring.
355///
356/// # Arguments
357/// - `factors`: a list of factored dimension extents, paired with
358///   their labels
359///
360/// # Returns
361/// - A `Vec<String>` of expanded labels, one for each reshaped
362///   dimension.
363pub fn expand_labels(factors: &[(String, Vec<usize>)]) -> Vec<String> {
364    let mut labels = Vec::new();
365    for (label, dims) in factors {
366        if dims.len() == 1 {
367            labels.push(label.clone());
368        } else {
369            for (i, _) in dims.iter().enumerate() {
370                labels.push(format!("{}/{}", label, i));
371            }
372        }
373    }
374    labels
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380    use crate::Slice;
381    use crate::shape;
382
383    #[test]
384    fn test_factor_dims_basic() {
385        assert_eq!(
386            factor_dims(&[6, 8], Limit::from(4)),
387            vec![vec![3, 2], vec![4, 2]]
388        );
389        assert_eq!(factor_dims(&[5], Limit::from(3)), vec![vec![5]]);
390        assert_eq!(factor_dims(&[30], Limit::from(5)), vec![vec![5, 3, 2]]);
391    }
392
393    // Verify that reshaping preserves memory layout by checking:
394    // 1. Coordinate round-tripping: original → reshaped → original
395    // 2. Flat index equality: original and reshaped coordinates map
396    //    to the same linear index
397    // 3. Index inversion: reshaped flat index maps back to the same
398    //    reshaped coordinate
399    //
400    // Together, these checks ensure that the reshaped view is
401    // layout-preserving and provides a bijective mapping between
402    // coordinate systems.
403    #[macro_export]
404    macro_rules! assert_layout_preserved {
405        ($original:expr_2021, $reshaped:expr_2021) => {{
406            // Iterate over all coordinates in the original slice.
407            for coord in $original.dim_iter($original.num_dim()) {
408                let forward = to_reshaped_coord($original, &$reshaped);
409                let inverse = to_original_coord(&$reshaped, $original);
410                // Apply the forward coordinate mapping from original
411                // to reshaped space.
412                let reshaped_coord = forward(&coord);
413                // Inverse mapping: reshaped coord → original coord.
414                let roundtrip = inverse(&reshaped_coord);
415                assert_eq!(
416                    roundtrip, coord,
417                    "Inverse mismatch: reshaped {:?} → original {:?}, expected {:?}",
418                    reshaped_coord, roundtrip, coord
419                );
420                // Compute flat index in the original slice.
421                let flat_orig = $original.location(&coord).unwrap();
422                // Compute flat index in the reshaped slice.
423                let flat_reshaped = $reshaped.location(&reshaped_coord).unwrap();
424                // Check that the flat index is preserved by the
425                // reshaping.
426                assert_eq!(
427                    flat_orig, flat_reshaped,
428                    "Flat index mismatch: original {:?} → reshaped {:?}",
429                    coord, reshaped_coord
430                );
431                // Invert the reshaped flat index back to coordinates.
432                let recovered = $reshaped.coordinates(flat_reshaped).unwrap();
433                // Ensure coordinate inversion is correct (round
434                // trip).
435                assert_eq!(
436                    reshaped_coord, recovered,
437                    "Coordinate mismatch: flat index {} → expected {:?}, got {:?}",
438                    flat_reshaped, reshaped_coord, recovered
439                );
440            }
441        }};
442    }
443
444    #[test]
445    fn test_reshape_split_1d_row_major() {
446        let s = Slice::new_row_major(vec![1024]);
447        let reshaped = s.reshape_with_limit(Limit::from(8));
448
449        assert_eq!(reshaped.offset(), 0);
450        assert_eq!(reshaped.sizes(), &vec![8, 8, 8, 2]);
451        assert_eq!(reshaped.strides(), &vec![128, 16, 2, 1]);
452        assert_eq!(
453            factor_dims(s.sizes(), Limit::from(8)),
454            vec![vec![8, 8, 8, 2]]
455        );
456
457        assert_layout_preserved!(&s, &reshaped);
458    }
459
460    #[test]
461    fn test_reshape_6_with_limit_2() {
462        let s = Slice::new_row_major(vec![6]);
463        let reshaped = reshape_with_limit(&s, Limit::from(2));
464        assert_eq!(factor_dims(s.sizes(), Limit::from(2)), vec![vec![2, 3]]);
465        assert_layout_preserved!(&s, &reshaped);
466    }
467
468    #[test]
469    fn test_reshape_identity_noop_2d() {
470        // All dimensions ≤ limit.
471        let original = Slice::new_row_major(vec![4, 8]);
472        let reshaped = original.reshape_with_limit(Limit::from(8));
473
474        assert_eq!(reshaped.sizes(), original.sizes());
475        assert_eq!(reshaped.strides(), original.strides());
476        assert_eq!(reshaped.offset(), original.offset());
477        assert_eq!(
478            vec![vec![4], vec![8]],
479            original
480                .sizes()
481                .iter()
482                .map(|&n| vec![n])
483                .collect::<Vec<_>>()
484        );
485        assert_layout_preserved!(&original, &reshaped);
486    }
487
488    #[test]
489    fn test_reshape_empty_slice() {
490        // 0-dimensional slice.
491        let original = Slice::new_row_major(vec![]);
492        let reshaped = reshape_with_limit(&original, Limit::from(8));
493
494        assert_eq!(reshaped.sizes(), original.sizes());
495        assert_eq!(reshaped.strides(), original.strides());
496        assert_eq!(reshaped.offset(), original.offset());
497
498        assert_layout_preserved!(&original, &reshaped);
499    }
500
501    #[test]
502    fn test_reshape_mixed_dims_3d() {
503        // 3D slice with one dimension exceeding the limit.
504        let original = Slice::new_row_major(vec![6, 8, 10]);
505        let reshaped = original.reshape_with_limit(Limit::from(4));
506
507        assert_eq!(
508            factor_dims(original.sizes(), Limit::from(4)),
509            vec![vec![3, 2], vec![4, 2], vec![2, 5]]
510        );
511        assert_eq!(reshaped.sizes(), &[3, 2, 4, 2, 2, 5]);
512
513        assert_layout_preserved!(&original, &reshaped);
514    }
515
516    #[test]
517    fn test_reshape_all_large_dims() {
518        // 3D slice with all dimensions exceeding the limit.
519        let original = Slice::new_row_major(vec![12, 18, 20]);
520        let reshaped = original.reshape_with_limit(Limit::from(4));
521
522        assert_eq!(
523            factor_dims(original.sizes(), Limit::from(4)),
524            vec![vec![4, 3], vec![3, 3, 2], vec![4, 5]]
525        );
526        assert_eq!(reshaped.sizes(), &[4, 3, 3, 3, 2, 4, 5]);
527
528        assert_layout_preserved!(&original, &reshaped);
529    }
530
531    #[test]
532    fn test_reshape_split_1d_factors_3_3_2_2() {
533        // 36 = 3 × 3 × 2 × 2.
534        let original = Slice::new_row_major(vec![36]);
535        let reshaped = reshape_with_limit(&original, Limit::from(3));
536
537        assert_eq!(
538            factor_dims(original.sizes(), Limit::from(3)),
539            vec![vec![3, 3, 2, 2]]
540        );
541        assert_eq!(reshaped.sizes(), &[3, 3, 2, 2]);
542        assert_layout_preserved!(&original, &reshaped);
543    }
544
545    #[test]
546    fn test_reshape_large_prime_dimension() {
547        // Prime larger than limit, cannot be factored.
548        let original = Slice::new_row_major(vec![7]);
549        let reshaped = reshape_with_limit(&original, Limit::from(4));
550
551        // Should remain as-is since 7 is prime > 4
552        assert_eq!(factor_dims(original.sizes(), Limit::from(4)), vec![vec![7]]);
553        assert_eq!(reshaped.sizes(), &[7]);
554
555        assert_layout_preserved!(&original, &reshaped);
556    }
557
558    #[test]
559    fn test_reshape_split_1d_factors_5_3_2() {
560        // 30 = 5 × 3 × 2, all ≤ limit.
561        let original = Slice::new_row_major(vec![30]);
562        let reshaped = reshape_with_limit(&original, Limit::from(5));
563
564        assert_eq!(
565            factor_dims(original.sizes(), Limit::from(5)),
566            vec![vec![5, 3, 2]]
567        );
568        assert_eq!(reshaped.sizes(), &[5, 3, 2]);
569        assert_eq!(reshaped.strides(), &[6, 2, 1]);
570
571        assert_layout_preserved!(&original, &reshaped);
572    }
573
574    #[test]
575    fn test_reshape_factors_2_6_2_8_8() {
576        // 12 = 6 × 2, 64 = 8 × 8 — all ≤ 8
577        let original = Slice::new_row_major(vec![2, 12, 64]);
578        let reshaped = original.reshape_with_limit(Limit::from(8));
579
580        assert_eq!(
581            factor_dims(original.sizes(), Limit::from(8)),
582            vec![vec![2], vec![6, 2], vec![8, 8]]
583        );
584        assert_eq!(reshaped.sizes(), &[2, 6, 2, 8, 8]);
585        assert_eq!(reshaped.strides(), &[768, 128, 64, 8, 1]);
586
587        assert_layout_preserved!(&original, &reshaped);
588    }
589
590    #[test]
591    fn test_reshape_all_dims_within_limit() {
592        // Original shape: [2, 3, 4] — all ≤ limit (4).
593        let original = Slice::new_row_major(vec![2, 3, 4]);
594        let reshaped = original.reshape_with_limit(Limit::from(4));
595
596        assert_eq!(
597            factor_dims(original.sizes(), Limit::from(4)),
598            vec![vec![2], vec![3], vec![4]]
599        );
600        assert_eq!(reshaped.sizes(), &[2, 3, 4]);
601        assert_eq!(reshaped.strides(), original.strides());
602        assert_eq!(reshaped.offset(), original.offset());
603
604        assert_layout_preserved!(&original, &reshaped);
605    }
606
607    #[test]
608    fn test_reshape_degenerate_dimension() {
609        // Degenerate dimension should remain unchanged.
610        let original = Slice::new_row_major(vec![1, 12]);
611        let reshaped = original.reshape_with_limit(Limit::from(4));
612
613        assert_eq!(
614            factor_dims(original.sizes(), Limit::from(4)),
615            vec![vec![1], vec![4, 3]]
616        );
617        assert_eq!(reshaped.sizes(), &[1, 4, 3]);
618
619        assert_layout_preserved!(&original, &reshaped);
620    }
621
622    #[test]
623    fn test_select_then_reshape() {
624        // Original shape: 2 zones, 3 hosts, 4 gpus
625        let original = shape!(zone = 2, host = 3, gpu = 4);
626
627        // Select the zone=1 plane: shape becomes [1, 3, 4]
628        let selected = original.select("zone", 1).unwrap();
629        assert_eq!(selected.slice().offset(), 12); // Nonzero offset.
630        assert_eq!(selected.slice().sizes(), &[1, 3, 4]);
631
632        // Reshape the selected slice using limit=2 in row-major
633        // layout.
634        let reshaped = selected.slice().reshape_with_limit(Limit::from(2));
635
636        assert_eq!(
637            factor_dims(selected.slice().sizes(), Limit::from(2)),
638            vec![vec![1], vec![3], vec![2, 2]]
639        );
640        assert_eq!(reshaped.sizes(), &[1, 3, 2, 2]);
641        assert_eq!(reshaped.strides(), &[12, 4, 2, 1]);
642        assert_eq!(reshaped.offset(), 12); // Offset verified preserved.
643
644        assert_layout_preserved!(selected.slice(), &reshaped);
645    }
646
647    #[test]
648    fn test_select_host_plane_then_reshape() {
649        // Original shape: 2 zones, 3 hosts, 4 gpus.
650        let original = shape!(zone = 2, host = 3, gpu = 4);
651        // Select the host=2 plane: shape becomes [2, 1, 4].
652        let selected = original.select("host", 2).unwrap();
653        // Reshape the selected slice using limit=2 in row-major
654        // layout.
655        let reshaped = selected.slice().reshape_with_limit(Limit::from(2));
656
657        assert_layout_preserved!(selected.slice(), &reshaped);
658    }
659
660    #[test]
661    fn test_reshape_after_select_no_factoring_due_to_primes() {
662        // Original shape: 3 zones, 4 hosts, 5 gpus
663        let original = shape!(zone = 3, host = 4, gpu = 5);
664        // First select: fix zone = 1 → shape: [1, 4, 5].
665        let selected_zone = original.select("zone", 1).unwrap();
666        assert_eq!(selected_zone.slice().sizes(), &[1, 4, 5]);
667        // Second select: fix host = 2 → shape: [1, 1, 5].
668        let selected_host = selected_zone.select("host", 2).unwrap();
669        assert_eq!(selected_host.slice().sizes(), &[1, 1, 5]);
670        // Reshape with limit = 2.
671        let reshaped = selected_host.slice().reshape_with_limit(Limit::from(2));
672
673        assert_eq!(
674            factor_dims(selected_host.slice().sizes(), Limit::from(2)),
675            vec![vec![1], vec![1], vec![5]]
676        );
677        assert_eq!(reshaped.sizes(), &[1, 1, 5]);
678
679        assert_layout_preserved!(selected_host.slice(), &reshaped);
680    }
681
682    #[test]
683    fn test_reshape_after_multiple_selects_triggers_factoring() {
684        // Original shape: 2 zones, 4 hosts, 8 gpus
685        let original = shape!(zone = 2, host = 4, gpu = 8);
686        // Select zone=1 → shape: [1, 4, 8]
687        let selected_zone = original.select("zone", 1).unwrap();
688        assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
689
690        // Select host=2 → shape: [1, 1, 8]
691        let selected_host = selected_zone.select("host", 2).unwrap();
692        assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
693
694        // Reshape with limit = 2 → gpu=8 should factor
695        let reshaped = selected_host.slice().reshape_with_limit(Limit::from(2));
696
697        assert_eq!(
698            factor_dims(selected_host.slice().sizes(), Limit::from(2)),
699            vec![vec![1], vec![1], vec![2, 2, 2]]
700        );
701        assert_eq!(reshaped.sizes(), &[1, 1, 2, 2, 2]);
702
703        assert_layout_preserved!(selected_host.slice(), &reshaped);
704    }
705
706    #[test]
707    fn test_expand_labels_singleton_dims() {
708        let factors = vec![("x".into(), vec![2]), ("y".into(), vec![4])];
709        let expected = vec!["x", "y"];
710        assert_eq!(expand_labels(&factors), expected);
711    }
712
713    #[test]
714    fn test_expand_labels_factored_dims() {
715        let factors = vec![("gpu".into(), vec![2, 2, 2])];
716        let expected = vec!["gpu/0", "gpu/1", "gpu/2"];
717        assert_eq!(expand_labels(&factors), expected);
718    }
719
720    #[test]
721    fn test_expand_labels_mixed_dims() {
722        let factors = vec![("zone".into(), vec![2]), ("gpu".into(), vec![2, 2])];
723        let expected = vec!["zone", "gpu/0", "gpu/1"];
724        assert_eq!(expand_labels(&factors), expected);
725    }
726
727    #[test]
728    fn test_expand_labels_empty() {
729        let factors: Vec<(String, Vec<usize>)> = vec![];
730        let expected: Vec<String> = vec![];
731        assert_eq!(expand_labels(&factors), expected);
732    }
733
734    #[test]
735    fn test_reshape_shape_noop() {
736        let shape = shape!(x = 4, y = 8);
737        let reshaped = reshape_shape(&shape, Limit::from(8));
738        assert_eq!(reshaped.shape.labels(), &["x", "y"]);
739        assert_eq!(reshaped.shape.slice(), shape.slice());
740    }
741
742    #[test]
743    fn test_reshape_shape_factored() {
744        let shape = shape!(gpu = 8);
745        let reshaped = reshape_shape(&shape, Limit::from(2));
746        assert_eq!(reshaped.shape.labels(), &["gpu/0", "gpu/1", "gpu/2"]);
747        assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2]);
748
749        let expected = shape.slice().reshape_with_limit(Limit::from(2));
750        assert_eq!(reshaped.shape.slice(), &expected);
751    }
752
753    #[test]
754    fn test_reshape_shape_singleton() {
755        let shape = shape!(x = 3);
756        let reshaped = reshape_shape(&shape, Limit::from(8));
757        assert_eq!(reshaped.shape.labels(), &["x"]);
758        assert_eq!(reshaped.shape.slice(), shape.slice());
759    }
760
761    #[test]
762    fn test_reshape_shape_prime_exceeds_limit() {
763        let shape = shape!(x = 11);
764        let reshaped = reshape_shape(&shape, Limit::from(5));
765        assert_eq!(reshaped.shape.labels(), &["x"]);
766        assert_eq!(reshaped.shape.slice(), shape.slice());
767    }
768
769    #[test]
770    fn test_reshape_shape_mixed_dims() {
771        let shape = shape!(zone = 2, gpu = 8);
772        let reshaped = reshape_shape(&shape, Limit::from(2));
773        assert_eq!(
774            reshaped.shape.labels(),
775            &["zone", "gpu/0", "gpu/1", "gpu/2"]
776        );
777        assert_eq!(reshaped.shape.slice().sizes(), &[2, 2, 2, 2]);
778
779        let expected = shape.slice().reshape_with_limit(Limit::from(2));
780        assert_eq!(reshaped.shape.slice(), &expected);
781    }
782
783    #[test]
784    fn test_reshape_shape_after_selects() {
785        // Original shape: 2 zones, 4 hosts, 8 gpus
786        let original = shape!(zone = 2, host = 4, gpu = 8);
787
788        // Select zone=1 → shape: [1, 4, 8]
789        let selected_zone = original.select("zone", 1).unwrap();
790        assert_eq!(selected_zone.slice().sizes(), &[1, 4, 8]);
791
792        // Select host=2 → shape: [1, 1, 8]
793        let selected_host = selected_zone.select("host", 2).unwrap();
794        assert_eq!(selected_host.slice().sizes(), &[1, 1, 8]);
795
796        // Reshape shape through high-level API
797        let reshaped = reshape_shape(&selected_host, Limit::from(2));
798
799        // Labels should be: zone, host, gpu/0, gpu/1, gpu/2
800        assert_eq!(
801            reshaped.shape.labels(),
802            &["zone", "host", "gpu/0", "gpu/1", "gpu/2"]
803        );
804
805        // Sizes should reflect factored GPU dimension
806        assert_eq!(reshaped.shape.slice().sizes(), &[1, 1, 2, 2, 2]);
807
808        // Check against low-level equivalent reshaped slice
809        let expected = selected_host.slice().reshape_with_limit(Limit::from(2));
810        assert_eq!(reshaped.shape.slice(), &expected);
811    }
812}