ndslice/
slice.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::iter::zip;
10
11use serde::Deserialize;
12use serde::Serialize;
13
14/// The type of error for slice operations.
15#[derive(Debug, thiserror::Error)]
16#[non_exhaustive]
17pub enum SliceError {
18    #[error("invalid dims: expected {expected}, got {got}")]
19    InvalidDims { expected: usize, got: usize },
20
21    #[error("nonrectangular shape")]
22    NonrectangularShape,
23
24    #[error("nonunique strides")]
25    NonuniqueStrides,
26
27    #[error("stride {stride} must be larger than size of previous space {space}")]
28    StrideTooSmall { stride: usize, space: usize },
29
30    #[error("index {index} out of range {total}")]
31    IndexOutOfRange { index: usize, total: usize },
32
33    #[error("value {value} not in slice")]
34    ValueNotInSlice { value: usize },
35
36    #[error("incompatible view: {reason}")]
37    IncompatibleView { reason: String },
38
39    #[error("noncontiguous shape")]
40    NonContiguous,
41
42    #[error("empty range: {begin}..{end} (step {step})")]
43    EmptyRange {
44        begin: usize,
45        end: usize,
46        step: usize,
47    },
48
49    #[error("dimension {dim} out of range for {ndims}-dimensional slice")]
50    DimensionOutOfRange { dim: usize, ndims: usize },
51}
52
53/// Slice is a compact representation of indices into the flat
54/// representation of an n-dimensional array. Given an offset, sizes of
55/// each dimension, and strides for each dimension, Slice can compute
56/// indices into the flat array.
57///
58/// For example, the following describes a dense 4x4x4 array in row-major
59/// order:
60/// ```
61/// # use ndslice::Slice;
62/// let s = Slice::new(0, vec![4, 4, 4], vec![16, 4, 1]).unwrap();
63/// assert!(s.iter().eq(0..(4 * 4 * 4)));
64/// ```
65///
66/// Slices allow easy slicing by subsetting and striding. For example,
67/// we can fix the index of the second dimension by dropping it and
68/// adding that index (multiplied by the previous size) to the offset.
69///
70/// ```
71/// # use ndslice::Slice;
72/// let s = Slice::new(0, vec![2, 4, 2], vec![8, 2, 1]).unwrap();
73/// let selected_index = 3;
74/// let sub = Slice::new(2 * selected_index, vec![2, 2], vec![8, 1]).unwrap();
75/// let coords = [[0, 0], [0, 1], [1, 0], [1, 1]];
76/// for coord @ [x, y] in coords {
77///     assert_eq!(
78///         sub.location(&coord).unwrap(),
79///         s.location(&[x, 3, y]).unwrap()
80///     );
81/// }
82/// ```
83// TODO: Consider representing this by arrays parameterized by the slice
84// dimensionality.
85#[derive(Serialize, Deserialize, Clone, Eq, PartialEq, Hash, Debug)]
86pub struct Slice {
87    offset: usize,
88    sizes: Vec<usize>,
89    strides: Vec<usize>,
90}
91
92impl Slice {
93    /// Create a new Slice with the provided offset, sizes, and
94    /// strides. New performs validation to ensure that sizes and strides
95    /// are compatible:
96    ///   - They have to be the same length (i.e., same number of dimensions)
97    ///   - They have to be rectangular (i.e., stride n+1 has to evenly divide into stride n)
98    ///   - Strides must be nonoverlapping (each stride has to be larger than the previous space)
99    pub fn new(offset: usize, sizes: Vec<usize>, strides: Vec<usize>) -> Result<Self, SliceError> {
100        if sizes.len() != strides.len() {
101            return Err(SliceError::InvalidDims {
102                expected: sizes.len(),
103                got: strides.len(),
104            });
105        }
106        let mut combined: Vec<(usize, usize)> =
107            strides.iter().cloned().zip(sizes.iter().cloned()).collect();
108        combined.sort();
109
110        let mut prev_stride: Option<usize> = None;
111        let mut prev_size: Option<usize> = None;
112        let mut total: usize = 1;
113        for (stride, size) in combined {
114            if let Some(prev_stride) = prev_stride {
115                if stride % prev_stride != 0 {
116                    return Err(SliceError::NonrectangularShape);
117                }
118                // Strides for single element dimensions can repeat, because they are unused
119                if stride == prev_stride && size != 1 && prev_size.unwrap_or(1) != 1 {
120                    return Err(SliceError::NonuniqueStrides);
121                }
122            }
123            if total > stride {
124                return Err(SliceError::StrideTooSmall {
125                    stride,
126                    space: total,
127                });
128            }
129            total = stride * size;
130            prev_stride = Some(stride);
131            prev_size = Some(size);
132        }
133
134        Ok(Slice {
135            offset,
136            sizes,
137            strides,
138        })
139    }
140
141    /// Deconstruct the slice into its offset, sizes, and strides.
142    pub fn into_inner(self) -> (usize, Vec<usize>, Vec<usize>) {
143        let Slice {
144            offset,
145            sizes,
146            strides,
147        } = self;
148        (offset, sizes, strides)
149    }
150
151    /// Create a new slice of the given sizes in row-major order.
152    pub fn new_row_major(sizes: impl Into<Vec<usize>>) -> Self {
153        let sizes = sizes.into();
154        // "flip it and reverse it" --Missy Elliott
155        let mut strides: Vec<usize> = sizes.clone();
156        let _ = strides.iter_mut().rev().fold(1, |acc, n| {
157            let next = *n * acc;
158            *n = acc;
159            next
160        });
161        Self {
162            offset: 0,
163            sizes,
164            strides,
165        }
166    }
167
168    /// Create one celled slice.
169    pub fn new_single_multi_dim_cell(dims: usize) -> Self {
170        Self {
171            offset: 0,
172            sizes: vec![1; dims],
173            strides: vec![1; dims],
174        }
175    }
176
177    /// The number of dimensions in this slice.
178    pub fn num_dim(&self) -> usize {
179        self.sizes.len()
180    }
181
182    /// This is the offset from which the first value in the Slice begins.
183    pub fn offset(&self) -> usize {
184        self.offset
185    }
186
187    /// The shape of the slice; that is, the size of each dimension.
188    pub fn sizes(&self) -> &[usize] {
189        &self.sizes
190    }
191
192    /// The strides of the slice; that is, the distance between each
193    /// element at a given index in the underlying array.
194    pub fn strides(&self) -> &[usize] {
195        &self.strides
196    }
197
198    pub fn is_contiguous(&self) -> bool {
199        let mut expected_stride = 1;
200        for (stride, size) in zip(self.strides.iter(), self.sizes.iter()).rev() {
201            if *stride != expected_stride {
202                return false;
203            }
204            expected_stride *= *size
205        }
206        true
207    }
208
209    /// Select a single index along a dimension, removing that
210    /// dimension entirely.
211    ///
212    /// This reduces the dimensionality by 1 by "fixing" one
213    /// coordinate to a specific value. Think of it like taking a
214    /// cross-section: selecting index 2 from the first dimension of a
215    /// 3D array gives you a 2D slice, like cutting a plane from a 3D
216    /// space at a fixed position.
217    ///
218    /// This reduces the dimensionality by 1 by "fixing" one
219    /// coordinate to a specific value. The fixed coordinate's
220    /// contribution (index × stride) gets absorbed into the base
221    /// offset, while the remaining dimensions keep their original
222    /// strides unchanged - they still describe the same memory
223    /// distances between elements.
224    ///
225    /// # Example intuition
226    /// - 3D array → select `at(dim=0, index=2)` → 2D slice (like a
227    ///   plane)
228    /// - 2D matrix → select `at(dim=1, index=3)` → 1D vector (like a
229    ///   column)
230    /// - 1D vector → select `at(dim=0, index=5)` → 0D scalar (single
231    ///   element)
232    ///
233    /// # Arguments
234    /// * `dim` - The dimension index to select from
235    /// * `index` - The index within that dimension
236    ///
237    /// # Returns
238    /// A new slice with one fewer dimension
239    ///
240    /// # Errors
241    /// * `IndexOutOfRange` if `dim >= self.sizes.len()` or `index >=
242    ///   self.sizes[dim]`
243    pub fn at(&self, dim: usize, index: usize) -> Result<Self, SliceError> {
244        if dim >= self.sizes.len() {
245            return Err(SliceError::DimensionOutOfRange {
246                dim,
247                ndims: self.num_dim(),
248            });
249        }
250        if index >= self.sizes[dim] {
251            return Err(SliceError::IndexOutOfRange {
252                index,
253                total: self.sizes[dim],
254            });
255        }
256
257        let new_offset = self.offset + index * self.strides[dim];
258        let mut new_sizes = self.sizes.clone();
259        let mut new_strides = self.strides.clone();
260        new_sizes.remove(dim);
261        new_strides.remove(dim);
262        let slice = Slice::new(new_offset, new_sizes, new_strides)?;
263        Ok(slice)
264    }
265
266    /// A slice defines a **strided view**; a triple (`offset,
267    /// `sizes`, `strides`). Each coordinate maps to a flat memory
268    /// index using the formula:
269    /// ```text
270    /// index = offset + ∑ iₖ × strides[k]
271    /// ```
272    /// where `iₖ` is the coordinate in dimension `k`.
273    ///
274    /// The `select(dim, range)` operation restricts the view to a
275    /// subrange along a single dimension. It calculates a new slice
276    /// from a base slice by updating the `offset`, `sizes[dim]`, and
277    /// `strides[dim]` to describe a logically reindexed subregion:
278    /// ```text
279    /// offset       += begin × strides[dim]
280    /// sizes[dim]    = ⎡(end - begin) / step⎤
281    /// strides[dim] ×= step
282    /// ```
283    ///
284    /// This transformation preserves the strided layout and avoids
285    /// copying data. After `select`, the view behaves as if indexing
286    /// starts at zero in the selected dimension, with a new length
287    /// and stride. From the user's perspective, nothing changes;
288    /// indexing remains zero-based, and the resulting shape can be
289    /// used like any other. The transformation is internal: the
290    /// view's offset and stride absorb the selection logic.
291    pub fn select(
292        &self,
293        dim: usize,
294        begin: usize,
295        end: usize,
296        step: usize,
297    ) -> Result<Self, SliceError> {
298        if dim >= self.sizes.len() {
299            return Err(SliceError::IndexOutOfRange {
300                index: dim,
301                total: self.sizes.len(),
302            });
303        }
304        if begin >= self.sizes[dim] {
305            return Err(SliceError::IndexOutOfRange {
306                index: begin,
307                total: self.sizes[dim],
308            });
309        }
310        if end <= begin {
311            return Err(SliceError::EmptyRange { begin, end, step });
312        }
313
314        let mut offset = self.offset();
315        let mut sizes = self.sizes().to_vec();
316        let mut strides = self.strides().to_vec();
317
318        offset += begin * strides[dim];
319        // The # of elems in `begin..end` with step `step`. This is
320        // ⌈(end - begin) / stride⌉ — the number of steps that fit in
321        // the half-open interval.
322        sizes[dim] = (end - begin).div_ceil(step);
323        strides[dim] *= step;
324
325        let slice = Slice::new(offset, sizes, strides)?;
326        Ok(slice)
327    }
328
329    /// Return the location of the provided coordinates.
330    pub fn location(&self, coord: &[usize]) -> Result<usize, SliceError> {
331        if coord.len() != self.sizes.len() {
332            return Err(SliceError::InvalidDims {
333                expected: self.sizes.len(),
334                got: coord.len(),
335            });
336        }
337        Ok(self.offset
338            + coord
339                .iter()
340                .zip(&self.strides)
341                .map(|(pos, stride)| pos * stride)
342                .sum::<usize>())
343    }
344
345    /// Return the coordinates of the provided value in the n-d space of this
346    /// Slice.
347    pub fn coordinates(&self, value: usize) -> Result<Vec<usize>, SliceError> {
348        let mut pos = value
349            .checked_sub(self.offset)
350            .ok_or(SliceError::ValueNotInSlice { value })?;
351        let mut result = vec![0; self.sizes.len()];
352        let mut sorted_info: Vec<_> = self
353            .strides
354            .iter()
355            .zip(self.sizes.iter().enumerate())
356            .collect();
357        sorted_info.sort_by_key(|&(stride, _)| *stride);
358        for &(stride, (i, &size)) in sorted_info.iter().rev() {
359            let (index, new_pos) = if size > 1 {
360                (pos / stride, pos % stride)
361            } else {
362                (0, pos)
363            };
364            if index >= size {
365                return Err(SliceError::ValueNotInSlice { value });
366            }
367            result[i] = index;
368            pos = new_pos;
369        }
370        if pos != 0 {
371            return Err(SliceError::ValueNotInSlice { value });
372        }
373        Ok(result)
374    }
375
376    /// Returns whether the provided rank is contained in this slice.
377    pub fn contains(&self, value: usize) -> bool {
378        self.coordinates(value).is_ok()
379    }
380
381    /// The total length of the slice's indices.
382    pub fn len(&self) -> usize {
383        self.sizes.iter().product()
384    }
385
386    pub fn is_empty(&self) -> bool {
387        self.len() == 0
388    }
389
390    /// Iterator over the slice's indices.
391    pub fn iter(&self) -> SliceIterator {
392        SliceIterator {
393            slice: self.clone(),
394            pos: CartesianIterator::new(self.sizes.clone()),
395        }
396    }
397
398    /// Iterator over sub-dimensions of the slice.
399    pub fn dim_iter(&self, dims: usize) -> DimSliceIterator {
400        DimSliceIterator {
401            pos: CartesianIterator::new(self.sizes[0..dims].to_vec()),
402        }
403    }
404
405    /// The linear index formula calculates the logical rank of a
406    /// multidimensional point in a row-major flattened array,
407    /// assuming dense gapless storage with zero offset:
408    ///
409    /// ```plain
410    ///     index := Σ(coordinate[i] × ∏(sizes[j] for j > i))
411    /// ```
412    ///
413    /// For example, given a 3x2 row-major base array B:
414    ///
415    /// ```plain
416    ///       0 1 2         1
417    /// B =   3 4 5    V =  4
418    ///       6 7 8         7
419    /// ```
420    ///
421    /// Let V be the first column of B. Then,
422    ///
423    /// ```plain
424    /// V      | loc   | index
425    /// -------+-------+------
426    /// (0, 0) |  1    | 0
427    /// (1, 0) |  4    | 1
428    /// (2, 0) |  7    | 2
429    /// ```
430    ///
431    /// # Conditions Under Which `loc = index`
432    ///
433    /// The physical offset formula computes the memory location of a
434    /// point `p` as:
435    ///
436    /// ```plain
437    /// loc := offset + Σ(coordinate[i] × stride[i])
438    /// ```
439    ///
440    /// Let the layout be dense row-major and offset = 0.
441    /// Then,
442    /// ```plain
443    /// stride[i] := ∏(sizes[j] for j > i).
444    /// ```
445    /// and substituting into the physical offset formula:
446    /// ```plain
447    ///   loc = Σ(coordinate[i] × stride[i])
448    ///       = Σ(coordinate[i] × ∏(sizes[j] for j > i))
449    ///       = index.
450    /// ```
451    ///
452    /// Thus, ∀ p = (i, j) ∈ B, loc_B(p) = index_B(p).
453    ///
454    /// # See also
455    ///
456    /// The [`get`] function performs an inverse operation: given a
457    /// logical index in row-major order, it computes the physical
458    /// memory offset according to the slice layout. So, if the layout
459    /// is row-major then `s.get(s.index(loc)) = loc`.
460    pub fn index(&self, value: usize) -> Result<usize, SliceError> {
461        let coords = self.coordinates(value)?;
462        let mut stride = 1;
463        let mut result = 0;
464
465        for (idx, size) in coords.iter().rev().zip(self.sizes.iter().rev()) {
466            result += *idx * stride;
467            stride *= size;
468        }
469
470        Ok(result)
471    }
472
473    /// Given a logical index (in row-major order), return the
474    /// physical memory offset of that element according to this
475    /// slice’s layout.
476    ///
477    /// The index is interpreted as a position in row-major traversal
478    /// that is, iterating across columns within rows. This method
479    /// converts logical row-major index to physical offset by:
480    ///
481    /// 1. Decomposing index into multidimensional coordinates
482    /// 2. Computing offset = base + Σ(coordinate[i] × stride[i])
483    ///
484    /// For example, with shape `[3, 4]` (3 rows, 4 columns) and
485    /// column-major layout:
486    ///
487    /// ```text
488    /// sizes  = [3, 4]         // rows, cols
489    /// strides = [1, 3]        // column-major: down, then right
490    ///
491    /// Logical matrix:
492    ///   A  B  C  D
493    ///   E  F  G  H
494    ///   I  J  K  L
495    ///
496    /// Memory layout:
497    /// offset 0  → [0, 0] = A
498    /// offset 1  → [1, 0] = E
499    /// offset 2  → [2, 0] = I
500    /// offset 3  → [0, 1] = B
501    /// offset 4  → [1, 1] = F
502    /// offset 5  → [2, 1] = J
503    /// offset 6  → [0, 2] = C
504    /// offset 7  → [1, 2] = G
505    /// offset 8  → [2, 2] = K
506    /// offset 9  → [0, 3] = D
507    /// offset 10 → [1, 3] = H
508    /// offset 11 → [2, 3] = L
509    ///
510    /// Then:
511    ///   index = 1  → coordinate [0, 1]  → offset = 0*1 + 1*3 = 3
512    /// ```
513    ///
514    /// # Errors
515    ///
516    /// Returns an error if `index >= product(sizes)`.
517    ///
518    /// # See also
519    ///
520    /// The [`index`] function performs an inverse operation: given a
521    /// memory offset, it returns the logical position of that element
522    /// in the slice's row-major iteration order.
523    pub fn get(&self, index: usize) -> Result<usize, SliceError> {
524        let mut val = self.offset;
525        let mut rest = index;
526        let mut total = 1;
527        for (size, stride) in self.sizes.iter().zip(self.strides.iter()).rev() {
528            total *= size;
529            val += (rest % size) * stride;
530            rest /= size;
531        }
532        if index < total {
533            Ok(val)
534        } else {
535            Err(SliceError::IndexOutOfRange { index, total })
536        }
537    }
538
539    /// The returned [`MapSlice`] is a view of this slice, with its elements
540    /// mapped using the provided mapping function.
541    pub fn map<T, F>(&self, mapper: F) -> MapSlice<'_, T, F>
542    where
543        F: Fn(usize) -> T,
544    {
545        MapSlice {
546            slice: self,
547            mapper,
548        }
549    }
550
551    /// Returns a new [`Slice`] with the given shape by reinterpreting
552    /// the layout of this slice.
553    ///
554    /// Constructs a new shape with standard row-major strides, using
555    /// the same base offset. Returns an error if the reshaped view
556    /// would access coordinates not valid in the original slice.
557    ///
558    /// # Requirements
559    ///
560    /// - This slice must be contiguous and have offset == 0.
561    /// - The number of elements must match:
562    ///   `self.sizes().iter().product() == new_sizes.iter().product()`
563    /// - Each flat offset in the proposed view must be valid in `self`.
564    ///
565    /// # Errors
566    ///
567    /// Returns [`SliceError::IncompatibleView`] if:
568    /// - The element count differs
569    /// - The base offset is nonzero
570    /// - Any offset in the view is not reachable in the original slice
571    ///
572    /// # Example
573    ///
574    /// ```rust
575    /// use ndslice::Slice;
576    /// let base = Slice::new_row_major(&[2, 3, 4]);
577    /// let reshaped = base.view(&[6, 4]).unwrap();
578    /// ```
579    pub fn view(&self, new_sizes: &[usize]) -> Result<Slice, SliceError> {
580        let view_elems: usize = new_sizes.iter().product();
581        let base_elems: usize = self.sizes().iter().product();
582
583        // TODO: This version of `view` requires that `self` be
584        // "dense":
585        //
586        //   - `self.offset == 0`
587        //   - `self.strides` match the row-major layout for
588        //     `self.sizes`
589        //   - `self.len() == self.sizes.iter().product::<usize>()`
590        //
591        // Future iterations of this function will aim to relax or
592        // remove the "dense" requirement where possible.
593
594        if view_elems != base_elems {
595            return Err(SliceError::IncompatibleView {
596                reason: format!(
597                    "element count mismatch: base has {}, view wants {}",
598                    base_elems, view_elems
599                ),
600            });
601        }
602        if self.offset != 0 {
603            return Err(SliceError::IncompatibleView {
604                reason: format!("view requires base offset = 0, but found {}", self.offset),
605            });
606        }
607        // Compute row-major strides.
608        let mut new_strides = vec![1; new_sizes.len()];
609        for i in (0..new_sizes.len().saturating_sub(1)).rev() {
610            new_strides[i] = new_strides[i + 1] * new_sizes[i + 1];
611        }
612
613        // Validate that every address in the new view maps to a valid
614        // coordinate in base.
615        for coord in CartesianIterator::new(new_sizes.to_vec()) {
616            #[allow(clippy::identity_op)]
617            let offset_in_view = 0 + coord
618                .iter()
619                .zip(&new_strides)
620                .map(|(i, s)| i * s)
621                .sum::<usize>();
622
623            if self.coordinates(offset_in_view).is_err() {
624                return Err(SliceError::IncompatibleView {
625                    reason: format!("offset {} not reachable in base", offset_in_view),
626                });
627            }
628        }
629
630        Ok(Slice {
631            offset: 0,
632            sizes: new_sizes.to_vec(),
633            strides: new_strides,
634        })
635    }
636
637    /// Returns a sub-slice of `self` starting at `starts`, of size `lens`.
638    pub fn subview(&self, starts: &[usize], lens: &[usize]) -> Result<Slice, SliceError> {
639        if starts.len() != self.num_dim() || lens.len() != self.num_dim() {
640            return Err(SliceError::InvalidDims {
641                expected: self.num_dim(),
642                got: starts.len().max(lens.len()),
643            });
644        }
645
646        for (d, (&start, &len)) in starts.iter().zip(lens).enumerate() {
647            if start + len > self.sizes[d] {
648                return Err(SliceError::IndexOutOfRange {
649                    index: start + len,
650                    total: self.sizes[d],
651                });
652            }
653        }
654
655        let offset = self.location(starts)?;
656        Ok(Slice {
657            offset,
658            sizes: lens.to_vec(),
659            strides: self.strides.clone(),
660        })
661    }
662
663    /// Ensures that every storage offset used by `self` is valid in
664    /// `other`.
665    ///
666    /// That is, for all p ∈ self:
667    /// `other.coordinates(self.location(p))` is defined.
668    ///
669    /// Returns `self` on success, enabling fluent chaining.
670    ///
671    /// # Examples
672    ///
673    /// ```
674    /// use ndslice::Slice;
675    ///
676    /// let base = Slice::new(0, vec![4, 4], vec![4, 1]).unwrap();
677    /// let view = base.subview(&[1, 1], &[2, 2]).unwrap();
678    /// assert_eq!(view.enforce_embedding(&base).unwrap().len(), 4);
679    ///
680    /// let small = Slice::new(0, vec![2, 2], vec![2, 1]).unwrap();
681    /// assert!(view.enforce_embedding(&small).is_err());
682    ///  ```
683    pub fn enforce_embedding<'a>(&'a self, other: &'_ Slice) -> Result<&'a Slice, SliceError> {
684        self.iter()
685            .try_for_each(|loc| other.coordinates(loc).map(|_| ()))?;
686        Ok(self)
687    }
688}
689
690impl std::fmt::Display for Slice {
691    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
692        write!(f, "{:?}", self)
693    }
694}
695
696impl IntoIterator for &Slice {
697    type Item = usize;
698    type IntoIter = SliceIterator;
699    fn into_iter(self) -> Self::IntoIter {
700        self.iter()
701    }
702}
703
704pub struct SliceIterator {
705    pub(crate) slice: Slice,
706    pos: CartesianIterator,
707}
708
709impl Iterator for SliceIterator {
710    type Item = usize;
711
712    fn next(&mut self) -> Option<Self::Item> {
713        match self.pos.next() {
714            None => None,
715            Some(pos) => Some(self.slice.location(&pos).unwrap()),
716        }
717    }
718}
719
720/// Iterates over the Cartesian product of a list of dimension sizes.
721///
722/// Given a list of dimension sizes `[d₀, d₁, ..., dₖ₋₁]`, this yields
723/// all coordinate tuples `[i₀, i₁, ..., iₖ₋₁]` where each `iⱼ ∈
724/// 0..dⱼ`.
725///
726/// Coordinates are yielded in row-major order (last dimension varies
727/// fastest).
728pub struct DimSliceIterator {
729    pos: CartesianIterator,
730}
731
732impl Iterator for DimSliceIterator {
733    type Item = Vec<usize>;
734
735    fn next(&mut self) -> Option<Self::Item> {
736        self.pos.next()
737    }
738}
739
740/// Iterates over all coordinate tuples in an N-dimensional space.
741///
742/// Yields each point in row-major order for the shape defined by
743/// `dims`, where each coordinate lies in `[0..dims[i])`.
744/// # Example
745/// ```ignore
746/// let iter = CartesianIterator::new(vec![2, 3]);
747/// let coords: Vec<_> = iter.collect();
748/// assert_eq!(coords, vec![
749///     vec![0, 0], vec![0, 1], vec![0, 2],
750///     vec![1, 0], vec![1, 1], vec![1, 2],
751/// ]);
752/// ```
753pub(crate) struct CartesianIterator {
754    dims: Vec<usize>,
755    index: usize,
756}
757
758impl CartesianIterator {
759    pub(crate) fn new(dims: Vec<usize>) -> Self {
760        CartesianIterator { dims, index: 0 }
761    }
762}
763
764impl Iterator for CartesianIterator {
765    type Item = Vec<usize>;
766
767    fn next(&mut self) -> Option<Self::Item> {
768        if self.index >= self.dims.iter().product::<usize>() {
769            return None;
770        }
771
772        let mut result: Vec<usize> = vec![0; self.dims.len()];
773        let mut rest = self.index;
774        for (i, dim) in self.dims.iter().enumerate().rev() {
775            result[i] = rest % dim;
776            rest /= dim;
777        }
778        self.index += 1;
779        Some(result)
780    }
781}
782
783/// MapSlice is a view of the underlying Slice that maps each rank
784/// into a different type.
785pub struct MapSlice<'a, T, F>
786where
787    F: Fn(usize) -> T,
788{
789    slice: &'a Slice,
790    mapper: F,
791}
792
793impl<'a, T, F> MapSlice<'a, T, F>
794where
795    F: Fn(usize) -> T,
796{
797    /// The underlying slice sizes.
798    pub fn sizes(&self) -> &[usize] {
799        &self.slice.sizes
800    }
801
802    /// The underlying slice strides.
803    pub fn strides(&self) -> &[usize] {
804        &self.slice.strides
805    }
806
807    /// The mapped value at the provided coordinates. See [`Slice::location`].
808    pub fn location(&self, coord: &[usize]) -> Result<T, SliceError> {
809        self.slice.location(coord).map(&self.mapper)
810    }
811
812    /// The mapped value at the provided index. See [`Slice::get`].
813    pub fn get(&self, index: usize) -> Result<T, SliceError> {
814        self.slice.get(index).map(&self.mapper)
815    }
816
817    /// The underlying slice length.
818    pub fn len(&self) -> usize {
819        self.slice.len()
820    }
821
822    /// Whether the underlying slice is empty.
823    pub fn is_empty(&self) -> bool {
824        self.slice.is_empty()
825    }
826}
827
828#[cfg(test)]
829mod tests {
830    use std::assert_matches::assert_matches;
831    use std::vec;
832
833    use super::*;
834
835    #[test]
836    fn test_cartesian_iterator() {
837        let dims = vec![2, 2, 2];
838        let iter = CartesianIterator::new(dims);
839        let products: Vec<Vec<usize>> = iter.collect();
840        assert_eq!(
841            products,
842            vec![
843                vec![0, 0, 0],
844                vec![0, 0, 1],
845                vec![0, 1, 0],
846                vec![0, 1, 1],
847                vec![1, 0, 0],
848                vec![1, 0, 1],
849                vec![1, 1, 0],
850                vec![1, 1, 1],
851            ]
852        );
853    }
854
855    #[test]
856    #[allow(clippy::explicit_counter_loop)]
857    fn test_slice() {
858        let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
859        for i in 0..4 {
860            assert_eq!(s.get(i).unwrap(), i);
861        }
862
863        {
864            // Test IntoIter
865            let mut current = 0;
866            for index in &s {
867                assert_eq!(index, current);
868                current += 1;
869            }
870        }
871
872        let s = Slice::new(0, vec![3, 4, 5], vec![20, 5, 1]).unwrap();
873        assert_eq!(s.get(3 * 4 + 1).unwrap(), 13);
874
875        let s = Slice::new(0, vec![2, 2, 2], vec![4, 32, 1]).unwrap();
876        assert_eq!(s.get(0).unwrap(), 0);
877        assert_eq!(s.get(1).unwrap(), 1);
878        assert_eq!(s.get(2).unwrap(), 32);
879        assert_eq!(s.get(3).unwrap(), 33);
880        assert_eq!(s.get(4).unwrap(), 4);
881        assert_eq!(s.get(5).unwrap(), 5);
882        assert_eq!(s.get(6).unwrap(), 36);
883        assert_eq!(s.get(7).unwrap(), 37);
884
885        let s = Slice::new(0, vec![2, 2, 2], vec![32, 4, 1]).unwrap();
886        assert_eq!(s.get(0).unwrap(), 0);
887        assert_eq!(s.get(1).unwrap(), 1);
888        assert_eq!(s.get(2).unwrap(), 4);
889        assert_eq!(s.get(4).unwrap(), 32);
890    }
891
892    #[test]
893    fn test_slice_iter() {
894        let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
895        assert!(s.iter().eq(0..6));
896
897        let s = Slice::new(10, vec![10, 2], vec![10, 5]).unwrap();
898        assert!(s.iter().eq((10..=105).step_by(5)));
899
900        // Implementaion corresponds with Slice::get.
901        assert!(s.iter().eq((0..s.len()).map(|i| s.get(i).unwrap())));
902    }
903
904    #[test]
905    fn test_dim_slice_iter() {
906        let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
907        let sub_dims: Vec<_> = s.dim_iter(1).collect();
908        assert_eq!(sub_dims, vec![vec![0], vec![1]]);
909    }
910
911    #[test]
912    fn test_slice_coordinates() {
913        let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
914        assert_eq!(s.coordinates(0).unwrap(), vec![0, 0]);
915        assert_eq!(s.coordinates(3).unwrap(), vec![1, 0]);
916        assert_matches!(
917            s.coordinates(6),
918            Err(SliceError::ValueNotInSlice { value: 6 })
919        );
920
921        let s = Slice::new(10, vec![2, 3], vec![3, 1]).unwrap();
922        assert_matches!(
923            s.coordinates(6),
924            Err(SliceError::ValueNotInSlice { value: 6 })
925        );
926        assert_eq!(s.coordinates(10).unwrap(), vec![0, 0]);
927        assert_eq!(s.coordinates(13).unwrap(), vec![1, 0]);
928
929        let s = Slice::new(0, vec![2, 1, 1], vec![1, 1, 1]).unwrap();
930        assert_eq!(s.coordinates(1).unwrap(), vec![1, 0, 0]);
931    }
932
933    #[test]
934    fn test_slice_index() {
935        let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
936        assert_eq!(s.index(3).unwrap(), 3);
937        assert!(s.index(14).is_err());
938
939        let s = Slice::new(0, vec![2, 2], vec![4, 2]).unwrap();
940        assert_eq!(s.index(2).unwrap(), 1);
941    }
942
943    #[test]
944    fn test_slice_map() {
945        let s = Slice::new(0, vec![2, 3], vec![3, 1]).unwrap();
946        let m = s.map(|i| i * 2);
947        assert_eq!(m.get(0).unwrap(), 0);
948        assert_eq!(m.get(3).unwrap(), 6);
949        assert_eq!(m.get(5).unwrap(), 10);
950    }
951
952    #[test]
953    fn test_slice_size_one() {
954        let s = Slice::new(0, vec![1, 1], vec![1, 1]).unwrap();
955        assert_eq!(s.get(0).unwrap(), 0);
956    }
957
958    #[test]
959    fn test_row_major() {
960        let s = Slice::new_row_major(vec![4, 4, 4]);
961        assert_eq!(s.offset(), 0);
962        assert_eq!(s.sizes(), &[4, 4, 4]);
963        assert_eq!(s.strides(), &[16, 4, 1]);
964    }
965
966    #[test]
967    fn test_slice_view_smoke() {
968        use crate::Slice;
969
970        let base = Slice::new_row_major([2, 3, 4]);
971
972        // Reshape: compatible shape and layout
973        let view = base.view(&[6, 4]).unwrap();
974        assert_eq!(view.sizes(), &[6, 4]);
975        assert_eq!(view.offset(), 0);
976        assert_eq!(view.strides(), &[4, 1]);
977        assert_eq!(
978            view.location(&[5, 3]).unwrap(),
979            base.location(&[1, 2, 3]).unwrap()
980        );
981
982        // Reshape: identity (should succeed)
983        let view = base.view(&[2, 3, 4]).unwrap();
984        assert_eq!(view.sizes(), base.sizes());
985        assert_eq!(view.strides(), base.strides());
986
987        // Reshape: incompatible shape (wrong element count)
988        let err = base.view(&[5, 4]);
989        assert!(err.is_err());
990
991        // Reshape: incompatible layout (simulate select)
992        let selected = Slice::new(1, vec![2, 3], vec![6, 1]).unwrap(); // not offset=0
993        let err = selected.view(&[3, 2]);
994        assert!(err.is_err());
995
996        // Reshape: flat 1D view
997        let flat = base.view(&[24]).unwrap();
998        assert_eq!(flat.sizes(), &[24]);
999        assert_eq!(flat.strides(), &[1]);
1000        assert_eq!(
1001            flat.location(&[23]).unwrap(),
1002            base.location(&[1, 2, 3]).unwrap()
1003        );
1004    }
1005
1006    #[test]
1007    fn test_view_of_view_when_dense() {
1008        // Start with a dense base: 2 × 3 × 4 = 24 elements.
1009        let base = Slice::new_row_major([2, 3, 4]);
1010
1011        // First view: flatten to 1D.
1012        let flat = base.view(&[24]).unwrap();
1013        assert_eq!(flat.sizes(), &[24]);
1014        assert_eq!(flat.strides(), &[1]);
1015        assert_eq!(flat.offset(), 0); // Still dense.
1016
1017        // Second view: reshape 1D to 6 × 4.
1018        let reshaped = flat.view(&[6, 4]).unwrap();
1019        assert_eq!(reshaped.sizes(), &[6, 4]);
1020        assert_eq!(reshaped.strides(), &[4, 1]);
1021        assert_eq!(reshaped.offset(), 0);
1022
1023        // Location agreement check
1024        assert_eq!(
1025            reshaped.location(&[5, 3]).unwrap(),
1026            base.location(&[1, 2, 3]).unwrap()
1027        );
1028    }
1029
1030    #[test]
1031    fn test_at_1d_to_0d() {
1032        let slice = Slice::new_row_major(vec![5]);
1033        assert_eq!(slice.num_dim(), 1);
1034        assert_eq!(slice.sizes(), &[5]);
1035        assert_eq!(slice.strides(), &[1]);
1036
1037        let result = slice.at(0, 3).unwrap();
1038        assert_eq!(result.num_dim(), 0);
1039        assert_eq!(result.sizes(), &[]);
1040        assert_eq!(result.strides(), &[]);
1041        assert_eq!(result.offset(), 3);
1042        assert_eq!(result.location(&[]).unwrap(), 3);
1043    }
1044
1045    #[test]
1046    fn test_at_2d_to_1d() {
1047        let slice = Slice::new_row_major(vec![3, 4]);
1048        assert_eq!(slice.num_dim(), 2);
1049        assert_eq!(slice.sizes(), &[3, 4]);
1050        assert_eq!(slice.strides(), &[4, 1]);
1051
1052        let result = slice.at(0, 1).unwrap();
1053        assert_eq!(result.num_dim(), 1);
1054        assert_eq!(result.sizes(), &[4]);
1055        assert_eq!(result.strides(), &[1]);
1056        assert_eq!(result.offset(), 4);
1057    }
1058
1059    #[test]
1060    fn test_at_3d_to_2d() {
1061        let slice = Slice::new_row_major(vec![2, 3, 4]);
1062        assert_eq!(slice.num_dim(), 3);
1063        assert_eq!(slice.sizes(), &[2, 3, 4]);
1064        assert_eq!(slice.strides(), &[12, 4, 1]);
1065
1066        let result = slice.at(0, 1).unwrap();
1067        assert_eq!(result.num_dim(), 2);
1068        assert_eq!(result.sizes(), &[3, 4]);
1069        assert_eq!(result.strides(), &[4, 1]);
1070        assert_eq!(result.offset(), 12);
1071    }
1072
1073    #[test]
1074    fn test_get_index_inverse_relationship() {
1075        // Start with a 3 x 3 dense row major matrix.
1076        //
1077        // 0 1 2
1078        // 3 4 5
1079        // 6 7 8
1080        let m = Slice::new_row_major([3, 3]);
1081        assert_eq!(m.offset, 0);
1082        assert_eq!(m.sizes(), &[3, 3]);
1083        assert_eq!(m.strides(), &[3, 1]);
1084
1085        // Slice `m` is 0-offset, row-major, dense, gapless.
1086        for loc in m.iter() {
1087            // ∀ `loc` ∈ `m`, `m.index(loc) == loc`.
1088            assert_eq!(m.index(loc).unwrap(), loc);
1089            // ∀ `loc` ∈ `m`, `m.get(m.index(loc)) == loc`.
1090            assert_eq!(m.get(m.index(loc).unwrap()).unwrap(), loc);
1091        }
1092
1093        // Slice out the middle column.
1094        //    1
1095        //    4
1096        //    7
1097        let c = m.select(1, 1, 2, 1).unwrap();
1098        assert_eq!(c.sizes(), &[3, 1]);
1099        assert_eq!(c.strides(), &[3, 1]);
1100
1101        // Slice `c` has a non-zero offset.
1102        for loc in c.iter() {
1103            // Local rank of `loc` in `c` != loc.
1104            assert_ne!(c.index(loc).unwrap(), loc);
1105            // ∀ `loc` ∈ `c`, `c.get(c.index(loc)) == loc`.
1106            assert_eq!(c.get(c.index(loc).unwrap()).unwrap(), loc);
1107        }
1108    }
1109
1110    #[test]
1111    fn embedding_succeeds_for_contained_view() {
1112        let base = Slice::new(0, vec![4, 4], vec![4, 1]).unwrap(); // 4×4 matrix, row-major
1113        let view = Slice::new(5, vec![2, 2], vec![4, 1]).unwrap(); // a 2×2 submatrix starting at (1,1)
1114
1115        assert!(view.enforce_embedding(&base).is_ok());
1116    }
1117
1118    #[test]
1119    fn embedding_fails_for_out_of_bounds_view() {
1120        let base = Slice::new(0, vec![4, 4], vec![4, 1]).unwrap(); // 4×4 matrix
1121        let view = Slice::new(14, vec![2, 2], vec![4, 1]).unwrap(); // starts at (3,2), accesses (4,3)
1122
1123        assert!(view.enforce_embedding(&base).is_err());
1124    }
1125}