ndslice/
slice.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::iter::zip;
10
11use serde::Deserialize;
12use serde::Serialize;
13
14/// The type of error for slice operations.
15#[derive(Debug, thiserror::Error)]
16#[non_exhaustive]
17pub enum SliceError {
18    #[error("invalid dims: expected {expected}, got {got}")]
19    InvalidDims { expected: usize, got: usize },
20
21    #[error("nonrectangular shape")]
22    NonrectangularShape,
23
24    #[error("nonunique strides")]
25    NonuniqueStrides,
26
27    #[error("stride {stride} must be larger than size of previous space {space}")]
28    StrideTooSmall { stride: usize, space: usize },
29
30    #[error("index {index} out of range {total}")]
31    IndexOutOfRange { index: usize, total: usize },
32
33    #[error("value {value} not in slice")]
34    ValueNotInSlice { value: usize },
35
36    #[error("incompatible view: {reason}")]
37    IncompatibleView { reason: String },
38
39    #[error("noncontiguous shape")]
40    NonContiguous,
41
42    #[error("empty range: {begin}..{end} (step {step})")]
43    EmptyRange {
44        begin: usize,
45        end: usize,
46        step: usize,
47    },
48
49    #[error("dimension {dim} out of range for {ndims}-dimensional slice")]
50    DimensionOutOfRange { dim: usize, ndims: usize },
51}
52
53/// Slice is a compact representation of indices into the flat
54/// representation of an n-dimensional array. Given an offset, sizes of
55/// each dimension, and strides for each dimension, Slice can compute
56/// indices into the flat array.
57///
58/// For example, the following describes a dense 4x4x4 array in row-major
59/// order:
60/// ```
61/// # use ndslice::Slice;
62/// let s = Slice::new(0, vec![4, 4, 4], vec![16, 4, 1]).unwrap();
63/// assert!(s.iter().eq(0..(4 * 4 * 4)));
64/// ```
65///
66/// Slices allow easy slicing by subsetting and striding. For example,
67/// we can fix the index of the second dimension by dropping it and
68/// adding that index (multiplied by the previous size) to the offset.
69///
70/// ```
71/// # use ndslice::Slice;
72/// let s = Slice::new(0, vec![2, 4, 2], vec![8, 2, 1]).unwrap();
73/// let selected_index = 3;
74/// let sub = Slice::new(2 * selected_index, vec![2, 2], vec![8, 1]).unwrap();
75/// let coords = [[0, 0], [0, 1], [1, 0], [1, 1]];
76/// for coord @ [x, y] in coords {
77///     assert_eq!(
78///         sub.location(&coord).unwrap(),
79///         s.location(&[x, 3, y]).unwrap()
80///     );
81/// }
82/// ```
83// TODO: Consider representing this by arrays parameterized by the slice
84// dimensionality.
85#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Hash, Debug)]
86pub struct Slice {
87    offset: usize,
88    sizes: Vec<usize>,
89    strides: Vec<usize>,
90}
91
92impl Slice {
93    /// Create a new Slice with the provided offset, sizes, and
94    /// strides. New performs validation to ensure that sizes and strides
95    /// are compatible:
96    ///   - They have to be the same length (i.e., same number of dimensions)
97    ///   - They have to be rectangular (i.e., stride n+1 has to evenly divide into stride n)
98    ///   - Strides must be nonoverlapping (each stride has to be larger than the previous space)
99    pub fn new(offset: usize, sizes: Vec<usize>, strides: Vec<usize>) -> Result<Self, SliceError> {
100        if sizes.len() != strides.len() {
101            return Err(SliceError::InvalidDims {
102                expected: sizes.len(),
103                got: strides.len(),
104            });
105        }
106        let mut combined: Vec<(usize, usize)> =
107            strides.iter().cloned().zip(sizes.iter().cloned()).collect();
108        combined.sort();
109
110        let mut prev_stride: Option<usize> = None;
111        let mut prev_size: Option<usize> = None;
112        let mut total: usize = 1;
113        for (stride, size) in combined {
114            if let Some(prev_stride) = prev_stride {
115                if stride % prev_stride != 0 {
116                    return Err(SliceError::NonrectangularShape);
117                }
118                // Strides for single element dimensions can repeat, because they are unused
119                if stride == prev_stride && size != 1 && prev_size.unwrap_or(1) != 1 {
120                    return Err(SliceError::NonuniqueStrides);
121                }
122            }
123            if total > stride {
124                return Err(SliceError::StrideTooSmall {
125                    stride,
126                    space: total,
127                });
128            }
129            total = stride * size;
130            prev_stride = Some(stride);
131            prev_size = Some(size);
132        }
133
134        Ok(Slice {
135            offset,
136            sizes,
137            strides,
138        })
139    }
140
141    /// Deconstruct the slice into its offset, sizes, and strides.
142    pub fn into_inner(self) -> (usize, Vec<usize>, Vec<usize>) {
143        let Slice {
144            offset,
145            sizes,
146            strides,
147        } = self;
148        (offset, sizes, strides)
149    }
150
151    /// Create a new slice of the given sizes in row-major order.
152    pub fn new_row_major(sizes: impl Into<Vec<usize>>) -> Self {
153        let sizes = sizes.into();
154        // "flip it and reverse it" --Missy Elliott
155        let mut strides: Vec<usize> = sizes.clone();
156        let _ = strides.iter_mut().rev().fold(1, |acc, n| {
157            let next = *n * acc;
158            *n = acc;
159            next
160        });
161        Self {
162            offset: 0,
163            sizes,
164            strides,
165        }
166    }
167
168    /// Create one celled slice.
169    pub fn new_single_multi_dim_cell(dims: usize) -> Self {
170        Self {
171            offset: 0,
172            sizes: vec![1; dims],
173            strides: vec![1; dims],
174        }
175    }
176
177    /// The number of dimensions in this slice.
178    pub fn num_dim(&self) -> usize {
179        self.sizes.len()
180    }
181
182    /// This is the offset from which the first value in the Slice begins.
183    pub fn offset(&self) -> usize {
184        self.offset
185    }
186
187    /// The shape of the slice; that is, the size of each dimension.
188    pub fn sizes(&self) -> &[usize] {
189        &self.sizes
190    }
191
192    /// The strides of the slice; that is, the distance between each
193    /// element at a given index in the underlying array.
194    pub fn strides(&self) -> &[usize] {
195        &self.strides
196    }
197
198    pub fn is_contiguous(&self) -> bool {
199        let mut expected_stride = 1;
200        for (stride, size) in zip(self.strides.iter(), self.sizes.iter()).rev() {
201            if *stride != expected_stride {
202                return false;
203            }
204            expected_stride *= *size
205        }
206        true
207    }
208
209    /// Select a single index along a dimension, removing that
210    /// dimension entirely.
211    ///
212    /// This reduces the dimensionality by 1 by "fixing" one
213    /// coordinate to a specific value. Think of it like taking a
214    /// cross-section: selecting index 2 from the first dimension of a
215    /// 3D array gives you a 2D slice, like cutting a plane from a 3D
216    /// space at a fixed position.
217    ///
218    /// This reduces the dimensionality by 1 by "fixing" one
219    /// coordinate to a specific value. The fixed coordinate's
220    /// contribution (index × stride) gets absorbed into the base
221    /// offset, while the remaining dimensions keep their original
222    /// strides unchanged - they still describe the same memory
223    /// distances between elements.
224    ///
225    /// # Example intuition
226    /// - 3D array → select `at(dim=0, index=2)` → 2D slice (like a
227    ///   plane)
228    /// - 2D matrix → select `at(dim=1, index=3)` → 1D vector (like a
229    ///   column)
230    /// - 1D vector → select `at(dim=0, index=5)` → 0D scalar (single
231    ///   element)
232    ///
233    /// # Arguments
234    /// * `dim` - The dimension index to select from
235    /// * `index` - The index within that dimension
236    ///
237    /// # Returns
238    /// A new slice with one fewer dimension
239    ///
240    /// # Errors
241    /// * `IndexOutOfRange` if `dim >= self.sizes.len()` or `index >=
242    ///   self.sizes[dim]`
243    pub fn at(&self, dim: usize, index: usize) -> Result<Self, SliceError> {
244        if dim >= self.sizes.len() {
245            return Err(SliceError::DimensionOutOfRange {
246                dim,
247                ndims: self.num_dim(),
248            });
249        }
250        if index >= self.sizes[dim] {
251            return Err(SliceError::IndexOutOfRange {
252                index,
253                total: self.sizes[dim],
254            });
255        }
256
257        let new_offset = self.offset + index * self.strides[dim];
258        let mut new_sizes = self.sizes.clone();
259        let mut new_strides = self.strides.clone();
260        new_sizes.remove(dim);
261        new_strides.remove(dim);
262        let slice = Slice::new(new_offset, new_sizes, new_strides)?;
263        Ok(slice)
264    }
265
266    /// A slice defines a **strided view**; a triple (`offset,
267    /// `sizes`, `strides`). Each coordinate maps to a flat memory
268    /// index using the formula:
269    /// ```text
270    /// index = offset + ∑ iₖ × strides[k]
271    /// ```
272    /// where `iₖ` is the coordinate in dimension `k`.
273    ///
274    /// The `select(dim, range)` operation restricts the view to a
275    /// subrange along a single dimension. It calculates a new slice
276    /// from a base slice by updating the `offset`, `sizes[dim]`, and
277    /// `strides[dim]` to describe a logically reindexed subregion:
278    /// ```text
279    /// offset       += begin × strides[dim]
280    /// sizes[dim]    = ⎡(end - begin) / step⎤
281    /// strides[dim] ×= step
282    /// ```
283    ///
284    /// This transformation preserves the strided layout and avoids
285    /// copying data. After `select`, the view behaves as if indexing
286    /// starts at zero in the selected dimension, with a new length
287    /// and stride. From the user's perspective, nothing changes;
288    /// indexing remains zero-based, and the resulting shape can be
289    /// used like any other. The transformation is internal: the
290    /// view's offset and stride absorb the selection logic.
291    pub fn select(
292        &self,
293        dim: usize,
294        begin: usize,
295        end: usize,
296        step: usize,
297    ) -> Result<Self, SliceError> {
298        if dim >= self.sizes.len() {
299            return Err(SliceError::IndexOutOfRange {
300                index: dim,
301                total: self.sizes.len(),
302            });
303        }
304        if begin >= self.sizes[dim] {
305            return Err(SliceError::IndexOutOfRange {
306                index: begin,
307                total: self.sizes[dim],
308            });
309        }
310        if end <= begin {
311            return Err(SliceError::EmptyRange { begin, end, step });
312        }
313
314        let mut offset = self.offset();
315        let mut sizes = self.sizes().to_vec();
316        let mut strides = self.strides().to_vec();
317
318        offset += begin * strides[dim];
319        // The # of elems in `begin..end` with step `step`. This is
320        // ⌈(end - begin) / stride⌉ — the number of steps that fit in
321        // the half-open interval.
322        sizes[dim] = (end - begin).div_ceil(step);
323        strides[dim] *= step;
324
325        let slice = Slice::new(offset, sizes, strides)?;
326        Ok(slice)
327    }
328
329    /// Return the location of the provided coordinates.
330    pub fn location(&self, coord: &[usize]) -> Result<usize, SliceError> {
331        if coord.len() != self.sizes.len() {
332            return Err(SliceError::InvalidDims {
333                expected: self.sizes.len(),
334                got: coord.len(),
335            });
336        }
337        Ok(self.offset
338            + coord
339                .iter()
340                .zip(&self.strides)
341                .map(|(pos, stride)| pos * stride)
342                .sum::<usize>())
343    }
344
345    /// Return the coordinates of the provided value in the n-d space of this
346    /// Slice.
347    pub fn coordinates(&self, value: usize) -> Result<Vec<usize>, SliceError> {
348        let mut pos = value
349            .checked_sub(self.offset)
350            .ok_or(SliceError::ValueNotInSlice { value })?;
351        let mut result = vec![0; self.sizes.len()];
352        let mut sorted_info: Vec<_> = self
353            .strides
354            .iter()
355            .zip(self.sizes.iter().enumerate())
356            .collect();
357        sorted_info.sort_by_key(|&(stride, _)| *stride);
358        for &(stride, (i, &size)) in sorted_info.iter().rev() {
359            let (index, new_pos) = if size > 1 {
360                (pos / stride, pos % stride)
361            } else {
362                (0, pos)
363            };
364            if index >= size {
365                return Err(SliceError::ValueNotInSlice { value });
366            }
367            result[i] = index;
368            pos = new_pos;
369        }
370        if pos != 0 {
371            return Err(SliceError::ValueNotInSlice { value });
372        }
373        Ok(result)
374    }
375
376    /// The total length of the slice's indices.
377    pub fn len(&self) -> usize {
378        self.sizes.iter().product()
379    }
380
381    pub fn is_empty(&self) -> bool {
382        self.len() == 0
383    }
384
385    /// Iterator over the slice's indices.
386    pub fn iter(&self) -> SliceIterator {
387        SliceIterator {
388            slice: self.clone(),
389            pos: CartesianIterator::new(self.sizes.clone()),
390        }
391    }
392
393    /// Iterator over sub-dimensions of the slice.
394    pub fn dim_iter(&self, dims: usize) -> DimSliceIterator {
395        DimSliceIterator {
396            pos: CartesianIterator::new(self.sizes[0..dims].to_vec()),
397        }
398    }
399
400    /// The linear index formula calculates the logical rank of a
401    /// multidimensional point in a row-major flattened array,
402    /// assuming dense gapless storage with zero offset:
403    ///
404    /// ```plain
405    ///     index := Σ(coordinate[i] × ∏(sizes[j] for j > i))
406    /// ```
407    ///
408    /// For example, given a 3x2 row-major base array B:
409    ///
410    /// ```plain
411    ///       0 1 2         1
412    /// B =   3 4 5    V =  4
413    ///       6 7 8         7
414    /// ```
415    ///
416    /// Let V be the first column of B. Then,
417    ///
418    /// ```plain
419    /// V      | loc   | index
420    /// -------+-------+------
421    /// (0, 0) |  1    | 0
422    /// (1, 0) |  4    | 1
423    /// (2, 0) |  7    | 2
424    /// ```
425    ///
426    /// # Conditions Under Which `loc = index`
427    ///
428    /// The physical offset formula computes the memory location of a
429    /// point `p` as:
430    ///
431    /// ```plain
432    /// loc := offset + Σ(coordinate[i] × stride[i])
433    /// ```
434    ///
435    /// Let the layout be dense row-major and offset = 0.
436    /// Then,
437    /// ```plain
438    /// stride[i] := ∏(sizes[j] for j > i).
439    /// ```
440    /// and substituting into the physical offset formula:
441    /// ```plain
442    ///   loc = Σ(coordinate[i] × stride[i])
443    ///       = Σ(coordinate[i] × ∏(sizes[j] for j > i))
444    ///       = index.
445    /// ```
446    ///
447    /// Thus, ∀ p = (i, j) ∈ B, loc_B(p) = index_B(p).
448    ///
449    /// # See also
450    ///
451    /// The [`get`] function performs an inverse operation: given a
452    /// logical index in row-major order, it computes the physical
453    /// memory offset according to the slice layout. So, if the layout
454    /// is row-major then `s.get(s.index(loc)) = loc`.
455    pub fn index(&self, value: usize) -> Result<usize, SliceError> {
456        let coords = self.coordinates(value)?;
457        let mut stride = 1;
458        let mut result = 0;
459
460        for (idx, size) in coords.iter().rev().zip(self.sizes.iter().rev()) {
461            result += *idx * stride;
462            stride *= size;
463        }
464
465        Ok(result)
466    }
467
468    /// Given a logical index (in row-major order), return the
469    /// physical memory offset of that element according to this
470    /// slice’s layout.
471    ///
472    /// The index is interpreted as a position in row-major traversal
473    /// that is, iterating across columns within rows. This method
474    /// converts logical row-major index to physical offset by:
475    ///
476    /// 1. Decomposing index into multidimensional coordinates
477    /// 2. Computing offset = base + Σ(coordinate[i] × stride[i])
478    ///
479    /// For example, with shape `[3, 4]` (3 rows, 4 columns) and
480    /// column-major layout:
481    ///
482    /// ```text
483    /// sizes  = [3, 4]         // rows, cols
484    /// strides = [1, 3]        // column-major: down, then right
485    ///
486    /// Logical matrix:
487    ///   A  B  C  D
488    ///   E  F  G  H
489    ///   I  J  K  L
490    ///
491    /// Memory layout:
492    /// offset 0  → [0, 0] = A
493    /// offset 1  → [1, 0] = E
494    /// offset 2  → [2, 0] = I
495    /// offset 3  → [0, 1] = B
496    /// offset 4  → [1, 1] = F
497    /// offset 5  → [2, 1] = J
498    /// offset 6  → [0, 2] = C
499    /// offset 7  → [1, 2] = G
500    /// offset 8  → [2, 2] = K
501    /// offset 9  → [0, 3] = D
502    /// offset 10 → [1, 3] = H
503    /// offset 11 → [2, 3] = L
504    ///
505    /// Then:
506    ///   index = 1  → coordinate [0, 1]  → offset = 0*1 + 1*3 = 3
507    /// ```
508    ///
509    /// # Errors
510    ///
511    /// Returns an error if `index >= product(sizes)`.
512    ///
513    /// # See also
514    ///
515    /// The [`index`] function performs an inverse operation: given a
516    /// memory offset, it returns the logical position of that element
517    /// in the slice's row-major iteration order.
518    pub fn get(&self, index: usize) -> Result<usize, SliceError> {
519        let mut val = self.offset;
520        let mut rest = index;
521        let mut total = 1;
522        for (size, stride) in self.sizes.iter().zip(self.strides.iter()).rev() {
523            total *= size;
524            val += (rest % size) * stride;
525            rest /= size;
526        }
527        if index < total {
528            Ok(val)
529        } else {
530            Err(SliceError::IndexOutOfRange { index, total })
531        }
532    }
533
534    /// The returned [`MapSlice`] is a view of this slice, with its elements
535    /// mapped using the provided mapping function.
536    pub fn map<T, F>(&self, mapper: F) -> MapSlice<'_, T, F>
537    where
538        F: Fn(usize) -> T,
539    {
540        MapSlice {
541            slice: self,
542            mapper,
543        }
544    }
545
546    /// Returns a new [`Slice`] with the given shape by reinterpreting
547    /// the layout of this slice.
548    ///
549    /// Constructs a new shape with standard row-major strides, using
550    /// the same base offset. Returns an error if the reshaped view
551    /// would access coordinates not valid in the original slice.
552    ///
553    /// # Requirements
554    ///
555    /// - This slice must be contiguous and have offset == 0.
556    /// - The number of elements must match:
557    ///   `self.sizes().iter().product() == new_sizes.iter().product()`
558    /// - Each flat offset in the proposed view must be valid in `self`.
559    ///
560    /// # Errors
561    ///
562    /// Returns [`SliceError::IncompatibleView`] if:
563    /// - The element count differs
564    /// - The base offset is nonzero
565    /// - Any offset in the view is not reachable in the original slice
566    ///
567    /// # Example
568    ///
569    /// ```rust
570    /// use ndslice::Slice;
571    /// let base = Slice::new_row_major(&[2, 3, 4]);
572    /// let reshaped = base.view(&[6, 4]).unwrap();
573    /// ```
574    pub fn view(&self, new_sizes: &[usize]) -> Result<Slice, SliceError> {
575        let view_elems: usize = new_sizes.iter().product();
576        let base_elems: usize = self.sizes().iter().product();
577
578        // TODO: This version of `view` requires that `self` be
579        // "dense":
580        //
581        //   - `self.offset == 0`
582        //   - `self.strides` match the row-major layout for
583        //     `self.sizes`
584        //   - `self.len() == self.sizes.iter().product::<usize>()`
585        //
586        // Future iterations of this function will aim to relax or
587        // remove the "dense" requirement where possible.
588
589        if view_elems != base_elems {
590            return Err(SliceError::IncompatibleView {
591                reason: format!(
592                    "element count mismatch: base has {}, view wants {}",
593                    base_elems, view_elems
594                ),
595            });
596        }
597        if self.offset != 0 {
598            return Err(SliceError::IncompatibleView {
599                reason: format!("view requires base offset = 0, but found {}", self.offset),
600            });
601        }
602        // Compute row-major strides.
603        let mut new_strides = vec![1; new_sizes.len()];
604        for i in (0..new_sizes.len().saturating_sub(1)).rev() {
605            new_strides[i] = new_strides[i + 1] * new_sizes[i + 1];
606        }
607
608        // Validate that every address in the new view maps to a valid
609        // coordinate in base.
610        for coord in CartesianIterator::new(new_sizes.to_vec()) {
611            #[allow(clippy::identity_op)]
612            let offset_in_view = 0 + coord
613                .iter()
614                .zip(&new_strides)
615                .map(|(i, s)| i * s)
616                .sum::<usize>();
617
618            if self.coordinates(offset_in_view).is_err() {
619                return Err(SliceError::IncompatibleView {
620                    reason: format!("offset {} not reachable in base", offset_in_view),
621                });
622            }
623        }
624
625        Ok(Slice {
626            offset: 0,
627            sizes: new_sizes.to_vec(),
628            strides: new_strides,
629        })
630    }
631
632    /// Returns a sub-slice of `self` starting at `starts`, of size `lens`.
633    pub fn subview(&self, starts: &[usize], lens: &[usize]) -> Result<Slice, SliceError> {
634        if starts.len() != self.num_dim() || lens.len() != self.num_dim() {
635            return Err(SliceError::InvalidDims {
636                expected: self.num_dim(),
637                got: starts.len().max(lens.len()),
638            });
639        }
640
641        for (d, (&start, &len)) in starts.iter().zip(lens).enumerate() {
642            if start + len > self.sizes[d] {
643                return Err(SliceError::IndexOutOfRange {
644                    index: start + len,
645                    total: self.sizes[d],
646                });
647            }
648        }
649
650        let offset = self.location(starts)?;
651        Ok(Slice {
652            offset,
653            sizes: lens.to_vec(),
654            strides: self.strides.clone(),
655        })
656    }
657
658    /// Ensures that every storage offset used by `self` is valid in
659    /// `other`.
660    ///
661    /// That is, for all p ∈ self:
662    /// `other.coordinates(self.location(p))` is defined.
663    ///
664    /// Returns `self` on success, enabling fluent chaining.
665    ///
666    /// # Examples
667    ///
668    /// ```
669    /// use ndslice::Slice;
670    ///
671    /// let base = Slice::new(0, vec![4, 4], vec![4, 1]).unwrap();
672    /// let view = base.subview(&[1, 1], &[2, 2]).unwrap();
673    /// assert_eq!(view.enforce_embedding(&base).unwrap().len(), 4);
674    ///
675    /// let small = Slice::new(0, vec![2, 2], vec![2, 1]).unwrap();
676    /// assert!(view.enforce_embedding(&small).is_err());
677    ///  ```
678    pub fn enforce_embedding<'a>(&'a self, other: &'_ Slice) -> Result<&'a Slice, SliceError> {
679        self.iter()
680            .try_for_each(|loc| other.coordinates(loc).map(|_| ()))?;
681        Ok(self)
682    }
683}
684
685impl std::fmt::Display for Slice {
686    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
687        write!(f, "{:?}", self)
688    }
689}
690
691impl IntoIterator for &Slice {
692    type Item = usize;
693    type IntoIter = SliceIterator;
694    fn into_iter(self) -> Self::IntoIter {
695        self.iter()
696    }
697}
698
699pub struct SliceIterator {
700    pub(crate) slice: Slice,
701    pos: CartesianIterator,
702}
703
704impl Iterator for SliceIterator {
705    type Item = usize;
706
707    fn next(&mut self) -> Option<Self::Item> {
708        match self.pos.next() {
709            None => None,
710            Some(pos) => Some(self.slice.location(&pos).unwrap()),
711        }
712    }
713}
714
715/// Iterates over the Cartesian product of a list of dimension sizes.
716///
717/// Given a list of dimension sizes `[d₀, d₁, ..., dₖ₋₁]`, this yields
718/// all coordinate tuples `[i₀, i₁, ..., iₖ₋₁]` where each `iⱼ ∈
719/// 0..dⱼ`.
720///
721/// Coordinates are yielded in row-major order (last dimension varies
722/// fastest).
723pub struct DimSliceIterator {
724    pos: CartesianIterator,
725}
726
727impl Iterator for DimSliceIterator {
728    type Item = Vec<usize>;
729
730    fn next(&mut self) -> Option<Self::Item> {
731        self.pos.next()
732    }
733}
734
735/// Iterates over all coordinate tuples in an N-dimensional space.
736///
737/// Yields each point in row-major order for the shape defined by
738/// `dims`, where each coordinate lies in `[0..dims[i])`.
739/// # Example
740/// ```ignore
741/// let iter = CartesianIterator::new(vec![2, 3]);
742/// let coords: Vec<_> = iter.collect();
743/// assert_eq!(coords, vec![
744///     vec![0, 0], vec![0, 1], vec![0, 2],
745///     vec![1, 0], vec![1, 1], vec![1, 2],
746/// ]);
747/// ```
748pub(crate) struct CartesianIterator {
749    dims: Vec<usize>,
750    index: usize,
751}
752
753impl CartesianIterator {
754    pub(crate) fn new(dims: Vec<usize>) -> Self {
755        CartesianIterator { dims, index: 0 }
756    }
757}
758
759impl Iterator for CartesianIterator {
760    type Item = Vec<usize>;
761
762    fn next(&mut self) -> Option<Self::Item> {
763        if self.index >= self.dims.iter().product::<usize>() {
764            return None;
765        }
766
767        let mut result: Vec<usize> = vec![0; self.dims.len()];
768        let mut rest = self.index;
769        for (i, dim) in self.dims.iter().enumerate().rev() {
770            result[i] = rest % dim;
771            rest /= dim;
772        }
773        self.index += 1;
774        Some(result)
775    }
776}
777
778/// MapSlice is a view of the underlying Slice that maps each rank
779/// into a different type.
780pub struct MapSlice<'a, T, F>
781where
782    F: Fn(usize) -> T,
783{
784    slice: &'a Slice,
785    mapper: F,
786}
787
788impl<'a, T, F> MapSlice<'a, T, F>
789where
790    F: Fn(usize) -> T,
791{
792    /// The underlying slice sizes.
793    pub fn sizes(&self) -> &[usize] {
794        &self.slice.sizes
795    }
796
797    /// The underlying slice strides.
798    pub fn strides(&self) -> &[usize] {
799        &self.slice.strides
800    }
801
802    /// The mapped value at the provided coordinates. See [`Slice::location`].
803    pub fn location(&self, coord: &[usize]) -> Result<T, SliceError> {
804        self.slice.location(coord).map(&self.mapper)
805    }
806
807    /// The mapped value at the provided index. See [`Slice::get`].
808    pub fn get(&self, index: usize) -> Result<T, SliceError> {
809        self.slice.get(index).map(&self.mapper)
810    }
811
812    /// The underlying slice length.
813    pub fn len(&self) -> usize {
814        self.slice.len()
815    }
816
817    /// Whether the underlying slice is empty.
818    pub fn is_empty(&self) -> bool {
819        self.slice.is_empty()
820    }
821}
822
823#[cfg(test)]
824mod tests {
825    use std::assert_matches::assert_matches;
826    use std::vec;
827
828    use super::*;
829
830    #[test]
831    fn test_cartesian_iterator() {
832        let dims = vec![2, 2, 2];
833        let iter = CartesianIterator::new(dims);
834        let products: Vec<Vec<usize>> = iter.collect();
835        assert_eq!(
836            products,
837            vec![
838                vec![0, 0, 0],
839                vec![0, 0, 1],
840                vec![0, 1, 0],
841                vec![0, 1, 1],
842                vec![1, 0, 0],
843                vec![1, 0, 1],
844                vec![1, 1, 0],
845                vec![1, 1, 1],
846            ]
847        );
848    }
849
850    #[test]
851    #[allow(clippy::explicit_counter_loop)]
852    fn test_slice() {
853        let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
854        for i in 0..4 {
855            assert_eq!(s.get(i).unwrap(), i);
856        }
857
858        {
859            // Test IntoIter
860            let mut current = 0;
861            for index in &s {
862                assert_eq!(index, current);
863                current += 1;
864            }
865        }
866
867        let s = Slice::new(0, vec![3, 4, 5], vec![20, 5, 1]).unwrap();
868        assert_eq!(s.get(3 * 4 + 1).unwrap(), 13);
869
870        let s = Slice::new(0, vec![2, 2, 2], vec![4, 32, 1]).unwrap();
871        assert_eq!(s.get(0).unwrap(), 0);
872        assert_eq!(s.get(1).unwrap(), 1);
873        assert_eq!(s.get(2).unwrap(), 32);
874        assert_eq!(s.get(3).unwrap(), 33);
875        assert_eq!(s.get(4).unwrap(), 4);
876        assert_eq!(s.get(5).unwrap(), 5);
877        assert_eq!(s.get(6).unwrap(), 36);
878        assert_eq!(s.get(7).unwrap(), 37);
879
880        let s = Slice::new(0, vec![2, 2, 2], vec![32, 4, 1]).unwrap();
881        assert_eq!(s.get(0).unwrap(), 0);
882        assert_eq!(s.get(1).unwrap(), 1);
883        assert_eq!(s.get(2).unwrap(), 4);
884        assert_eq!(s.get(4).unwrap(), 32);
885    }
886
887    #[test]
888    fn test_slice_iter() {
889        let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
890        assert!(s.iter().eq(0..6));
891
892        let s = Slice::new(10, vec![10, 2], vec![10, 5]).unwrap();
893        assert!(s.iter().eq((10..=105).step_by(5)));
894
895        // Implementaion corresponds with Slice::get.
896        assert!(s.iter().eq((0..s.len()).map(|i| s.get(i).unwrap())));
897    }
898
899    #[test]
900    fn test_dim_slice_iter() {
901        let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
902        let sub_dims: Vec<_> = s.dim_iter(1).collect();
903        assert_eq!(sub_dims, vec![vec![0], vec![1]]);
904    }
905
906    #[test]
907    fn test_slice_coordinates() {
908        let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
909        assert_eq!(s.coordinates(0).unwrap(), vec![0, 0]);
910        assert_eq!(s.coordinates(3).unwrap(), vec![1, 0]);
911        assert_matches!(
912            s.coordinates(6),
913            Err(SliceError::ValueNotInSlice { value: 6 })
914        );
915
916        let s = Slice::new(10, vec![2, 3], vec![3, 1]).unwrap();
917        assert_matches!(
918            s.coordinates(6),
919            Err(SliceError::ValueNotInSlice { value: 6 })
920        );
921        assert_eq!(s.coordinates(10).unwrap(), vec![0, 0]);
922        assert_eq!(s.coordinates(13).unwrap(), vec![1, 0]);
923
924        let s = Slice::new(0, vec![2, 1, 1], vec![1, 1, 1]).unwrap();
925        assert_eq!(s.coordinates(1).unwrap(), vec![1, 0, 0]);
926    }
927
928    #[test]
929    fn test_slice_index() {
930        let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
931        assert_eq!(s.index(3).unwrap(), 3);
932        assert!(s.index(14).is_err());
933
934        let s = Slice::new(0, vec![2, 2], vec![4, 2]).unwrap();
935        assert_eq!(s.index(2).unwrap(), 1);
936    }
937
938    #[test]
939    fn test_slice_map() {
940        let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
941        let m = s.map(|i| i * 2);
942        assert_eq!(m.get(0).unwrap(), 0);
943        assert_eq!(m.get(3).unwrap(), 6);
944        assert_eq!(m.get(5).unwrap(), 10);
945    }
946
947    #[test]
948    fn test_slice_size_one() {
949        let s = Slice::new(0, vec![1, 1], vec![1, 1]).unwrap();
950        assert_eq!(s.get(0).unwrap(), 0);
951    }
952
953    #[test]
954    fn test_row_major() {
955        let s = Slice::new_row_major(vec![4, 4, 4]);
956        assert_eq!(s.offset(), 0);
957        assert_eq!(s.sizes(), &[4, 4, 4]);
958        assert_eq!(s.strides(), &[16, 4, 1]);
959    }
960
961    #[test]
962    fn test_slice_view_smoke() {
963        use crate::Slice;
964
965        let base = Slice::new_row_major([2, 3, 4]);
966
967        // Reshape: compatible shape and layout
968        let view = base.view(&[6, 4]).unwrap();
969        assert_eq!(view.sizes(), &[6, 4]);
970        assert_eq!(view.offset(), 0);
971        assert_eq!(view.strides(), &[4, 1]);
972        assert_eq!(
973            view.location(&[5, 3]).unwrap(),
974            base.location(&[1, 2, 3]).unwrap()
975        );
976
977        // Reshape: identity (should succeed)
978        let view = base.view(&[2, 3, 4]).unwrap();
979        assert_eq!(view.sizes(), base.sizes());
980        assert_eq!(view.strides(), base.strides());
981
982        // Reshape: incompatible shape (wrong element count)
983        let err = base.view(&[5, 4]);
984        assert!(err.is_err());
985
986        // Reshape: incompatible layout (simulate select)
987        let selected = Slice::new(1, vec![2, 3], vec![6, 1]).unwrap(); // not offset=0
988        let err = selected.view(&[3, 2]);
989        assert!(err.is_err());
990
991        // Reshape: flat 1D view
992        let flat = base.view(&[24]).unwrap();
993        assert_eq!(flat.sizes(), &[24]);
994        assert_eq!(flat.strides(), &[1]);
995        assert_eq!(
996            flat.location(&[23]).unwrap(),
997            base.location(&[1, 2, 3]).unwrap()
998        );
999    }
1000
1001    #[test]
1002    fn test_view_of_view_when_dense() {
1003        // Start with a dense base: 2 × 3 × 4 = 24 elements.
1004        let base = Slice::new_row_major([2, 3, 4]);
1005
1006        // First view: flatten to 1D.
1007        let flat = base.view(&[24]).unwrap();
1008        assert_eq!(flat.sizes(), &[24]);
1009        assert_eq!(flat.strides(), &[1]);
1010        assert_eq!(flat.offset(), 0); // Still dense.
1011
1012        // Second view: reshape 1D to 6 × 4.
1013        let reshaped = flat.view(&[6, 4]).unwrap();
1014        assert_eq!(reshaped.sizes(), &[6, 4]);
1015        assert_eq!(reshaped.strides(), &[4, 1]);
1016        assert_eq!(reshaped.offset(), 0);
1017
1018        // Location agreement check
1019        assert_eq!(
1020            reshaped.location(&[5, 3]).unwrap(),
1021            base.location(&[1, 2, 3]).unwrap()
1022        );
1023    }
1024
1025    #[test]
1026    fn test_at_1d_to_0d() {
1027        let slice = Slice::new_row_major(vec![5]);
1028        assert_eq!(slice.num_dim(), 1);
1029        assert_eq!(slice.sizes(), &[5]);
1030        assert_eq!(slice.strides(), &[1]);
1031
1032        let result = slice.at(0, 3).unwrap();
1033        assert_eq!(result.num_dim(), 0);
1034        assert_eq!(result.sizes(), &[]);
1035        assert_eq!(result.strides(), &[]);
1036        assert_eq!(result.offset(), 3);
1037        assert_eq!(result.location(&[]).unwrap(), 3);
1038    }
1039
1040    #[test]
1041    fn test_at_2d_to_1d() {
1042        let slice = Slice::new_row_major(vec![3, 4]);
1043        assert_eq!(slice.num_dim(), 2);
1044        assert_eq!(slice.sizes(), &[3, 4]);
1045        assert_eq!(slice.strides(), &[4, 1]);
1046
1047        let result = slice.at(0, 1).unwrap();
1048        assert_eq!(result.num_dim(), 1);
1049        assert_eq!(result.sizes(), &[4]);
1050        assert_eq!(result.strides(), &[1]);
1051        assert_eq!(result.offset(), 4);
1052    }
1053
1054    #[test]
1055    fn test_at_3d_to_2d() {
1056        let slice = Slice::new_row_major(vec![2, 3, 4]);
1057        assert_eq!(slice.num_dim(), 3);
1058        assert_eq!(slice.sizes(), &[2, 3, 4]);
1059        assert_eq!(slice.strides(), &[12, 4, 1]);
1060
1061        let result = slice.at(0, 1).unwrap();
1062        assert_eq!(result.num_dim(), 2);
1063        assert_eq!(result.sizes(), &[3, 4]);
1064        assert_eq!(result.strides(), &[4, 1]);
1065        assert_eq!(result.offset(), 12);
1066    }
1067
1068    #[test]
1069    fn test_get_index_inverse_relationship() {
1070        // Start with a 3 x 3 dense row major matrix.
1071        //
1072        // 0 1 2
1073        // 3 4 5
1074        // 6 7 8
1075        let m = Slice::new_row_major([3, 3]);
1076        assert_eq!(m.offset, 0);
1077        assert_eq!(m.sizes(), &[3, 3]);
1078        assert_eq!(m.strides(), &[3, 1]);
1079
1080        // Slice `m` is 0-offset, row-major, dense, gapless.
1081        for loc in m.iter() {
1082            // ∀ `loc` ∈ `m`, `m.index(loc) == loc`.
1083            assert_eq!(m.index(loc).unwrap(), loc);
1084            // ∀ `loc` ∈ `m`, `m.get(m.index(loc)) == loc`.
1085            assert_eq!(m.get(m.index(loc).unwrap()).unwrap(), loc);
1086        }
1087
1088        // Slice out the middle column.
1089        //    1
1090        //    4
1091        //    7
1092        let c = m.select(1, 1, 2, 1).unwrap();
1093        assert_eq!(c.sizes(), &[3, 1]);
1094        assert_eq!(c.strides(), &[3, 1]);
1095
1096        // Slice `c` has a non-zero offset.
1097        for loc in c.iter() {
1098            // Local rank of `loc` in `c` != loc.
1099            assert_ne!(c.index(loc).unwrap(), loc);
1100            // ∀ `loc` ∈ `c`, `c.get(c.index(loc)) == loc`.
1101            assert_eq!(c.get(c.index(loc).unwrap()).unwrap(), loc);
1102        }
1103    }
1104
1105    #[test]
1106    fn embedding_succeeds_for_contained_view() {
1107        let base = Slice::new(0, vec![4, 4], vec![4, 1]).unwrap(); // 4×4 matrix, row-major
1108        let view = Slice::new(5, vec![2, 2], vec![4, 1]).unwrap(); // a 2×2 submatrix starting at (1,1)
1109
1110        assert!(view.enforce_embedding(&base).is_ok());
1111    }
1112
1113    #[test]
1114    fn embedding_fails_for_out_of_bounds_view() {
1115        let base = Slice::new(0, vec![4, 4], vec![4, 1]).unwrap(); // 4×4 matrix
1116        let view = Slice::new(14, vec![2, 2], vec![4, 1]).unwrap(); // starts at (3,2), accesses (4,3)
1117
1118        assert!(view.enforce_embedding(&base).is_err());
1119    }
1120}