ndslice/
view.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::ops::Index;
10use std::sync::Arc;
11
12use serde::Deserialize;
13use serde::Serialize;
14use thiserror::Error;
15
16use crate::Range;
17use crate::Slice;
18use crate::SliceIterator;
19use crate::slice::CartesianIterator;
20
21/// Errors that can occur when constructing or validating an `Extent`.
22#[derive(Debug, thiserror::Error)]
23pub enum ExtentError {
24    /// The number of labels does not match the number of sizes.
25    ///
26    /// This occurs when constructing an `Extent` from parallel
27    /// `Vec<String>` and `Vec<usize>` inputs that are not the same
28    /// length.
29    #[error("label/sizes dimension mismatch: {num_labels} != {num_sizes}")]
30    DimMismatch {
31        /// Number of dimension labels provided.
32        num_labels: usize,
33        /// Number of dimension sizes provided.
34        num_sizes: usize,
35    },
36}
37
38/// `Extent` defines the logical shape of a multidimensional space by
39/// assigning a size to each named dimension. It abstracts away memory
40/// layout and focuses solely on structure — what dimensions exist and
41/// how many elements each contains.
42///
43/// Conceptually, it corresponds to a coordinate space in the
44/// mathematical sense.
45#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
46pub struct Extent {
47    inner: Arc<ExtentData>,
48}
49
50fn _assert_extent_traits()
51where
52    Extent: Send + Sync + 'static,
53{
54}
55
56// `ExtentData` is represented as:
57// - `labels`: dimension names like `"zone"`, `"host"`, `"gpu"`
58// - `sizes`: number of elements in each dimension, independent of
59//   stride or storage layout
60#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
61struct ExtentData {
62    labels: Vec<String>,
63    sizes: Vec<usize>,
64}
65
66impl Extent {
67    /// Creates a new `Extent` from the given labels and sizes.
68    pub fn new(labels: Vec<String>, sizes: Vec<usize>) -> Result<Self, ExtentError> {
69        if labels.len() != sizes.len() {
70            return Err(ExtentError::DimMismatch {
71                num_labels: labels.len(),
72                num_sizes: sizes.len(),
73            });
74        }
75
76        Ok(Self {
77            inner: Arc::new(ExtentData { labels, sizes }),
78        })
79    }
80
81    /// Returns the ordered list of dimension labels in this extent.
82    pub fn labels(&self) -> &[String] {
83        &self.inner.labels
84    }
85
86    /// Returns the dimension sizes, ordered to match the labels.
87    pub fn sizes(&self) -> &[usize] {
88        &self.inner.sizes
89    }
90
91    /// Returns the size of the dimension with the given label, if it
92    /// exists.
93    pub fn size(&self, label: &str) -> Option<usize> {
94        self.position(label).map(|pos| self.sizes()[pos])
95    }
96
97    /// Returns the position of the dimension with the given label, if
98    /// it exists exists.
99    pub fn position(&self, label: &str) -> Option<usize> {
100        self.labels().iter().position(|l| l == label)
101    }
102
103    /// Creates a `Point` in this extent with the given coordinates.
104    ///
105    /// Returns an error if the coordinate dimensionality does not
106    /// match.
107    pub fn point(&self, coords: Vec<usize>) -> Result<Point, PointError> {
108        if coords.len() != self.len() {
109            return Err(PointError::DimMismatch {
110                expected: self.len(),
111                actual: coords.len(),
112            });
113        }
114
115        Ok(Point {
116            coords,
117            extent: Extent {
118                inner: Arc::clone(&self.inner),
119            },
120        })
121    }
122
123    /// Returns the point corresponding to the provided rank in this extent.
124    pub fn point_of_rank(&self, mut rank: usize) -> Result<Point, PointError> {
125        if rank >= self.num_ranks() {
126            return Err(PointError::OutOfRange {
127                size: self.len(),
128                rank,
129            });
130        }
131
132        let mut stride: usize = self.sizes().iter().product();
133        let mut coords = vec![0; self.len()];
134        for (i, size) in self.sizes().iter().enumerate() {
135            stride /= size;
136            coords[i] = rank / stride;
137            rank %= stride;
138        }
139
140        Ok(Point {
141            coords,
142            extent: self.clone(),
143        })
144    }
145
146    /// The number of dimensions in the extent.
147    pub fn len(&self) -> usize {
148        self.sizes().len()
149    }
150
151    /// Whether the extent has zero dimensionbs.
152    pub fn is_empty(&self) -> bool {
153        self.sizes().is_empty()
154    }
155
156    /// The number of ranks in the extent.
157    pub fn num_ranks(&self) -> usize {
158        self.sizes().iter().product()
159    }
160
161    /// Convert this extent into its labels and sizes.
162    pub fn into_inner(self) -> (Vec<String>, Vec<usize>) {
163        match Arc::try_unwrap(self.inner) {
164            Ok(data) => (data.labels, data.sizes),
165            Err(shared) => (shared.labels.clone(), shared.sizes.clone()),
166        }
167    }
168
169    /// Creates a slice representing the full extent.
170    pub fn to_slice(&self) -> Slice {
171        Slice::new_row_major(self.sizes())
172    }
173
174    /// Iterate over this extens labels and sizes.
175    pub fn iter(&self) -> impl Iterator<Item = (String, usize)> + use<'_> {
176        self.labels()
177            .iter()
178            .zip(self.sizes().iter())
179            .map(|(l, s)| (l.clone(), *s))
180    }
181
182    /// Iterate points in this extent.
183    pub fn points(&self) -> ExtentPointsIterator {
184        ExtentPointsIterator {
185            extent: self,
186            pos: CartesianIterator::new(self.sizes().to_vec()),
187        }
188    }
189}
190
191impl std::fmt::Display for Extent {
192    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193        let n = self.sizes().len();
194        for i in 0..n {
195            write!(f, "{}={}", self.labels()[i], self.sizes()[i])?;
196            if i != n - 1 {
197                write!(f, ",")?;
198            }
199        }
200        Ok(())
201    }
202}
203
204/// An iterator for points in an extent.
205pub struct ExtentPointsIterator<'a> {
206    extent: &'a Extent,
207    pos: CartesianIterator,
208}
209
210impl<'a> Iterator for ExtentPointsIterator<'a> {
211    type Item = Point;
212
213    fn next(&mut self) -> Option<Self::Item> {
214        Some(Point {
215            coords: self.pos.next()?,
216            extent: self.extent.clone(),
217        })
218    }
219}
220
221/// Errors that can occur when constructing or evaluating a `Point`.
222#[derive(Debug, Error)]
223pub enum PointError {
224    /// The number of coordinates does not match the number of
225    /// dimensions defined by the associated extent.
226    ///
227    /// This occurs when creating a `Point` with a coordinate vector
228    /// of incorrect length relative to the dimensionality of the
229    /// extent.
230    #[error("dimension mismatch: expected {expected}, got {actual}")]
231    DimMismatch {
232        /// Number of dimensions expected from the extent.
233        expected: usize,
234        /// Number of coordinates actually provided.
235        actual: usize,
236    },
237
238    /// The point is out of range for the extent.
239    #[error("out of range: size of extent is {size}; does not contain rank {rank}")]
240    OutOfRange { size: usize, rank: usize },
241}
242
243/// `Point` represents a specific coordinate within the
244/// multi-dimensional space defined by an `Extent`.
245///
246/// Coordinate values can be accessed by indexing:
247///
248/// ```
249/// use ndslice::extent;
250///
251/// let ext = extent!(zone = 2, host = 4, gpu = 8);
252/// let point = ext.point(vec![1, 2, 3]).unwrap();
253/// assert_eq!(point[0], 1);
254/// assert_eq!(point[1], 2);
255/// assert_eq!(point[2], 3);
256/// ```
257#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
258pub struct Point {
259    coords: Vec<usize>,
260    extent: Extent,
261}
262
263impl Index<usize> for Point {
264    type Output = usize;
265
266    /// Returns the coordinate value for the given dimension index.
267    /// This allows using `point[0]` syntax instead of
268    /// `point.coords()[0]`.
269    fn index(&self, dim: usize) -> &Self::Output {
270        &self.coords[dim]
271    }
272}
273
274impl<'a> IntoIterator for &'a Point {
275    type Item = usize;
276    type IntoIter = std::iter::Cloned<std::slice::Iter<'a, usize>>;
277
278    /// Iterates over the coordinate values of this point.
279    ///
280    /// This allows using `for coord in &point { ... }` syntax to
281    /// iterate through each dimension's coordinate value.
282    fn into_iter(self) -> Self::IntoIter {
283        self.coords.iter().cloned()
284    }
285}
286
287fn _assert_point_traits()
288where
289    Point: Send + Sync + 'static,
290{
291}
292
293/// Extension trait for creating a `Point` from a coordinate vector
294/// and an `Extent`.
295///
296/// This trait provides the `.in_(&extent)` method, which constructs a
297/// `Point` using the caller as the coordinate vector and the given
298/// extent as the shape context.
299///
300/// # Example
301/// ```
302/// use ndslice::Extent;
303/// use ndslice::view::InExtent;
304/// let extent = Extent::new(vec!["x".into(), "y".into()], vec![3, 4]).unwrap();
305/// let point = vec![1, 2].in_(&extent).unwrap();
306/// assert_eq!(point.rank(), 1 * 4 + 2);
307/// ```
308pub trait InExtent {
309    fn in_(self, extent: &Extent) -> Result<Point, PointError>;
310}
311
312impl InExtent for Vec<usize> {
313    /// Creates a `Point` with the provided coordinates in the given
314    /// extent.
315    ///
316    /// Delegates to `Extent::point`.
317    fn in_(self, extent: &Extent) -> Result<Point, PointError> {
318        extent.point(self)
319    }
320}
321
322impl Point {
323    /// Returns a reference to the coordinate vector for this point.
324    pub fn coords(&self) -> &Vec<usize> {
325        &self.coords
326    }
327
328    /// Returns a reference to the extent associated with this point.
329    pub fn extent(&self) -> &Extent {
330        &self.extent
331    }
332
333    /// Computes the row-major logical rank of this point within its
334    /// extent.
335    ///
336    /// ```text
337    /// Σ (coord[i] × ∏(sizes[j] for j > i))
338    /// ```
339    ///
340    /// where `coord` is the point's coordinate and `sizes` is the
341    /// extent's dimension sizes.
342    pub fn rank(&self) -> usize {
343        let mut stride = 1;
344        let mut result = 0;
345        for (c, size) in self
346            .coords
347            .iter()
348            .rev()
349            .zip(self.extent().sizes().iter().rev())
350        {
351            result += *c * stride;
352            stride *= size;
353        }
354
355        result
356    }
357
358    /// The dimensionality of this point.
359    pub fn len(&self) -> usize {
360        self.coords.len()
361    }
362
363    /// Is this the 0d constant `[]`?
364    pub fn is_empty(&self) -> bool {
365        self.coords.is_empty()
366    }
367}
368
369impl std::fmt::Display for Point {
370    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371        let n = self.coords.len();
372        for i in 0..n {
373            write!(f, "{}={}", self.extent.labels()[i], self.coords[i])?;
374            if i != n - 1 {
375                write!(f, ",")?;
376            }
377        }
378        Ok(())
379    }
380}
381
382/// Errors that occur while operating on views.
383#[derive(Debug, Error)]
384pub enum ViewError {
385    /// The provided dimension does not exist in the relevant extent.
386    #[error("no such dimension: {0}")]
387    InvalidDim(String),
388
389    /// A view was attempted to be constructed from an empty (resolved) range.
390    #[error("empty range: {range} for dimension {dim} of size {size}")]
391    EmptyRange {
392        range: Range,
393        dim: String,
394        size: usize,
395    },
396}
397
398/// A view is a collection of ranks, organized into an extent.
399#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct View {
401    labels: Vec<String>,
402    slice: Slice,
403}
404
405impl View {
406    /// The extent of this view. Every point in this space is defined.
407    pub fn extent(&self) -> Extent {
408        Extent::new(self.labels.clone(), self.slice.sizes().to_vec()).unwrap()
409    }
410
411    /// Iterate over the ranks in this view. The iterator returns both each rank,
412    /// as well as the corresponding point in the extent of this view.
413    pub fn iter(&self) -> ViewIterator {
414        ViewIterator {
415            extent: self.extent(),
416            pos: self.slice.iter(),
417        }
418    }
419}
420
421impl std::fmt::Display for View {
422    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
423        let n = self.labels.len();
424        for i in 0..n {
425            write!(f, "{}={}", self.labels[i], self.slice.sizes()[i])?;
426            if i != n - 1 {
427                write!(f, ",")?;
428            }
429        }
430        Ok(())
431    }
432}
433
434/// The iterator over views.
435pub struct ViewIterator {
436    extent: Extent,     // Note that `extent` and...
437    pos: SliceIterator, // ... `pos` share the same `Slice`.
438}
439
440impl Iterator for ViewIterator {
441    type Item = (Point, usize);
442
443    fn next(&mut self) -> Option<Self::Item> {
444        // This is a rank in the base space.
445        let rank = self.pos.next()?;
446        // Here, we convert to view space.
447        let coords = self.pos.slice.coordinates(rank).unwrap();
448        let point = coords.in_(&self.extent).unwrap();
449        Some((point, rank))
450    }
451}
452
453/// Viewable is a common trait implemented for data structures from which views
454/// may be created. This allows us to provide a consistent API for constructing
455/// and composing views.
456pub trait Viewable {
457    /// The labels of the dimensions in this view.
458    fn labels(&self) -> Vec<String>;
459
460    /// The slice representing this view.
461    /// Note: this representation may change.
462    fn slice(&self) -> Slice;
463}
464
465impl Viewable for View {
466    fn labels(&self) -> Vec<String> {
467        self.labels.clone()
468    }
469
470    fn slice(&self) -> Slice {
471        self.slice.clone()
472    }
473}
474
475impl Viewable for Extent {
476    fn labels(&self) -> Vec<String> {
477        self.labels().to_vec()
478    }
479
480    fn slice(&self) -> Slice {
481        self.to_slice()
482    }
483}
484
485// We would make this impl<T: Viewable> From<T> for View,
486// except this conflicts with the blanket impl for From<&T> for View.
487impl From<Extent> for View {
488    fn from(extent: Extent) -> Self {
489        View {
490            labels: extent.labels().to_vec(),
491            slice: extent.slice(),
492        }
493    }
494}
495
496/// Extension methods for view construction.
497pub trait ViewExt: Viewable {
498    /// Construct a view comprising the range of points along the provided dimension.
499    ///
500    /// ## Examples
501    ///
502    /// ```
503    /// use ndslice::Range;
504    /// use ndslice::ViewExt;
505    /// use ndslice::extent;
506    ///
507    /// let ext = extent!(zone = 4, host = 2, gpu = 8);
508    ///
509    /// // Subselect zone index 0.
510    /// assert_eq!(ext.range("zone", 0).unwrap().iter().count(), 16);
511    ///
512    /// // Even GPUs within zone 0
513    /// assert_eq!(
514    ///     ext.range("zone", 0)
515    ///         .unwrap()
516    ///         .range("gpu", Range(0, None, 2))
517    ///         .unwrap()
518    ///         .iter()
519    ///         .count(),
520    ///     8
521    /// );
522    /// ```
523    fn range<R: Into<Range>>(&self, dim: &str, range: R) -> Result<View, ViewError>;
524
525    /// Group by view on `dim`. The returned iterator enumerates all groups
526    /// as views in the extent of `dim` to the last dimension of the view.
527    ///
528    /// ## Examples
529    ///
530    /// ```
531    /// use ndslice::ViewExt;
532    /// use ndslice::extent;
533    ///
534    /// let ext = extent!(zone = 4, host = 2, gpu = 8);
535    ///
536    /// // We generate one view for each zone.
537    /// assert_eq!(ext.group_by("host").unwrap().count(), 4);
538    ///
539    /// let mut parts = ext.group_by("host").unwrap();
540    ///
541    /// let zone0 = parts.next().unwrap();
542    /// let mut zone0_points = zone0.iter();
543    /// assert_eq!(zone0.extent(), extent!(host = 2, gpu = 8));
544    /// assert_eq!(
545    ///     zone0_points.next().unwrap(),
546    ///     (extent!(host = 2, gpu = 8).point(vec![0, 0]).unwrap(), 0)
547    /// );
548    /// assert_eq!(
549    ///     zone0_points.next().unwrap(),
550    ///     (extent!(host = 2, gpu = 8).point(vec![0, 1]).unwrap(), 1)
551    /// );
552    ///
553    /// let zone1 = parts.next().unwrap();
554    /// assert_eq!(zone1.extent(), extent!(host = 2, gpu = 8));
555    /// assert_eq!(
556    ///     zone1.iter().next().unwrap(),
557    ///     (extent!(host = 2, gpu = 8).point(vec![0, 0]).unwrap(), 16)
558    /// );
559    /// ```
560    fn group_by(&self, dim: &str) -> Result<impl Iterator<Item = View>, ViewError>;
561}
562
563impl<T: Viewable> ViewExt for T {
564    fn range<R: Into<Range>>(&self, dim: &str, range: R) -> Result<View, ViewError> {
565        let range = range.into();
566        let dim = self
567            .labels()
568            .iter()
569            .position(|l| dim == l)
570            .ok_or_else(|| ViewError::InvalidDim(dim.to_string()))?;
571        let (mut offset, mut sizes, mut strides) = self.slice().into_inner();
572        let (begin, end, step) = range.resolve(sizes[dim]);
573        if end <= begin {
574            return Err(ViewError::EmptyRange {
575                range,
576                dim: dim.to_string(),
577                size: sizes[dim],
578            });
579        }
580
581        offset += strides[dim] * begin;
582        sizes[dim] = (end - begin).div_ceil(step);
583        strides[dim] *= step;
584        let slice = Slice::new(offset, sizes, strides).unwrap();
585
586        Ok(View {
587            labels: self.labels().clone(),
588            slice,
589        })
590    }
591
592    fn group_by(&self, dim: &str) -> Result<impl Iterator<Item = View>, ViewError> {
593        let dim = self
594            .labels()
595            .iter()
596            .position(|l| dim == l)
597            .ok_or_else(|| ViewError::InvalidDim(dim.to_string()))?;
598
599        let (offset, sizes, strides) = self.slice().into_inner();
600        let mut ranks = Slice::new(offset, sizes[..dim].to_vec(), strides[..dim].to_vec())
601            .unwrap()
602            .iter();
603
604        let labels = self.labels()[dim..].to_vec();
605        let sizes = sizes[dim..].to_vec();
606        let strides = strides[dim..].to_vec();
607
608        Ok(std::iter::from_fn(move || {
609            let rank = ranks.next()?;
610            let slice = Slice::new(rank, sizes.clone(), strides.clone()).unwrap();
611            Some(View {
612                labels: labels.clone(),
613                slice,
614            })
615        }))
616    }
617}
618
619/// Construct a new extent with the given set of dimension-size pairs.
620///
621/// ```
622/// let s = ndslice::extent!(host = 2, gpu = 8);
623/// assert_eq!(s.labels(), &["host".to_string(), "gpu".to_string()]);
624/// assert_eq!(s.sizes(), &[2, 8]);
625/// ```
626#[macro_export]
627macro_rules! extent {
628    ( $( $label:ident = $size:expr_2021 ),* $(,)? ) => {
629        {
630            let mut labels = Vec::new();
631            let mut sizes = Vec::new();
632
633            $(
634                labels.push(stringify!($label).to_string());
635                sizes.push($size);
636            )*
637
638            $crate::view::Extent::new(labels, sizes).unwrap()
639        }
640    };
641}
642
643#[cfg(test)]
644mod test {
645    use super::*;
646    use crate::Shape;
647    use crate::shape;
648
649    #[test]
650    fn test_points_basic() {
651        let extent = extent!(x = 4, y = 5, z = 6);
652        let _p1 = extent.point(vec![1, 2, 3]).unwrap();
653        let _p2 = vec![1, 2, 3].in_(&extent).unwrap();
654
655        assert_eq!(extent.num_ranks(), 4 * 5 * 6);
656
657        let p3 = extent.point_of_rank(0).unwrap();
658        assert_eq!(p3.coords(), &[0, 0, 0]);
659        assert_eq!(p3.rank(), 0);
660
661        let p4 = extent.point_of_rank(1).unwrap();
662        assert_eq!(p4.coords(), &[0, 0, 1]);
663        assert_eq!(p4.rank(), 1);
664
665        let p5 = extent.point_of_rank(2).unwrap();
666        assert_eq!(p5.coords(), &[0, 0, 2]);
667        assert_eq!(p5.rank(), 2);
668
669        let p6 = extent.point_of_rank(6 * 5 + 1).unwrap();
670        assert_eq!(p6.coords(), &[1, 0, 1]);
671        assert_eq!(p6.rank(), 6 * 5 + 1);
672        assert_eq!(p6[0], 1);
673        assert_eq!(p6[1], 0);
674        assert_eq!(p6[2], 1);
675
676        assert_eq!(extent.points().collect::<Vec<_>>().len(), 4 * 5 * 6);
677        for (rank, point) in extent.points().enumerate() {
678            let &[x, y, z] = &**point.coords() else {
679                panic!("invalid coords");
680            };
681            assert_eq!(z + y * 6 + x * 6 * 5, rank);
682            assert_eq!(point.rank(), rank);
683        }
684    }
685
686    macro_rules! assert_view {
687        ($view:expr, $extent:expr,  $( $($coord:expr),+ => $rank:expr );* $(;)?) => {
688            let view = $view;
689            assert_eq!(view.extent(), $extent);
690            let expected: Vec<_> = vec![$(($extent.point(vec![$($coord),+]).unwrap(), $rank)),*];
691            let actual: Vec<_> = view.iter().collect();
692            assert_eq!(actual, expected);
693        };
694    }
695
696    #[test]
697    fn test_view_basic() {
698        let extent = extent!(x = 4, y = 4);
699        assert_view!(
700            extent.range("x", 0..2).unwrap(),
701            extent!(x = 2, y = 4),
702            0, 0 => 0;
703            0, 1 => 1;
704            0, 2 => 2;
705            0, 3 => 3;
706            1, 0 => 4;
707            1, 1 => 5;
708            1, 2 => 6;
709            1, 3 => 7;
710        );
711        assert_view!(
712            extent.range("x", 1).unwrap().range("y", 2..).unwrap(),
713            extent!(x = 1, y = 2),
714            0, 0 => 6;
715            0, 1 => 7;
716        );
717        assert_view!(
718            extent.range("y", Range(0, None, 2)).unwrap(),
719            extent!(x = 4, y = 2),
720            0, 0 => 0;
721            0, 1 => 2;
722            1, 0 => 4;
723            1, 1 => 6;
724            2, 0 => 8;
725            2, 1 => 10;
726            3, 0 => 12;
727            3, 1 => 14;
728        );
729        assert_view!(
730            extent.range("y", Range(0, None, 2)).unwrap().range("x", 2..).unwrap(),
731            extent!(x = 2, y = 2),
732            0, 0 => 8;
733            0, 1 => 10;
734            1, 0 => 12;
735            1, 1 => 14;
736        );
737
738        let extent = extent!(x = 10, y = 2);
739        assert_view!(
740            extent.range("x", Range(0, None, 2)).unwrap(),
741            extent!(x = 5, y = 2),
742            0, 0 => 0;
743            0, 1 => 1;
744            1, 0 => 4;
745            1, 1 => 5;
746            2, 0 => 8;
747            2, 1 => 9;
748            3, 0 => 12;
749            3, 1 => 13;
750            4, 0 => 16;
751            4, 1 => 17;
752        );
753        assert_view!(
754            extent.range("x", Range(0, None, 2)).unwrap().range("x", 2..).unwrap().range("y", 1).unwrap(),
755            extent!(x = 3, y = 1),
756            0, 0 => 9;
757            1, 0 => 13;
758            2, 0 => 17;
759        );
760
761        let extent = extent!(zone = 4, host = 2, gpu = 8);
762        assert_view!(
763            extent.range("zone", 0).unwrap().range("gpu", Range(0, None, 2)).unwrap(),
764            extent!(zone = 1, host = 2, gpu = 4),
765            0, 0, 0 => 0;
766            0, 0, 1 => 2;
767            0, 0, 2 => 4;
768            0, 0, 3 => 6;
769            0, 1, 0 => 8;
770            0, 1, 1 => 10;
771            0, 1, 2 => 12;
772            0, 1, 3 => 14;
773        );
774
775        let extent = extent!(x = 3);
776        assert_view!(
777            extent.range("x", Range(0, None, 2)).unwrap(),
778            extent!(x = 2),
779            0 => 0;
780            1 => 2;
781        );
782    }
783
784    #[test]
785    fn test_point_indexing() {
786        let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
787        let point = extent.point(vec![1, 2, 3]).unwrap();
788
789        assert_eq!(point[0], 1);
790        assert_eq!(point[1], 2);
791        assert_eq!(point[2], 3);
792    }
793
794    #[test]
795    #[should_panic]
796    fn test_point_indexing_out_of_bounds() {
797        let extent = Extent::new(vec!["x".into(), "y".into()], vec![4, 5]).unwrap();
798        let point = extent.point(vec![1, 2]).unwrap();
799
800        let _ = point[5]; // Should panic
801    }
802
803    #[test]
804    fn test_point_into_iter() {
805        let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
806        let point = extent.point(vec![1, 2, 3]).unwrap();
807
808        let coords: Vec<usize> = (&point).into_iter().collect();
809        assert_eq!(coords, vec![1, 2, 3]);
810
811        let mut sum = 0;
812        for coord in &point {
813            sum += coord;
814        }
815        assert_eq!(sum, 6);
816    }
817
818    #[test]
819    fn test_extent_basic() {
820        let extent = extent!(x = 10, y = 5, z = 1);
821        assert_eq!(
822            extent.iter().collect::<Vec<_>>(),
823            vec![
824                ("x".to_string(), 10),
825                ("y".to_string(), 5),
826                ("z".to_string(), 1)
827            ]
828        );
829    }
830
831    #[test]
832    fn test_extent_display() {
833        let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
834        assert_eq!(format!("{}", extent), "x=4,y=5,z=6");
835
836        let empty_extent = Extent::new(vec![], vec![]).unwrap();
837        assert_eq!(format!("{}", empty_extent), "");
838    }
839
840    #[test]
841    fn test_extent_0d() {
842        let e = Extent::new(vec![], vec![]).unwrap();
843        assert_eq!(e.num_ranks(), 1);
844        let points: Vec<_> = e.points().collect();
845        assert_eq!(points.len(), 1);
846        assert_eq!(points[0].coords(), &[]);
847        assert_eq!(points[0].rank(), 0);
848    }
849
850    #[test]
851    fn test_point_display() {
852        let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
853        let point = extent.point(vec![1, 2, 3]).unwrap();
854        assert_eq!(format!("{}", point), "x=1,y=2,z=3");
855
856        assert!(extent.point(vec![]).is_err());
857
858        let empty_extent = Extent::new(vec![], vec![]).unwrap();
859        let empty_point = empty_extent.point(vec![]).unwrap();
860        assert_eq!(format!("{}", empty_point), "");
861    }
862
863    #[test]
864    fn test_relative_point() {
865        // Given a rank in the root shape, return the corresponding point in the
866        // provided shape, which is a view of the root shape.
867        pub fn relative_point(rank_on_root_mesh: usize, shape: &Shape) -> anyhow::Result<Point> {
868            let coords = shape.slice().coordinates(rank_on_root_mesh)?;
869            let extent = Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec())?;
870            Ok(extent.point(coords)?)
871        }
872
873        let root_shape = shape! { replicas = 4, hosts = 4, gpus = 4 };
874        // rows are `hosts`, cols are gpus
875        // replicas = 0
876        //     0,    1,  2,    3,
877        //     (4),  5,  (6),  7,
878        //     8,    9,  10,   11,
879        //     (12), 13, (14), 15,
880        // replicas = 3, which is [replicas=0] + 48
881        //     48,   49, 50,   51,
882        //     (52), 53, (54), 55,
883        //     56,   57, 58,   59,
884        //     (60), 61, (62), 63,
885        let sliced_shape = root_shape
886            .select("replicas", crate::Range(0, Some(4), 3))
887            .unwrap()
888            .select("hosts", crate::Range(1, Some(4), 2))
889            .unwrap()
890            .select("gpus", crate::Range(0, Some(4), 2))
891            .unwrap();
892        let ranks_on_root_mesh = &[4, 6, 12, 14, 52, 54, 60, 62];
893        assert_eq!(
894            sliced_shape.slice().iter().collect::<Vec<_>>(),
895            ranks_on_root_mesh,
896        );
897
898        let ranks_on_sliced_mesh = ranks_on_root_mesh
899            .iter()
900            .map(|&r| relative_point(r, &sliced_shape).unwrap().rank());
901        assert_eq!(
902            ranks_on_sliced_mesh.collect::<Vec<_>>(),
903            vec![0, 1, 2, 3, 4, 5, 6, 7]
904        );
905    }
906
907    #[test]
908    fn test_iter_subviews() {
909        let extent = extent!(zone = 4, host = 4, gpu = 8);
910
911        assert_eq!(extent.group_by("gpu").unwrap().count(), 16);
912        assert_eq!(extent.group_by("zone").unwrap().count(), 1);
913
914        let mut parts = extent.group_by("gpu").unwrap();
915        assert_view!(
916            parts.next().unwrap(),
917            extent!(gpu = 8),
918            0 => 0;
919            1 => 1;
920            2 => 2;
921            3 => 3;
922            4 => 4;
923            5 => 5;
924            6 => 6;
925            7 => 7;
926        );
927        assert_view!(
928            parts.next().unwrap(),
929            extent!(gpu = 8),
930            0 => 8;
931            1 => 9;
932            2 => 10;
933            3 => 11;
934            4 => 12;
935            5 => 13;
936            6 => 14;
937            7 => 15;
938        );
939    }
940}