ndslice/
shape.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt;
10use std::str::FromStr;
11
12use serde::Deserialize;
13use serde::Serialize;
14
15use crate::DimSliceIterator;
16use crate::Region;
17use crate::Slice;
18use crate::SliceError;
19use crate::selection::Selection;
20use crate::view::Extent;
21
22// We always retain dimensions here even if they are selected out.
23
24#[derive(Debug, thiserror::Error)]
25pub enum ShapeError {
26    #[error("label slice dimension mismatch: {labels_dim} != {slice_dim}")]
27    DimSliceMismatch { labels_dim: usize, slice_dim: usize },
28
29    #[error("invalid labels `{labels:?}`")]
30    InvalidLabels { labels: Vec<String> },
31
32    #[error("empty range {range}")]
33    EmptyRange { range: Range },
34
35    #[error("out of range {range} for dimension {dim} of size {size}")]
36    OutOfRange {
37        range: Range,
38        dim: String,
39        size: usize,
40    },
41
42    #[error("selection `{expr}` exceeds dimensionality {num_dim}")]
43    SelectionTooDeep { expr: Selection, num_dim: usize },
44
45    #[error("dynamic selection `{expr}`")]
46    SelectionDynamic { expr: Selection },
47
48    #[error("{index} out of range for dimension {dim} of size {size}")]
49    IndexOutOfRange {
50        index: usize,
51        dim: String,
52        size: usize,
53    },
54
55    #[error("failed to parse shape: {reason}")]
56    ParseError { reason: String },
57
58    #[error(transparent)]
59    SliceError(#[from] SliceError),
60}
61
62/// A shape is a [`Slice`] with labeled dimensions and a selection API.
63#[derive(Clone, Deserialize, Serialize, PartialEq, Hash, Debug)]
64pub struct Shape {
65    /// The labels for each dimension in slice.
66    labels: Vec<String>,
67    /// The slice itself, which describes the topology of the shape.
68    slice: Slice,
69}
70
71impl Shape {
72    /// Creates a new shape with the provided labels, which describe the
73    /// provided Slice.
74    ///
75    /// Shapes can also be constructed by way of the [`shape`] macro, which
76    /// creates a by-construction correct slice in row-major order given a set of
77    /// sized dimensions.
78    pub fn new(labels: Vec<String>, slice: Slice) -> Result<Self, ShapeError> {
79        if labels.len() != slice.num_dim() {
80            return Err(ShapeError::DimSliceMismatch {
81                labels_dim: labels.len(),
82                slice_dim: slice.num_dim(),
83            });
84        }
85        Ok(Self { labels, slice })
86    }
87
88    /// Select a single index along a named dimension, removing that
89    /// dimension entirely. This reduces the dimensionality by 1. In
90    /// effect it results in a cross section of the shape at the given
91    /// index in the given dimension.
92    pub fn at(&self, label: &str, index: usize) -> Result<Self, ShapeError> {
93        let dim = self.dim(label)?;
94        let slice = self.slice.at(dim, index).map_err(|err| match err {
95            SliceError::IndexOutOfRange { index, total } => ShapeError::OutOfRange {
96                range: Range(index, Some(index + 1), 1),
97                dim: label.to_string(),
98                size: total,
99            },
100            other => other.into(),
101        })?;
102        let mut labels = self.labels.clone();
103        labels.remove(dim);
104        Ok(Self { labels, slice })
105    }
106
107    /// Restrict this shape along a named dimension using a [`Range`].
108    /// The provided range must be nonempty.
109    ///
110    /// `select` is composable, it can be applied repeatedly, even on
111    /// the same dimension, to refine the view incrementally.
112    pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
113        let dim = self.dim(label)?;
114        let range = range.into();
115        let (begin, end, step) = range.resolve(self.slice().sizes()[dim]);
116        let slice = self
117            .slice
118            .select(dim, begin, end, step)
119            .map_err(|err| match err {
120                SliceError::EmptyRange { .. } => ShapeError::EmptyRange { range },
121                SliceError::IndexOutOfRange { total, .. } => ShapeError::OutOfRange {
122                    range,
123                    dim: label.to_string(),
124                    size: total,
125                },
126                other => other.into(),
127            })?;
128        let labels = self.labels.clone();
129        Ok(Self { labels, slice })
130    }
131
132    /// Produces an iterator over subshapes by fixing the first `dims`
133    /// dimensions.
134    ///
135    /// For a shape of rank `n`, this yields `∏ sizes[0..dims]`
136    /// subshapes, each with the first `dims` dimensions restricted to
137    /// size 1. The remaining dimensions are left unconstrained.
138    ///
139    /// This is useful for structured traversal of slices within a
140    /// multidimensional shape. See [`SelectIterator`] for details and
141    /// examples.
142    ///
143    /// # Errors
144    /// Returns an error if `dims == 0` or `dims >= self.rank()`.
145    pub fn select_iter(&self, dims: usize) -> Result<SelectIterator<'_>, ShapeError> {
146        let num_dims = self.slice().num_dim();
147        if dims == 0 || dims >= num_dims {
148            return Err(ShapeError::SliceError(SliceError::IndexOutOfRange {
149                index: dims,
150                total: num_dims,
151            }));
152        }
153
154        Ok(SelectIterator {
155            shape: self,
156            iter: self.slice().dim_iter(dims),
157        })
158    }
159
160    /// Sub-set this shape by select a particular row of the given
161    /// indices The resulting shape will no longer have dimensions for
162    /// the given indices Example shape.index(vec![("gpu", 3),
163    /// ("host", 0)])
164    pub fn index(&self, indices: Vec<(String, usize)>) -> Result<Shape, ShapeError> {
165        let mut shape = self.clone();
166        for (label, index) in indices {
167            shape = shape.at(&label, index)?;
168        }
169        Ok(shape)
170    }
171
172    /// The per-dimension labels of this shape.
173    pub fn labels(&self) -> &[String] {
174        &self.labels
175    }
176
177    /// The slice describing the shape.
178    pub fn slice(&self) -> &Slice {
179        &self.slice
180    }
181
182    /// Return a set of labeled coordinates for the given rank.
183    pub fn coordinates(&self, rank: usize) -> Result<Vec<(String, usize)>, ShapeError> {
184        let coords = self.slice.coordinates(rank)?;
185        Ok(coords
186            .iter()
187            .zip(self.labels.iter())
188            .map(|(i, l)| (l.to_string(), *i))
189            .collect())
190    }
191
192    pub fn dim(&self, label: &str) -> Result<usize, ShapeError> {
193        self.labels
194            .iter()
195            .position(|l| l == label)
196            .ok_or_else(|| ShapeError::InvalidLabels {
197                labels: vec![label.to_string()],
198            })
199    }
200
201    /// Return the 0-dimensional single element shape
202    pub fn unity() -> Shape {
203        Shape::new(vec![], Slice::new(0, vec![], vec![]).expect("unity")).expect("unity")
204    }
205
206    /// The extent corresponding to this shape.
207    pub fn extent(&self) -> Extent {
208        Extent::new(self.labels.clone(), self.slice.sizes().to_vec()).unwrap()
209    }
210
211    /// The region corresponding to this shape.
212    pub fn region(&self) -> Region {
213        self.into()
214    }
215}
216
217impl From<Region> for Shape {
218    fn from(region: Region) -> Self {
219        let (labels, slice) = region.into_inner();
220        Shape::new(labels, slice)
221            .expect("Shape::new should not fail because a Region by definition is a valid Shape")
222    }
223}
224
225impl From<&Region> for Shape {
226    fn from(region: &Region) -> Self {
227        Shape::new(region.labels().to_vec(), region.slice().clone())
228            .expect("Shape::new should not fail because a Region by definition is a valid Shape")
229    }
230}
231
232/// Iterator over subshapes obtained by fixing a prefix of dimensions.
233///
234/// This iterator is produced by [`Shape::select_iter(dims)`], and
235/// yields one `Shape` per coordinate prefix in the first `dims`
236/// dimensions.
237///
238/// For a shape of `n` dimensions, each yielded shape has:
239/// - The first `dims` dimensions restricted to size 1 (i.e., fixed
240///   via `select`)
241/// - The remaining `n - dims` dimensions left unconstrained
242///
243/// This allows structured iteration over "slices" of the original
244/// shape: for example with `n` = 3, `select_iter(1)` walks through 2D
245/// planes, while `select_iter(2)` yields 1D subshapes.
246///
247/// # Example
248/// ```ignore
249/// let s = shape!(zone = 2, host = 2, gpu = 8);
250/// let views: Vec<_> = s.select_iter(2).unwrap().collect();
251/// assert_eq!(views.len(), 4);
252/// assert_eq!(views[0].slice().sizes(), &[1, 1, 8]);
253/// ```
254/// The above example can be interpreted as: for each `(zone, host)`
255/// pair, `select_iter(2)` yields a `Shape` describing the associated
256/// row of GPUs — a view into the `[1, 1, 8]` subregion of the full
257/// `[2, 2, 8]` shape.
258pub struct SelectIterator<'a> {
259    shape: &'a Shape,
260    iter: DimSliceIterator,
261}
262
263impl<'a> Iterator for SelectIterator<'a> {
264    type Item = Shape;
265
266    fn next(&mut self) -> Option<Self::Item> {
267        let pos = self.iter.next()?;
268        let mut shape = self.shape.clone();
269        for (dim, index) in pos.iter().enumerate() {
270            shape = shape.select(&self.shape.labels()[dim], *index).unwrap();
271        }
272        Some(shape)
273    }
274}
275
276impl fmt::Display for Shape {
277    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278        // Just display the sizes of each dimension, for now.
279        // Once we have a selection algebra, we can provide a
280        // better Display implementation.
281        write!(f, "{{")?;
282        for dim in 0..self.labels.len() {
283            write!(f, "{}={}", self.labels[dim], self.slice.sizes()[dim])?;
284            if dim < self.labels.len() - 1 {
285                write!(f, ",")?;
286            }
287        }
288        write!(f, "}}")
289    }
290}
291
292impl FromStr for Shape {
293    type Err = ShapeError;
294
295    fn from_str(s: &str) -> Result<Self, Self::Err> {
296        let s = s.trim();
297
298        if !s.starts_with('{') || !s.ends_with('}') {
299            return Err(ShapeError::ParseError {
300                reason: "shape string must be enclosed in braces".to_string(),
301            });
302        }
303
304        let inner = &s[1..s.len() - 1].trim();
305
306        if inner.is_empty() {
307            return Ok(Shape::unity());
308        }
309
310        let mut labels = Vec::new();
311        let mut sizes = Vec::new();
312
313        for part in inner.split(',') {
314            let part = part.trim();
315            let mut split = part.split('=');
316
317            let label = split
318                .next()
319                .ok_or_else(|| ShapeError::ParseError {
320                    reason: format!("invalid dimension format: '{}'", part),
321                })?
322                .trim();
323
324            let size_str = split
325                .next()
326                .ok_or_else(|| ShapeError::ParseError {
327                    reason: format!("missing size for dimension '{}'", label),
328                })?
329                .trim();
330
331            if split.next().is_some() {
332                return Err(ShapeError::ParseError {
333                    reason: format!("invalid dimension format: '{}'", part),
334                });
335            }
336
337            if label.is_empty() {
338                return Err(ShapeError::ParseError {
339                    reason: format!("missing label in dimension: '{}'", part),
340                });
341            }
342
343            let size = size_str
344                .parse::<usize>()
345                .map_err(|_| ShapeError::ParseError {
346                    reason: format!("invalid size '{}' for dimension '{}'", size_str, label),
347                })?;
348
349            labels.push(label.to_string());
350            sizes.push(size);
351        }
352
353        let slice = Slice::new_row_major(sizes);
354        Shape::new(labels, slice)
355    }
356}
357
358static SHAPE_CACHED_TYPEHASH: std::sync::LazyLock<u64> =
359    std::sync::LazyLock::new(|| cityhasher::hash(<Shape as typeuri::Named>::typename()));
360
361impl typeuri::Named for Shape {
362    fn typename() -> &'static str {
363        "ndslice::shape::Shape"
364    }
365
366    fn typehash() -> u64 {
367        *SHAPE_CACHED_TYPEHASH
368    }
369}
370
371hyperactor_config::impl_attrvalue!(Shape);
372
373/// Construct a new shape with the given set of dimension-size pairs in row-major
374/// order.
375///
376/// ```
377/// let s = ndslice::shape!(host = 2, gpu = 8);
378/// assert_eq!(s.labels(), &["host".to_string(), "gpu".to_string()]);
379/// assert_eq!(s.slice().sizes(), &[2, 8]);
380/// assert_eq!(s.slice().strides(), &[8, 1]);
381/// ```
382#[macro_export]
383macro_rules! shape {
384    ( $( $label:ident = $size:expr ),* $(,)? ) => {
385        {
386            let mut labels = Vec::new();
387            let mut sizes = Vec::new();
388
389            $(
390                labels.push(stringify!($label).to_string());
391                sizes.push($size);
392            )*
393
394            $crate::shape::Shape::new(labels, $crate::Slice::new_row_major(sizes)).unwrap()
395        }
396    };
397}
398
399/// Perform a sub-selection on the provided [`Shape`] object.
400///
401/// This macro chains `.select()` calls to apply multiple labeled
402/// dimension restrictions in a fluent way.
403///
404/// ```
405/// let s = ndslice::shape!(host = 2, gpu = 8);
406/// let s = ndslice::select!(s, host = 1, gpu = 4..).unwrap();
407/// assert_eq!(s.labels(), &["host".to_string(), "gpu".to_string()]);
408/// assert_eq!(s.slice().sizes(), &[1, 4]);
409/// ```
410#[macro_export]
411macro_rules! select {
412    ($shape:ident, $label:ident = $range:expr) => {
413        $shape.select(stringify!($label), $range)
414    };
415
416    ($shape:ident, $label:ident = $range:expr, $($labels:ident = $ranges:expr),+) => {
417        $shape.select(stringify!($label), $range).and_then(|shape| $crate::select!(shape, $($labels = $ranges),+))
418    };
419}
420
421/// A range of indices, with a stride. Ranges are convertible from
422/// native Rust ranges.
423///
424/// Deriving `Eq`, `Ord` and `Hash` is sound because all fields are
425/// `Ord` and comparison is purely structural over `(start, end,
426/// step)`.
427#[derive(
428    Debug,
429    Clone,
430    Eq,
431    Hash,
432    PartialEq,
433    Serialize,
434    Deserialize,
435    PartialOrd,
436    Ord
437)]
438pub struct Range(pub usize, pub Option<usize>, pub usize);
439
440impl Range {
441    pub(crate) fn resolve(&self, size: usize) -> (usize, usize, usize) {
442        match self {
443            Range(begin, Some(end), stride) => (*begin, std::cmp::min(size, *end), *stride),
444            Range(begin, None, stride) => (*begin, size, *stride),
445        }
446    }
447
448    pub(crate) fn is_empty(&self) -> bool {
449        matches!(self, Range(begin, Some(end), _) if end <= begin)
450    }
451}
452
453impl fmt::Display for Range {
454    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
455        match self {
456            Range(begin, None, stride) => write!(f, "{}::{}", begin, stride),
457            Range(begin, Some(end), stride) => write!(f, "{}:{}:{}", begin, end, stride),
458        }
459    }
460}
461
462impl From<std::ops::Range<usize>> for Range {
463    fn from(r: std::ops::Range<usize>) -> Self {
464        Self(r.start, Some(r.end), 1)
465    }
466}
467
468impl From<std::ops::RangeInclusive<usize>> for Range {
469    fn from(r: std::ops::RangeInclusive<usize>) -> Self {
470        Self(*r.start(), Some(*r.end() + 1), 1)
471    }
472}
473
474impl From<std::ops::RangeFrom<usize>> for Range {
475    fn from(r: std::ops::RangeFrom<usize>) -> Self {
476        Self(r.start, None, 1)
477    }
478}
479
480impl From<usize> for Range {
481    fn from(idx: usize) -> Self {
482        Self(idx, Some(idx + 1), 1)
483    }
484}
485
486#[cfg(test)]
487mod tests {
488    use std::assert_matches::assert_matches;
489
490    use super::*;
491
492    #[test]
493    fn test_basic() {
494        let s = shape!(host = 2, gpu = 8);
495        assert_eq!(&s.labels, &["host".to_string(), "gpu".to_string()]);
496        assert_eq!(s.slice.offset(), 0);
497        assert_eq!(s.slice.sizes(), &[2, 8]);
498        assert_eq!(s.slice.strides(), &[8, 1]);
499
500        assert_eq!(s.to_string(), "{host=2,gpu=8}");
501    }
502
503    #[test]
504    fn test_select() {
505        let s = shape!(host = 2, gpu = 8);
506
507        assert_eq!(
508            s.slice().iter().collect::<Vec<_>>(),
509            &[
510                0,
511                1,
512                2,
513                3,
514                4,
515                5,
516                6,
517                7,
518                8,
519                8 + 1,
520                8 + 2,
521                8 + 3,
522                8 + 4,
523                8 + 5,
524                8 + 6,
525                8 + 7
526            ]
527        );
528
529        assert_eq!(
530            select!(s, host = 1)
531                .unwrap()
532                .slice()
533                .iter()
534                .collect::<Vec<_>>(),
535            &[8, 8 + 1, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
536        );
537
538        assert_eq!(
539            select!(s, gpu = 2..)
540                .unwrap()
541                .slice()
542                .iter()
543                .collect::<Vec<_>>(),
544            &[2, 3, 4, 5, 6, 7, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
545        );
546
547        assert_eq!(
548            select!(s, gpu = 3..5)
549                .unwrap()
550                .slice()
551                .iter()
552                .collect::<Vec<_>>(),
553            &[3, 4, 8 + 3, 8 + 4]
554        );
555
556        assert_eq!(
557            select!(s, gpu = 3..5, host = 1)
558                .unwrap()
559                .slice()
560                .iter()
561                .collect::<Vec<_>>(),
562            &[8 + 3, 8 + 4]
563        );
564    }
565
566    #[test]
567    fn test_select_iter() {
568        let s = shape!(replica = 2, host = 2, gpu = 8);
569        let selections: Vec<_> = s.select_iter(2).unwrap().collect();
570        assert_eq!(selections[0].slice().sizes(), &[1, 1, 8]);
571        assert_eq!(selections[1].slice().sizes(), &[1, 1, 8]);
572        assert_eq!(selections[2].slice().sizes(), &[1, 1, 8]);
573        assert_eq!(selections[3].slice().sizes(), &[1, 1, 8]);
574        assert_eq!(
575            selections,
576            &[
577                select!(s, replica = 0, host = 0).unwrap(),
578                select!(s, replica = 0, host = 1).unwrap(),
579                select!(s, replica = 1, host = 0).unwrap(),
580                select!(s, replica = 1, host = 1).unwrap()
581            ]
582        );
583    }
584
585    #[test]
586    fn test_coordinates() {
587        let s = shape!(host = 2, gpu = 8);
588        assert_eq!(
589            s.coordinates(0).unwrap(),
590            vec![("host".to_string(), 0), ("gpu".to_string(), 0)]
591        );
592        assert_eq!(
593            s.coordinates(1).unwrap(),
594            vec![("host".to_string(), 0), ("gpu".to_string(), 1)]
595        );
596        assert_eq!(
597            s.coordinates(8).unwrap(),
598            vec![("host".to_string(), 1), ("gpu".to_string(), 0)]
599        );
600        assert_eq!(
601            s.coordinates(9).unwrap(),
602            vec![("host".to_string(), 1), ("gpu".to_string(), 1)]
603        );
604
605        assert_matches!(
606            s.coordinates(16).unwrap_err(),
607            ShapeError::SliceError(SliceError::ValueNotInSlice { value: 16 })
608        );
609    }
610
611    #[test]
612    fn test_select_bad() {
613        let s = shape!(host = 2, gpu = 8);
614
615        assert_matches!(
616            select!(s, gpu = 1..1).unwrap_err(),
617            ShapeError::EmptyRange {
618                range: Range(1, Some(1), 1)
619            },
620        );
621
622        assert_matches!(
623            select!(s, gpu = 8).unwrap_err(),
624            ShapeError::OutOfRange {
625                range: Range(8, Some(9), 1),
626                dim,
627                size: 8,
628            } if dim == "gpu",
629        );
630    }
631
632    #[test]
633    fn test_shape_index() {
634        let n_hosts = 5;
635        let n_gpus = 7;
636
637        // Index first dim
638        let s = shape!(host = n_hosts, gpu = n_gpus);
639        assert_eq!(
640            s.index(vec![("host".to_string(), 0)]).unwrap(),
641            Shape::new(
642                vec!["gpu".to_string()],
643                Slice::new(0, vec![n_gpus], vec![1]).unwrap()
644            )
645            .unwrap()
646        );
647
648        // Index last dims
649        let offset = 1;
650        assert_eq!(
651            s.index(vec![("gpu".to_string(), offset)]).unwrap(),
652            Shape::new(
653                vec!["host".to_string()],
654                Slice::new(offset, vec![n_hosts], vec![n_gpus]).unwrap()
655            )
656            .unwrap()
657        );
658
659        // Index middle dim
660        let n_zone = 2;
661        let s = shape!(zone = n_zone, host = n_hosts, gpu = n_gpus);
662        let offset = 3;
663        assert_eq!(
664            s.index(vec![("host".to_string(), offset)]).unwrap(),
665            Shape::new(
666                vec!["zone".to_string(), "gpu".to_string()],
667                Slice::new(
668                    offset * n_gpus,
669                    vec![n_zone, n_gpus],
670                    vec![n_hosts * n_gpus, 1]
671                )
672                .unwrap()
673            )
674            .unwrap()
675        );
676
677        // Out of range
678        assert!(
679            shape!(gpu = n_gpus)
680                .index(vec![("gpu".to_string(), n_gpus)])
681                .is_err()
682        );
683        // Invalid dim
684        assert!(
685            shape!(gpu = n_gpus)
686                .index(vec![("non-exist-dim".to_string(), 0)])
687                .is_err()
688        );
689    }
690
691    #[test]
692    fn test_shape_select_stride_rounding() {
693        let shape = shape!(x = 10);
694        // Select x = 0..10 step 3 → expect indices [0, 3, 6, 9]
695        let sub = shape.select("x", Range(0, Some(10), 3)).unwrap();
696        let slice = sub.slice();
697        // 10 / 3 = 3.33..., so ceil(10 / 3) = 4
698        assert_eq!(
699            slice,
700            &Slice::new(0, vec![4], vec![3]).unwrap(),
701            "Expected offset 0, size 4, stride 3"
702        );
703    }
704
705    #[test]
706    fn test_shape_at_removes_dimension() {
707        let labels = vec![
708            "batch".to_string(),
709            "height".to_string(),
710            "width".to_string(),
711        ];
712        let slice = Slice::new_row_major(vec![2, 3, 4]);
713        let shape = Shape::new(labels, slice).unwrap();
714
715        // Select index 1 from "batch" dimension
716        let result = shape.at("batch", 1).unwrap();
717
718        // Should have 2 dimensions now
719        assert_eq!(result.labels(), &["height", "width"]);
720        assert_eq!(result.slice().sizes(), &[3, 4]);
721        assert_eq!(result.slice().offset(), 12); // 1 * 12 (batch stride)
722    }
723
724    #[test]
725    fn test_shape_at_middle_dimension() {
726        let labels = vec![
727            "batch".to_string(),
728            "height".to_string(),
729            "width".to_string(),
730        ];
731        let slice = Slice::new_row_major(vec![2, 3, 4]);
732        let shape = Shape::new(labels, slice).unwrap();
733
734        // Select index 1 from "height" dimension (middle)
735        let result = shape.at("height", 1).unwrap();
736
737        // Should remove middle label
738        assert_eq!(result.labels(), &["batch", "width"]);
739        assert_eq!(result.slice().sizes(), &[2, 4]);
740        assert_eq!(result.slice().offset(), 4); // 1 * 4 (height stride)
741    }
742
743    #[test]
744    fn test_shape_at_invalid_label() {
745        let labels = vec!["batch".to_string(), "height".to_string()];
746        let slice = Slice::new_row_major(vec![2, 3]);
747        let shape = Shape::new(labels, slice).unwrap();
748
749        let result = shape.at("nonexistent", 0);
750        assert!(matches!(result, Err(ShapeError::InvalidLabels { .. })));
751    }
752
753    #[test]
754    fn test_shape_at_index_out_of_range() {
755        let labels = vec!["batch".to_string(), "height".to_string()];
756        let slice = Slice::new_row_major(vec![2, 3]);
757        let shape = Shape::new(labels, slice).unwrap();
758
759        let result = shape.at("batch", 5); // batch only has size 2
760        assert!(matches!(result, Err(ShapeError::OutOfRange { .. })));
761    }
762
763    #[test]
764    fn test_shape_from_str_round_trip() {
765        let test_cases = vec![
766            shape!(host = 2, gpu = 8),
767            shape!(x = 1),
768            shape!(batch = 10, height = 224, width = 224, channels = 3),
769            Shape::unity(), // empty shape
770        ];
771
772        for original in test_cases {
773            let display_str = original.to_string();
774            let parsed: Shape = display_str.parse().unwrap();
775            assert_eq!(
776                parsed, original,
777                "Round-trip failed for shape: {}",
778                display_str
779            );
780        }
781    }
782
783    #[test]
784    fn test_shape_from_str_valid_cases() {
785        let test_cases = vec![
786            ("{host=2,gpu=8}", shape!(host = 2, gpu = 8)),
787            ("{x=1}", shape!(x = 1)),
788            ("{ host = 2 , gpu = 8 }", shape!(host = 2, gpu = 8)), // with spaces
789            ("{}", Shape::unity()),                                // empty shape
790        ];
791
792        for (input, expected) in test_cases {
793            let parsed: Shape = input.parse().unwrap();
794            assert_eq!(parsed, expected, "Failed to parse: {}", input);
795        }
796    }
797
798    #[test]
799    fn test_shape_from_str_error_cases() {
800        let error_cases = vec![
801            "host=2,gpu=8",
802            "{host=2,gpu=8",
803            "host=2,gpu=8}",
804            "{host=2,gpu=}",
805            "{host=,gpu=8}",
806            "{host=2=3,gpu=8}",
807            "{host=abc,gpu=8}",
808            "{host=2,}",
809            "{=8}",
810        ];
811
812        for input in error_cases {
813            let result: Result<Shape, ShapeError> = input.parse();
814            assert!(result.is_err(), "expected error for input: {}", input);
815        }
816    }
817}