ndslice/
shape.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt;
10use std::str::FromStr;
11
12use serde::Deserialize;
13use serde::Serialize;
14
15use crate::DimSliceIterator;
16use crate::Region;
17use crate::Slice;
18use crate::SliceError;
19use crate::selection::Selection;
20use crate::view::Extent;
21
22// We always retain dimensions here even if they are selected out.
23
24#[derive(Debug, thiserror::Error)]
25pub enum ShapeError {
26    #[error("label slice dimension mismatch: {labels_dim} != {slice_dim}")]
27    DimSliceMismatch { labels_dim: usize, slice_dim: usize },
28
29    #[error("invalid labels `{labels:?}`")]
30    InvalidLabels { labels: Vec<String> },
31
32    #[error("empty range {range}")]
33    EmptyRange { range: Range },
34
35    #[error("out of range {range} for dimension {dim} of size {size}")]
36    OutOfRange {
37        range: Range,
38        dim: String,
39        size: usize,
40    },
41
42    #[error("selection `{expr}` exceeds dimensionality {num_dim}")]
43    SelectionTooDeep { expr: Selection, num_dim: usize },
44
45    #[error("dynamic selection `{expr}`")]
46    SelectionDynamic { expr: Selection },
47
48    #[error("{index} out of range for dimension {dim} of size {size}")]
49    IndexOutOfRange {
50        index: usize,
51        dim: String,
52        size: usize,
53    },
54
55    #[error("failed to parse shape: {reason}")]
56    ParseError { reason: String },
57
58    #[error(transparent)]
59    SliceError(#[from] SliceError),
60}
61
62/// A shape is a [`Slice`] with labeled dimensions and a selection API.
63#[derive(Clone, Deserialize, Serialize, PartialEq, Hash, Debug)]
64pub struct Shape {
65    /// The labels for each dimension in slice.
66    labels: Vec<String>,
67    /// The slice itself, which describes the topology of the shape.
68    slice: Slice,
69}
70
71impl Shape {
72    /// Creates a new shape with the provided labels, which describe the
73    /// provided Slice.
74    ///
75    /// Shapes can also be constructed by way of the [`shape`] macro, which
76    /// creates a by-construction correct slice in row-major order given a set of
77    /// sized dimensions.
78    pub fn new(labels: Vec<String>, slice: Slice) -> Result<Self, ShapeError> {
79        if labels.len() != slice.num_dim() {
80            return Err(ShapeError::DimSliceMismatch {
81                labels_dim: labels.len(),
82                slice_dim: slice.num_dim(),
83            });
84        }
85        Ok(Self { labels, slice })
86    }
87
88    /// Select a single index along a named dimension, removing that
89    /// dimension entirely. This reduces the dimensionality by 1. In
90    /// effect it results in a cross section of the shape at the given
91    /// index in the given dimension.
92    pub fn at(&self, label: &str, index: usize) -> Result<Self, ShapeError> {
93        let dim = self.dim(label)?;
94        let slice = self.slice.at(dim, index).map_err(|err| match err {
95            SliceError::IndexOutOfRange { index, total } => ShapeError::OutOfRange {
96                range: Range(index, Some(index + 1), 1),
97                dim: label.to_string(),
98                size: total,
99            },
100            other => other.into(),
101        })?;
102        let mut labels = self.labels.clone();
103        labels.remove(dim);
104        Ok(Self { labels, slice })
105    }
106
107    /// Restrict this shape along a named dimension using a [`Range`].
108    /// The provided range must be nonempty.
109    ///
110    /// `select` is composable, it can be applied repeatedly, even on
111    /// the same dimension, to refine the view incrementally.
112    pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
113        let dim = self.dim(label)?;
114        let range = range.into();
115        let (begin, end, step) = range.resolve(self.slice().sizes()[dim]);
116        let slice = self
117            .slice
118            .select(dim, begin, end, step)
119            .map_err(|err| match err {
120                SliceError::EmptyRange { .. } => ShapeError::EmptyRange { range },
121                SliceError::IndexOutOfRange { total, .. } => ShapeError::OutOfRange {
122                    range,
123                    dim: label.to_string(),
124                    size: total,
125                },
126                other => other.into(),
127            })?;
128        let labels = self.labels.clone();
129        Ok(Self { labels, slice })
130    }
131
132    /// Produces an iterator over subshapes by fixing the first `dims`
133    /// dimensions.
134    ///
135    /// For a shape of rank `n`, this yields `∏ sizes[0..dims]`
136    /// subshapes, each with the first `dims` dimensions restricted to
137    /// size 1. The remaining dimensions are left unconstrained.
138    ///
139    /// This is useful for structured traversal of slices within a
140    /// multidimensional shape. See [`SelectIterator`] for details and
141    /// examples.
142    ///
143    /// # Errors
144    /// Returns an error if `dims == 0` or `dims >= self.rank()`.
145    pub fn select_iter(&self, dims: usize) -> Result<SelectIterator<'_>, ShapeError> {
146        let num_dims = self.slice().num_dim();
147        if dims == 0 || dims >= num_dims {
148            return Err(ShapeError::SliceError(SliceError::IndexOutOfRange {
149                index: dims,
150                total: num_dims,
151            }));
152        }
153
154        Ok(SelectIterator {
155            shape: self,
156            iter: self.slice().dim_iter(dims),
157        })
158    }
159
160    /// Sub-set this shape by select a particular row of the given
161    /// indices The resulting shape will no longer have dimensions for
162    /// the given indices Example shape.index(vec![("gpu", 3),
163    /// ("host", 0)])
164    pub fn index(&self, indices: Vec<(String, usize)>) -> Result<Shape, ShapeError> {
165        let mut shape = self.clone();
166        for (label, index) in indices {
167            shape = shape.at(&label, index)?;
168        }
169        Ok(shape)
170    }
171
172    /// The per-dimension labels of this shape.
173    pub fn labels(&self) -> &[String] {
174        &self.labels
175    }
176
177    /// The slice describing the shape.
178    pub fn slice(&self) -> &Slice {
179        &self.slice
180    }
181
182    /// Return a set of labeled coordinates for the given rank.
183    pub fn coordinates(&self, rank: usize) -> Result<Vec<(String, usize)>, ShapeError> {
184        let coords = self.slice.coordinates(rank)?;
185        Ok(coords
186            .iter()
187            .zip(self.labels.iter())
188            .map(|(i, l)| (l.to_string(), *i))
189            .collect())
190    }
191
192    pub fn dim(&self, label: &str) -> Result<usize, ShapeError> {
193        self.labels
194            .iter()
195            .position(|l| l == label)
196            .ok_or_else(|| ShapeError::InvalidLabels {
197                labels: vec![label.to_string()],
198            })
199    }
200
201    /// Return the 0-dimensional single element shape
202    pub fn unity() -> Shape {
203        Shape::new(vec![], Slice::new(0, vec![], vec![]).expect("unity")).expect("unity")
204    }
205
206    /// The extent corresponding to this shape.
207    pub fn extent(&self) -> Extent {
208        Extent::new(self.labels.clone(), self.slice.sizes().to_vec()).unwrap()
209    }
210
211    /// The region corresponding to this shape.
212    pub fn region(&self) -> Region {
213        self.into()
214    }
215}
216
217impl From<Region> for Shape {
218    fn from(region: Region) -> Self {
219        let (labels, slice) = region.into_inner();
220        Shape::new(labels, slice)
221            .expect("Shape::new should not fail because a Region by definition is a valid Shape")
222    }
223}
224
225impl From<&Region> for Shape {
226    fn from(region: &Region) -> Self {
227        Shape::new(region.labels().to_vec(), region.slice().clone())
228            .expect("Shape::new should not fail because a Region by definition is a valid Shape")
229    }
230}
231
232/// Iterator over subshapes obtained by fixing a prefix of dimensions.
233///
234/// This iterator is produced by [`Shape::select_iter(dims)`], and
235/// yields one `Shape` per coordinate prefix in the first `dims`
236/// dimensions.
237///
238/// For a shape of `n` dimensions, each yielded shape has:
239/// - The first `dims` dimensions restricted to size 1 (i.e., fixed
240///   via `select`)
241/// - The remaining `n - dims` dimensions left unconstrained
242///
243/// This allows structured iteration over "slices" of the original
244/// shape: for example with `n` = 3, `select_iter(1)` walks through 2D
245/// planes, while `select_iter(2)` yields 1D subshapes.
246///
247/// # Example
248/// ```ignore
249/// let s = shape!(zone = 2, host = 2, gpu = 8);
250/// let views: Vec<_> = s.select_iter(2).unwrap().collect();
251/// assert_eq!(views.len(), 4);
252/// assert_eq!(views[0].slice().sizes(), &[1, 1, 8]);
253/// ```
254/// The above example can be interpreted as: for each `(zone, host)`
255/// pair, `select_iter(2)` yields a `Shape` describing the associated
256/// row of GPUs — a view into the `[1, 1, 8]` subregion of the full
257/// `[2, 2, 8]` shape.
258pub struct SelectIterator<'a> {
259    shape: &'a Shape,
260    iter: DimSliceIterator,
261}
262
263impl<'a> Iterator for SelectIterator<'a> {
264    type Item = Shape;
265
266    fn next(&mut self) -> Option<Self::Item> {
267        let pos = self.iter.next()?;
268        let mut shape = self.shape.clone();
269        for (dim, index) in pos.iter().enumerate() {
270            shape = shape.select(&self.shape.labels()[dim], *index).unwrap();
271        }
272        Some(shape)
273    }
274}
275
276impl fmt::Display for Shape {
277    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278        // Just display the sizes of each dimension, for now.
279        // Once we have a selection algebra, we can provide a
280        // better Display implementation.
281        write!(f, "{{")?;
282        for dim in 0..self.labels.len() {
283            write!(f, "{}={}", self.labels[dim], self.slice.sizes()[dim])?;
284            if dim < self.labels.len() - 1 {
285                write!(f, ",")?;
286            }
287        }
288        write!(f, "}}")
289    }
290}
291
292impl FromStr for Shape {
293    type Err = ShapeError;
294
295    fn from_str(s: &str) -> Result<Self, Self::Err> {
296        let s = s.trim();
297
298        if !s.starts_with('{') || !s.ends_with('}') {
299            return Err(ShapeError::ParseError {
300                reason: "shape string must be enclosed in braces".to_string(),
301            });
302        }
303
304        let inner = &s[1..s.len() - 1].trim();
305
306        if inner.is_empty() {
307            return Ok(Shape::unity());
308        }
309
310        let mut labels = Vec::new();
311        let mut sizes = Vec::new();
312
313        for part in inner.split(',') {
314            let part = part.trim();
315            let mut split = part.split('=');
316
317            let label = split
318                .next()
319                .ok_or_else(|| ShapeError::ParseError {
320                    reason: format!("invalid dimension format: '{}'", part),
321                })?
322                .trim();
323
324            let size_str = split
325                .next()
326                .ok_or_else(|| ShapeError::ParseError {
327                    reason: format!("missing size for dimension '{}'", label),
328                })?
329                .trim();
330
331            if split.next().is_some() {
332                return Err(ShapeError::ParseError {
333                    reason: format!("invalid dimension format: '{}'", part),
334                });
335            }
336
337            if label.is_empty() {
338                return Err(ShapeError::ParseError {
339                    reason: format!("missing label in dimension: '{}'", part),
340                });
341            }
342
343            let size = size_str
344                .parse::<usize>()
345                .map_err(|_| ShapeError::ParseError {
346                    reason: format!("invalid size '{}' for dimension '{}'", size_str, label),
347                })?;
348
349            labels.push(label.to_string());
350            sizes.push(size);
351        }
352
353        let slice = Slice::new_row_major(sizes);
354        Shape::new(labels, slice)
355    }
356}
357
358/// Construct a new shape with the given set of dimension-size pairs in row-major
359/// order.
360///
361/// ```
362/// let s = ndslice::shape!(host = 2, gpu = 8);
363/// assert_eq!(s.labels(), &["host".to_string(), "gpu".to_string()]);
364/// assert_eq!(s.slice().sizes(), &[2, 8]);
365/// assert_eq!(s.slice().strides(), &[8, 1]);
366/// ```
367#[macro_export]
368macro_rules! shape {
369    ( $( $label:ident = $size:expr ),* $(,)? ) => {
370        {
371            let mut labels = Vec::new();
372            let mut sizes = Vec::new();
373
374            $(
375                labels.push(stringify!($label).to_string());
376                sizes.push($size);
377            )*
378
379            $crate::shape::Shape::new(labels, $crate::Slice::new_row_major(sizes)).unwrap()
380        }
381    };
382}
383
384/// Perform a sub-selection on the provided [`Shape`] object.
385///
386/// This macro chains `.select()` calls to apply multiple labeled
387/// dimension restrictions in a fluent way.
388///
389/// ```
390/// let s = ndslice::shape!(host = 2, gpu = 8);
391/// let s = ndslice::select!(s, host = 1, gpu = 4..).unwrap();
392/// assert_eq!(s.labels(), &["host".to_string(), "gpu".to_string()]);
393/// assert_eq!(s.slice().sizes(), &[1, 4]);
394/// ```
395#[macro_export]
396macro_rules! select {
397    ($shape:ident, $label:ident = $range:expr) => {
398        $shape.select(stringify!($label), $range)
399    };
400
401    ($shape:ident, $label:ident = $range:expr, $($labels:ident = $ranges:expr),+) => {
402        $shape.select(stringify!($label), $range).and_then(|shape| $crate::select!(shape, $($labels = $ranges),+))
403    };
404}
405
406/// A range of indices, with a stride. Ranges are convertible from
407/// native Rust ranges.
408///
409/// Deriving `Eq`, `Ord` and `Hash` is sound because all fields are
410/// `Ord` and comparison is purely structural over `(start, end,
411/// step)`.
412#[derive(
413    Debug,
414    Clone,
415    Eq,
416    Hash,
417    PartialEq,
418    Serialize,
419    Deserialize,
420    PartialOrd,
421    Ord
422)]
423pub struct Range(pub usize, pub Option<usize>, pub usize);
424
425impl Range {
426    pub(crate) fn resolve(&self, size: usize) -> (usize, usize, usize) {
427        match self {
428            Range(begin, Some(end), stride) => (*begin, std::cmp::min(size, *end), *stride),
429            Range(begin, None, stride) => (*begin, size, *stride),
430        }
431    }
432
433    pub(crate) fn is_empty(&self) -> bool {
434        matches!(self, Range(begin, Some(end), _) if end <= begin)
435    }
436}
437
438impl fmt::Display for Range {
439    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
440        match self {
441            Range(begin, None, stride) => write!(f, "{}::{}", begin, stride),
442            Range(begin, Some(end), stride) => write!(f, "{}:{}:{}", begin, end, stride),
443        }
444    }
445}
446
447impl From<std::ops::Range<usize>> for Range {
448    fn from(r: std::ops::Range<usize>) -> Self {
449        Self(r.start, Some(r.end), 1)
450    }
451}
452
453impl From<std::ops::RangeInclusive<usize>> for Range {
454    fn from(r: std::ops::RangeInclusive<usize>) -> Self {
455        Self(*r.start(), Some(*r.end() + 1), 1)
456    }
457}
458
459impl From<std::ops::RangeFrom<usize>> for Range {
460    fn from(r: std::ops::RangeFrom<usize>) -> Self {
461        Self(r.start, None, 1)
462    }
463}
464
465impl From<usize> for Range {
466    fn from(idx: usize) -> Self {
467        Self(idx, Some(idx + 1), 1)
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use std::assert_matches::assert_matches;
474
475    use super::*;
476
477    #[test]
478    fn test_basic() {
479        let s = shape!(host = 2, gpu = 8);
480        assert_eq!(&s.labels, &["host".to_string(), "gpu".to_string()]);
481        assert_eq!(s.slice.offset(), 0);
482        assert_eq!(s.slice.sizes(), &[2, 8]);
483        assert_eq!(s.slice.strides(), &[8, 1]);
484
485        assert_eq!(s.to_string(), "{host=2,gpu=8}");
486    }
487
488    #[test]
489    fn test_select() {
490        let s = shape!(host = 2, gpu = 8);
491
492        assert_eq!(
493            s.slice().iter().collect::<Vec<_>>(),
494            &[
495                0,
496                1,
497                2,
498                3,
499                4,
500                5,
501                6,
502                7,
503                8,
504                8 + 1,
505                8 + 2,
506                8 + 3,
507                8 + 4,
508                8 + 5,
509                8 + 6,
510                8 + 7
511            ]
512        );
513
514        assert_eq!(
515            select!(s, host = 1)
516                .unwrap()
517                .slice()
518                .iter()
519                .collect::<Vec<_>>(),
520            &[8, 8 + 1, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
521        );
522
523        assert_eq!(
524            select!(s, gpu = 2..)
525                .unwrap()
526                .slice()
527                .iter()
528                .collect::<Vec<_>>(),
529            &[2, 3, 4, 5, 6, 7, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
530        );
531
532        assert_eq!(
533            select!(s, gpu = 3..5)
534                .unwrap()
535                .slice()
536                .iter()
537                .collect::<Vec<_>>(),
538            &[3, 4, 8 + 3, 8 + 4]
539        );
540
541        assert_eq!(
542            select!(s, gpu = 3..5, host = 1)
543                .unwrap()
544                .slice()
545                .iter()
546                .collect::<Vec<_>>(),
547            &[8 + 3, 8 + 4]
548        );
549    }
550
551    #[test]
552    fn test_select_iter() {
553        let s = shape!(replica = 2, host = 2, gpu = 8);
554        let selections: Vec<_> = s.select_iter(2).unwrap().collect();
555        assert_eq!(selections[0].slice().sizes(), &[1, 1, 8]);
556        assert_eq!(selections[1].slice().sizes(), &[1, 1, 8]);
557        assert_eq!(selections[2].slice().sizes(), &[1, 1, 8]);
558        assert_eq!(selections[3].slice().sizes(), &[1, 1, 8]);
559        assert_eq!(
560            selections,
561            &[
562                select!(s, replica = 0, host = 0).unwrap(),
563                select!(s, replica = 0, host = 1).unwrap(),
564                select!(s, replica = 1, host = 0).unwrap(),
565                select!(s, replica = 1, host = 1).unwrap()
566            ]
567        );
568    }
569
570    #[test]
571    fn test_coordinates() {
572        let s = shape!(host = 2, gpu = 8);
573        assert_eq!(
574            s.coordinates(0).unwrap(),
575            vec![("host".to_string(), 0), ("gpu".to_string(), 0)]
576        );
577        assert_eq!(
578            s.coordinates(1).unwrap(),
579            vec![("host".to_string(), 0), ("gpu".to_string(), 1)]
580        );
581        assert_eq!(
582            s.coordinates(8).unwrap(),
583            vec![("host".to_string(), 1), ("gpu".to_string(), 0)]
584        );
585        assert_eq!(
586            s.coordinates(9).unwrap(),
587            vec![("host".to_string(), 1), ("gpu".to_string(), 1)]
588        );
589
590        assert_matches!(
591            s.coordinates(16).unwrap_err(),
592            ShapeError::SliceError(SliceError::ValueNotInSlice { value: 16 })
593        );
594    }
595
596    #[test]
597    fn test_select_bad() {
598        let s = shape!(host = 2, gpu = 8);
599
600        assert_matches!(
601            select!(s, gpu = 1..1).unwrap_err(),
602            ShapeError::EmptyRange {
603                range: Range(1, Some(1), 1)
604            },
605        );
606
607        assert_matches!(
608            select!(s, gpu = 8).unwrap_err(),
609            ShapeError::OutOfRange {
610                range: Range(8, Some(9), 1),
611                dim,
612                size: 8,
613            } if dim == "gpu",
614        );
615    }
616
617    #[test]
618    fn test_shape_index() {
619        let n_hosts = 5;
620        let n_gpus = 7;
621
622        // Index first dim
623        let s = shape!(host = n_hosts, gpu = n_gpus);
624        assert_eq!(
625            s.index(vec![("host".to_string(), 0)]).unwrap(),
626            Shape::new(
627                vec!["gpu".to_string()],
628                Slice::new(0, vec![n_gpus], vec![1]).unwrap()
629            )
630            .unwrap()
631        );
632
633        // Index last dims
634        let offset = 1;
635        assert_eq!(
636            s.index(vec![("gpu".to_string(), offset)]).unwrap(),
637            Shape::new(
638                vec!["host".to_string()],
639                Slice::new(offset, vec![n_hosts], vec![n_gpus]).unwrap()
640            )
641            .unwrap()
642        );
643
644        // Index middle dim
645        let n_zone = 2;
646        let s = shape!(zone = n_zone, host = n_hosts, gpu = n_gpus);
647        let offset = 3;
648        assert_eq!(
649            s.index(vec![("host".to_string(), offset)]).unwrap(),
650            Shape::new(
651                vec!["zone".to_string(), "gpu".to_string()],
652                Slice::new(
653                    offset * n_gpus,
654                    vec![n_zone, n_gpus],
655                    vec![n_hosts * n_gpus, 1]
656                )
657                .unwrap()
658            )
659            .unwrap()
660        );
661
662        // Out of range
663        assert!(
664            shape!(gpu = n_gpus)
665                .index(vec![("gpu".to_string(), n_gpus)])
666                .is_err()
667        );
668        // Invalid dim
669        assert!(
670            shape!(gpu = n_gpus)
671                .index(vec![("non-exist-dim".to_string(), 0)])
672                .is_err()
673        );
674    }
675
676    #[test]
677    fn test_shape_select_stride_rounding() {
678        let shape = shape!(x = 10);
679        // Select x = 0..10 step 3 → expect indices [0, 3, 6, 9]
680        let sub = shape.select("x", Range(0, Some(10), 3)).unwrap();
681        let slice = sub.slice();
682        // 10 / 3 = 3.33..., so ceil(10 / 3) = 4
683        assert_eq!(
684            slice,
685            &Slice::new(0, vec![4], vec![3]).unwrap(),
686            "Expected offset 0, size 4, stride 3"
687        );
688    }
689
690    #[test]
691    fn test_shape_at_removes_dimension() {
692        let labels = vec![
693            "batch".to_string(),
694            "height".to_string(),
695            "width".to_string(),
696        ];
697        let slice = Slice::new_row_major(vec![2, 3, 4]);
698        let shape = Shape::new(labels, slice).unwrap();
699
700        // Select index 1 from "batch" dimension
701        let result = shape.at("batch", 1).unwrap();
702
703        // Should have 2 dimensions now
704        assert_eq!(result.labels(), &["height", "width"]);
705        assert_eq!(result.slice().sizes(), &[3, 4]);
706        assert_eq!(result.slice().offset(), 12); // 1 * 12 (batch stride)
707    }
708
709    #[test]
710    fn test_shape_at_middle_dimension() {
711        let labels = vec![
712            "batch".to_string(),
713            "height".to_string(),
714            "width".to_string(),
715        ];
716        let slice = Slice::new_row_major(vec![2, 3, 4]);
717        let shape = Shape::new(labels, slice).unwrap();
718
719        // Select index 1 from "height" dimension (middle)
720        let result = shape.at("height", 1).unwrap();
721
722        // Should remove middle label
723        assert_eq!(result.labels(), &["batch", "width"]);
724        assert_eq!(result.slice().sizes(), &[2, 4]);
725        assert_eq!(result.slice().offset(), 4); // 1 * 4 (height stride)
726    }
727
728    #[test]
729    fn test_shape_at_invalid_label() {
730        let labels = vec!["batch".to_string(), "height".to_string()];
731        let slice = Slice::new_row_major(vec![2, 3]);
732        let shape = Shape::new(labels, slice).unwrap();
733
734        let result = shape.at("nonexistent", 0);
735        assert!(matches!(result, Err(ShapeError::InvalidLabels { .. })));
736    }
737
738    #[test]
739    fn test_shape_at_index_out_of_range() {
740        let labels = vec!["batch".to_string(), "height".to_string()];
741        let slice = Slice::new_row_major(vec![2, 3]);
742        let shape = Shape::new(labels, slice).unwrap();
743
744        let result = shape.at("batch", 5); // batch only has size 2
745        assert!(matches!(result, Err(ShapeError::OutOfRange { .. })));
746    }
747
748    #[test]
749    fn test_shape_from_str_round_trip() {
750        let test_cases = vec![
751            shape!(host = 2, gpu = 8),
752            shape!(x = 1),
753            shape!(batch = 10, height = 224, width = 224, channels = 3),
754            Shape::unity(), // empty shape
755        ];
756
757        for original in test_cases {
758            let display_str = original.to_string();
759            let parsed: Shape = display_str.parse().unwrap();
760            assert_eq!(
761                parsed, original,
762                "Round-trip failed for shape: {}",
763                display_str
764            );
765        }
766    }
767
768    #[test]
769    fn test_shape_from_str_valid_cases() {
770        let test_cases = vec![
771            ("{host=2,gpu=8}", shape!(host = 2, gpu = 8)),
772            ("{x=1}", shape!(x = 1)),
773            ("{ host = 2 , gpu = 8 }", shape!(host = 2, gpu = 8)), // with spaces
774            ("{}", Shape::unity()),                                // empty shape
775        ];
776
777        for (input, expected) in test_cases {
778            let parsed: Shape = input.parse().unwrap();
779            assert_eq!(parsed, expected, "Failed to parse: {}", input);
780        }
781    }
782
783    #[test]
784    fn test_shape_from_str_error_cases() {
785        let error_cases = vec![
786            "host=2,gpu=8",
787            "{host=2,gpu=8",
788            "host=2,gpu=8}",
789            "{host=2,gpu=}",
790            "{host=,gpu=8}",
791            "{host=2=3,gpu=8}",
792            "{host=abc,gpu=8}",
793            "{host=2,}",
794            "{=8}",
795        ];
796
797        for input in error_cases {
798            let result: Result<Shape, ShapeError> = input.parse();
799            assert!(result.is_err(), "expected error for input: {}", input);
800        }
801    }
802}