ndslice/
shape.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt;
10
11use serde::Deserialize;
12use serde::Serialize;
13
14use crate::DimSliceIterator;
15use crate::Slice;
16use crate::SliceError;
17use crate::selection::Selection;
18use crate::view::Extent;
19
20// We always retain dimensions here even if they are selected out.
21
22#[derive(Debug, thiserror::Error)]
23pub enum ShapeError {
24    #[error("label slice dimension mismatch: {labels_dim} != {slice_dim}")]
25    DimSliceMismatch { labels_dim: usize, slice_dim: usize },
26
27    #[error("invalid labels `{labels:?}`")]
28    InvalidLabels { labels: Vec<String> },
29
30    #[error("empty range {range}")]
31    EmptyRange { range: Range },
32
33    #[error("out of range {range} for dimension {dim} of size {size}")]
34    OutOfRange {
35        range: Range,
36        dim: String,
37        size: usize,
38    },
39
40    #[error("selection `{expr}` exceeds dimensionality {num_dim}")]
41    SelectionTooDeep { expr: Selection, num_dim: usize },
42
43    #[error("dynamic selection `{expr}`")]
44    SelectionDynamic { expr: Selection },
45
46    #[error("{index} out of range for dimension {dim} of size {size}")]
47    IndexOutOfRange {
48        index: usize,
49        dim: String,
50        size: usize,
51    },
52
53    #[error(transparent)]
54    SliceError(#[from] SliceError),
55}
56
57/// A shape is a [`Slice`] with labeled dimensions and a selection API.
58#[derive(Clone, Deserialize, Serialize, PartialEq, Hash, Debug)]
59pub struct Shape {
60    /// The labels for each dimension in slice.
61    labels: Vec<String>,
62    /// The slice itself, which describes the topology of the shape.
63    slice: Slice,
64}
65
66impl Shape {
67    /// Creates a new shape with the provided labels, which describe the
68    /// provided Slice.
69    ///
70    /// Shapes can also be constructed by way of the [`shape`] macro, which
71    /// creates a by-construction correct slice in row-major order given a set of
72    /// sized dimensions.
73    pub fn new(labels: Vec<String>, slice: Slice) -> Result<Self, ShapeError> {
74        if labels.len() != slice.num_dim() {
75            return Err(ShapeError::DimSliceMismatch {
76                labels_dim: labels.len(),
77                slice_dim: slice.num_dim(),
78            });
79        }
80        Ok(Self { labels, slice })
81    }
82
83    /// Select a single index along a named dimension, removing that
84    /// dimension entirely. This reduces the dimensionality by 1. In
85    /// effect it results in a cross section of the shape at the given
86    /// index in the given dimension.
87    pub fn at(&self, label: &str, index: usize) -> Result<Self, ShapeError> {
88        let dim = self.dim(label)?;
89        let slice = self.slice.at(dim, index).map_err(|err| match err {
90            SliceError::IndexOutOfRange { index, total } => ShapeError::OutOfRange {
91                range: Range(index, Some(index + 1), 1),
92                dim: label.to_string(),
93                size: total,
94            },
95            other => other.into(),
96        })?;
97        let mut labels = self.labels.clone();
98        labels.remove(dim);
99        Ok(Self { labels, slice })
100    }
101
102    /// Restrict this shape along a named dimension using a [`Range`].
103    /// The provided range must be nonempty.
104    ///
105    /// `select` is composable, it can be applied repeatedly, even on
106    /// the same dimension, to refine the view incrementally.
107    pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
108        let dim = self.dim(label)?;
109        let range = range.into();
110        let (begin, end, step) = range.resolve(self.slice().sizes()[dim]);
111        let slice = self
112            .slice
113            .select(dim, begin, end, step)
114            .map_err(|err| match err {
115                SliceError::EmptyRange { .. } => ShapeError::EmptyRange { range },
116                SliceError::IndexOutOfRange { total, .. } => ShapeError::OutOfRange {
117                    range,
118                    dim: label.to_string(),
119                    size: total,
120                },
121                other => other.into(),
122            })?;
123        let labels = self.labels.clone();
124        Ok(Self { labels, slice })
125    }
126
127    /// Produces an iterator over subshapes by fixing the first `dims`
128    /// dimensions.
129    ///
130    /// For a shape of rank `n`, this yields `∏ sizes[0..dims]`
131    /// subshapes, each with the first `dims` dimensions restricted to
132    /// size 1. The remaining dimensions are left unconstrained.
133    ///
134    /// This is useful for structured traversal of slices within a
135    /// multidimensional shape. See [`SelectIterator`] for details and
136    /// examples.
137    ///
138    /// # Errors
139    /// Returns an error if `dims == 0` or `dims >= self.rank()`.
140    pub fn select_iter(&self, dims: usize) -> Result<SelectIterator, ShapeError> {
141        let num_dims = self.slice().num_dim();
142        if dims == 0 || dims >= num_dims {
143            return Err(ShapeError::SliceError(SliceError::IndexOutOfRange {
144                index: dims,
145                total: num_dims,
146            }));
147        }
148
149        Ok(SelectIterator {
150            shape: self,
151            iter: self.slice().dim_iter(dims),
152        })
153    }
154
155    /// Sub-set this shape by select a particular row of the given
156    /// indices The resulting shape will no longer have dimensions for
157    /// the given indices Example shape.index(vec![("gpu", 3),
158    /// ("host", 0)])
159    pub fn index(&self, indices: Vec<(String, usize)>) -> Result<Shape, ShapeError> {
160        let mut shape = self.clone();
161        for (label, index) in indices {
162            shape = shape.at(&label, index)?;
163        }
164        Ok(shape)
165    }
166
167    /// The per-dimension labels of this shape.
168    pub fn labels(&self) -> &[String] {
169        &self.labels
170    }
171
172    /// The slice describing the shape.
173    pub fn slice(&self) -> &Slice {
174        &self.slice
175    }
176
177    /// Return a set of labeled coordinates for the given rank.
178    pub fn coordinates(&self, rank: usize) -> Result<Vec<(String, usize)>, ShapeError> {
179        let coords = self.slice.coordinates(rank)?;
180        Ok(coords
181            .iter()
182            .zip(self.labels.iter())
183            .map(|(i, l)| (l.to_string(), *i))
184            .collect())
185    }
186
187    pub fn dim(&self, label: &str) -> Result<usize, ShapeError> {
188        self.labels
189            .iter()
190            .position(|l| l == label)
191            .ok_or_else(|| ShapeError::InvalidLabels {
192                labels: vec![label.to_string()],
193            })
194    }
195
196    /// Return the 0-dimensional single element shape
197    pub fn unity() -> Shape {
198        Shape::new(vec![], Slice::new(0, vec![], vec![]).expect("unity")).expect("unity")
199    }
200
201    /// The extent corresponding to this shape.
202    pub fn extent(&self) -> Extent {
203        Extent::new(self.labels.clone(), self.slice.sizes().to_vec()).unwrap()
204    }
205}
206
207/// Iterator over subshapes obtained by fixing a prefix of dimensions.
208///
209/// This iterator is produced by [`Shape::select_iter(dims)`], and
210/// yields one `Shape` per coordinate prefix in the first `dims`
211/// dimensions.
212///
213/// For a shape of `n` dimensions, each yielded shape has:
214/// - The first `dims` dimensions restricted to size 1 (i.e., fixed
215///   via `select`)
216/// - The remaining `n - dims` dimensions left unconstrained
217///
218/// This allows structured iteration over "slices" of the original
219/// shape: for example with `n` = 3, `select_iter(1)` walks through 2D
220/// planes, while `select_iter(2)` yields 1D subshapes.
221///
222/// # Example
223/// ```ignore
224/// let s = shape!(zone = 2, host = 2, gpu = 8);
225/// let views: Vec<_> = s.select_iter(2).unwrap().collect();
226/// assert_eq!(views.len(), 4);
227/// assert_eq!(views[0].slice().sizes(), &[1, 1, 8]);
228/// ```
229/// The above example can be interpreted as: for each `(zone, host)`
230/// pair, `select_iter(2)` yields a `Shape` describing the associated
231/// row of GPUs — a view into the `[1, 1, 8]` subregion of the full
232/// `[2, 2, 8]` shape.
233pub struct SelectIterator<'a> {
234    shape: &'a Shape,
235    iter: DimSliceIterator,
236}
237
238impl<'a> Iterator for SelectIterator<'a> {
239    type Item = Shape;
240
241    fn next(&mut self) -> Option<Self::Item> {
242        let pos = self.iter.next()?;
243        let mut shape = self.shape.clone();
244        for (dim, index) in pos.iter().enumerate() {
245            shape = shape.select(&self.shape.labels()[dim], *index).unwrap();
246        }
247        Some(shape)
248    }
249}
250
251impl fmt::Display for Shape {
252    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253        // Just display the sizes of each dimension, for now.
254        // Once we have a selection algebra, we can provide a
255        // better Display implementation.
256        write!(f, "{{")?;
257        for dim in 0..self.labels.len() {
258            write!(f, "{}={}", self.labels[dim], self.slice.sizes()[dim])?;
259            if dim < self.labels.len() - 1 {
260                write!(f, ",")?;
261            }
262        }
263        write!(f, "}}")
264    }
265}
266
267/// Construct a new shape with the given set of dimension-size pairs in row-major
268/// order.
269///
270/// ```
271/// let s = ndslice::shape!(host = 2, gpu = 8);
272/// assert_eq!(s.labels(), &["host".to_string(), "gpu".to_string()]);
273/// assert_eq!(s.slice().sizes(), &[2, 8]);
274/// assert_eq!(s.slice().strides(), &[8, 1]);
275/// ```
276#[macro_export]
277macro_rules! shape {
278    ( $( $label:ident = $size:expr_2021 ),* $(,)? ) => {
279        {
280            let mut labels = Vec::new();
281            let mut sizes = Vec::new();
282
283            $(
284                labels.push(stringify!($label).to_string());
285                sizes.push($size);
286            )*
287
288            $crate::shape::Shape::new(labels, $crate::Slice::new_row_major(sizes)).unwrap()
289        }
290    };
291}
292
293/// Perform a sub-selection on the provided [`Shape`] object.
294///
295/// This macro chains `.select()` calls to apply multiple labeled
296/// dimension restrictions in a fluent way.
297///
298/// ```
299/// let s = ndslice::shape!(host = 2, gpu = 8);
300/// let s = ndslice::select!(s, host = 1, gpu = 4..).unwrap();
301/// assert_eq!(s.labels(), &["host".to_string(), "gpu".to_string()]);
302/// assert_eq!(s.slice().sizes(), &[1, 4]);
303/// ```
304#[macro_export]
305macro_rules! select {
306    ($shape:ident, $label:ident = $range:expr_2021) => {
307        $shape.select(stringify!($label), $range)
308    };
309
310    ($shape:ident, $label:ident = $range:expr_2021, $($labels:ident = $ranges:expr_2021),+) => {
311        $shape.select(stringify!($label), $range).and_then(|shape| $crate::select!(shape, $($labels = $ranges),+))
312    };
313}
314
315/// A range of indices, with a stride. Ranges are convertible from
316/// native Rust ranges.
317///
318/// Deriving `Eq`, `Ord` and `Hash` is sound because all fields are
319/// `Ord` and comparison is purely structural over `(start, end,
320/// step)`.
321#[derive(
322    Debug,
323    Clone,
324    Eq,
325    Hash,
326    PartialEq,
327    Serialize,
328    Deserialize,
329    PartialOrd,
330    Ord
331)]
332pub struct Range(pub usize, pub Option<usize>, pub usize);
333
334impl Range {
335    pub(crate) fn resolve(&self, size: usize) -> (usize, usize, usize) {
336        match self {
337            Range(begin, Some(end), stride) => (*begin, std::cmp::min(size, *end), *stride),
338            Range(begin, None, stride) => (*begin, size, *stride),
339        }
340    }
341
342    pub(crate) fn is_empty(&self) -> bool {
343        matches!(self, Range(begin, Some(end), _) if end <= begin)
344    }
345}
346
347impl fmt::Display for Range {
348    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
349        match self {
350            Range(begin, None, stride) => write!(f, "{}::{}", begin, stride),
351            Range(begin, Some(end), stride) => write!(f, "{}:{}:{}", begin, end, stride),
352        }
353    }
354}
355
356impl From<std::ops::Range<usize>> for Range {
357    fn from(r: std::ops::Range<usize>) -> Self {
358        Self(r.start, Some(r.end), 1)
359    }
360}
361
362impl From<std::ops::RangeInclusive<usize>> for Range {
363    fn from(r: std::ops::RangeInclusive<usize>) -> Self {
364        Self(*r.start(), Some(*r.end() + 1), 1)
365    }
366}
367
368impl From<std::ops::RangeFrom<usize>> for Range {
369    fn from(r: std::ops::RangeFrom<usize>) -> Self {
370        Self(r.start, None, 1)
371    }
372}
373
374impl From<usize> for Range {
375    fn from(idx: usize) -> Self {
376        Self(idx, Some(idx + 1), 1)
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use std::assert_matches::assert_matches;
383
384    use super::*;
385
386    #[test]
387    fn test_basic() {
388        let s = shape!(host = 2, gpu = 8);
389        assert_eq!(&s.labels, &["host".to_string(), "gpu".to_string()]);
390        assert_eq!(s.slice.offset(), 0);
391        assert_eq!(s.slice.sizes(), &[2, 8]);
392        assert_eq!(s.slice.strides(), &[8, 1]);
393
394        assert_eq!(s.to_string(), "{host=2,gpu=8}");
395    }
396
397    #[test]
398    fn test_select() {
399        let s = shape!(host = 2, gpu = 8);
400
401        assert_eq!(
402            s.slice().iter().collect::<Vec<_>>(),
403            &[
404                0,
405                1,
406                2,
407                3,
408                4,
409                5,
410                6,
411                7,
412                8,
413                8 + 1,
414                8 + 2,
415                8 + 3,
416                8 + 4,
417                8 + 5,
418                8 + 6,
419                8 + 7
420            ]
421        );
422
423        assert_eq!(
424            select!(s, host = 1)
425                .unwrap()
426                .slice()
427                .iter()
428                .collect::<Vec<_>>(),
429            &[8, 8 + 1, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
430        );
431
432        assert_eq!(
433            select!(s, gpu = 2..)
434                .unwrap()
435                .slice()
436                .iter()
437                .collect::<Vec<_>>(),
438            &[2, 3, 4, 5, 6, 7, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
439        );
440
441        assert_eq!(
442            select!(s, gpu = 3..5)
443                .unwrap()
444                .slice()
445                .iter()
446                .collect::<Vec<_>>(),
447            &[3, 4, 8 + 3, 8 + 4]
448        );
449
450        assert_eq!(
451            select!(s, gpu = 3..5, host = 1)
452                .unwrap()
453                .slice()
454                .iter()
455                .collect::<Vec<_>>(),
456            &[8 + 3, 8 + 4]
457        );
458    }
459
460    #[test]
461    fn test_select_iter() {
462        let s = shape!(replica = 2, host = 2, gpu = 8);
463        let selections: Vec<_> = s.select_iter(2).unwrap().collect();
464        assert_eq!(selections[0].slice().sizes(), &[1, 1, 8]);
465        assert_eq!(selections[1].slice().sizes(), &[1, 1, 8]);
466        assert_eq!(selections[2].slice().sizes(), &[1, 1, 8]);
467        assert_eq!(selections[3].slice().sizes(), &[1, 1, 8]);
468        assert_eq!(
469            selections,
470            &[
471                select!(s, replica = 0, host = 0).unwrap(),
472                select!(s, replica = 0, host = 1).unwrap(),
473                select!(s, replica = 1, host = 0).unwrap(),
474                select!(s, replica = 1, host = 1).unwrap()
475            ]
476        );
477    }
478
479    #[test]
480    fn test_coordinates() {
481        let s = shape!(host = 2, gpu = 8);
482        assert_eq!(
483            s.coordinates(0).unwrap(),
484            vec![("host".to_string(), 0), ("gpu".to_string(), 0)]
485        );
486        assert_eq!(
487            s.coordinates(1).unwrap(),
488            vec![("host".to_string(), 0), ("gpu".to_string(), 1)]
489        );
490        assert_eq!(
491            s.coordinates(8).unwrap(),
492            vec![("host".to_string(), 1), ("gpu".to_string(), 0)]
493        );
494        assert_eq!(
495            s.coordinates(9).unwrap(),
496            vec![("host".to_string(), 1), ("gpu".to_string(), 1)]
497        );
498
499        assert_matches!(
500            s.coordinates(16).unwrap_err(),
501            ShapeError::SliceError(SliceError::ValueNotInSlice { value: 16 })
502        );
503    }
504
505    #[test]
506    fn test_select_bad() {
507        let s = shape!(host = 2, gpu = 8);
508
509        assert_matches!(
510            select!(s, gpu = 1..1).unwrap_err(),
511            ShapeError::EmptyRange {
512                range: Range(1, Some(1), 1)
513            },
514        );
515
516        assert_matches!(
517            select!(s, gpu = 8).unwrap_err(),
518            ShapeError::OutOfRange {
519                range: Range(8, Some(9), 1),
520                dim,
521                size: 8,
522            } if dim == "gpu",
523        );
524    }
525
526    #[test]
527    fn test_shape_index() {
528        let n_hosts = 5;
529        let n_gpus = 7;
530
531        // Index first dim
532        let s = shape!(host = n_hosts, gpu = n_gpus);
533        assert_eq!(
534            s.index(vec![("host".to_string(), 0)]).unwrap(),
535            Shape::new(
536                vec!["gpu".to_string()],
537                Slice::new(0, vec![n_gpus], vec![1]).unwrap()
538            )
539            .unwrap()
540        );
541
542        // Index last dims
543        let offset = 1;
544        assert_eq!(
545            s.index(vec![("gpu".to_string(), offset)]).unwrap(),
546            Shape::new(
547                vec!["host".to_string()],
548                Slice::new(offset, vec![n_hosts], vec![n_gpus]).unwrap()
549            )
550            .unwrap()
551        );
552
553        // Index middle dim
554        let n_zone = 2;
555        let s = shape!(zone = n_zone, host = n_hosts, gpu = n_gpus);
556        let offset = 3;
557        assert_eq!(
558            s.index(vec![("host".to_string(), offset)]).unwrap(),
559            Shape::new(
560                vec!["zone".to_string(), "gpu".to_string()],
561                Slice::new(
562                    offset * n_gpus,
563                    vec![n_zone, n_gpus],
564                    vec![n_hosts * n_gpus, 1]
565                )
566                .unwrap()
567            )
568            .unwrap()
569        );
570
571        // Out of range
572        assert!(
573            shape!(gpu = n_gpus)
574                .index(vec![("gpu".to_string(), n_gpus)])
575                .is_err()
576        );
577        // Invalid dim
578        assert!(
579            shape!(gpu = n_gpus)
580                .index(vec![("non-exist-dim".to_string(), 0)])
581                .is_err()
582        );
583    }
584
585    #[test]
586    fn test_shape_select_stride_rounding() {
587        let shape = shape!(x = 10);
588        // Select x = 0..10 step 3 → expect indices [0, 3, 6, 9]
589        let sub = shape.select("x", Range(0, Some(10), 3)).unwrap();
590        let slice = sub.slice();
591        // 10 / 3 = 3.33..., so ceil(10 / 3) = 4
592        assert_eq!(
593            slice,
594            &Slice::new(0, vec![4], vec![3]).unwrap(),
595            "Expected offset 0, size 4, stride 3"
596        );
597    }
598
599    #[test]
600    fn test_shape_at_removes_dimension() {
601        let labels = vec![
602            "batch".to_string(),
603            "height".to_string(),
604            "width".to_string(),
605        ];
606        let slice = Slice::new_row_major(vec![2, 3, 4]);
607        let shape = Shape::new(labels, slice).unwrap();
608
609        // Select index 1 from "batch" dimension
610        let result = shape.at("batch", 1).unwrap();
611
612        // Should have 2 dimensions now
613        assert_eq!(result.labels(), &["height", "width"]);
614        assert_eq!(result.slice().sizes(), &[3, 4]);
615        assert_eq!(result.slice().offset(), 12); // 1 * 12 (batch stride)
616    }
617
618    #[test]
619    fn test_shape_at_middle_dimension() {
620        let labels = vec![
621            "batch".to_string(),
622            "height".to_string(),
623            "width".to_string(),
624        ];
625        let slice = Slice::new_row_major(vec![2, 3, 4]);
626        let shape = Shape::new(labels, slice).unwrap();
627
628        // Select index 1 from "height" dimension (middle)
629        let result = shape.at("height", 1).unwrap();
630
631        // Should remove middle label
632        assert_eq!(result.labels(), &["batch", "width"]);
633        assert_eq!(result.slice().sizes(), &[2, 4]);
634        assert_eq!(result.slice().offset(), 4); // 1 * 4 (height stride)
635    }
636
637    #[test]
638    fn test_shape_at_invalid_label() {
639        let labels = vec!["batch".to_string(), "height".to_string()];
640        let slice = Slice::new_row_major(vec![2, 3]);
641        let shape = Shape::new(labels, slice).unwrap();
642
643        let result = shape.at("nonexistent", 0);
644        assert!(matches!(result, Err(ShapeError::InvalidLabels { .. })));
645    }
646
647    #[test]
648    fn test_shape_at_index_out_of_range() {
649        let labels = vec!["batch".to_string(), "height".to_string()];
650        let slice = Slice::new_row_major(vec![2, 3]);
651        let shape = Shape::new(labels, slice).unwrap();
652
653        let result = shape.at("batch", 5); // batch only has size 2
654        assert!(matches!(result, Err(ShapeError::OutOfRange { .. })));
655    }
656}