ndslice/
shape.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt;
10use std::str::FromStr;
11
12use serde::Deserialize;
13use serde::Serialize;
14
15use crate::DimSliceIterator;
16use crate::Region;
17use crate::Slice;
18use crate::SliceError;
19use crate::selection::Selection;
20use crate::view::Extent;
21
22// We always retain dimensions here even if they are selected out.
23
24#[derive(Debug, thiserror::Error)]
25pub enum ShapeError {
26    #[error("label slice dimension mismatch: {labels_dim} != {slice_dim}")]
27    DimSliceMismatch { labels_dim: usize, slice_dim: usize },
28
29    #[error("invalid labels `{labels:?}`")]
30    InvalidLabels { labels: Vec<String> },
31
32    #[error("empty range {range}")]
33    EmptyRange { range: Range },
34
35    #[error("out of range {range} for dimension {dim} of size {size}")]
36    OutOfRange {
37        range: Range,
38        dim: String,
39        size: usize,
40    },
41
42    #[error("selection `{expr}` exceeds dimensionality {num_dim}")]
43    SelectionTooDeep { expr: Selection, num_dim: usize },
44
45    #[error("dynamic selection `{expr}`")]
46    SelectionDynamic { expr: Selection },
47
48    #[error("{index} out of range for dimension {dim} of size {size}")]
49    IndexOutOfRange {
50        index: usize,
51        dim: String,
52        size: usize,
53    },
54
55    #[error("failed to parse shape: {reason}")]
56    ParseError { reason: String },
57
58    #[error(transparent)]
59    SliceError(#[from] SliceError),
60}
61
62/// A shape is a [`Slice`] with labeled dimensions and a selection API.
63#[derive(Clone, Deserialize, Serialize, PartialEq, Hash, Debug)]
64pub struct Shape {
65    /// The labels for each dimension in slice.
66    labels: Vec<String>,
67    /// The slice itself, which describes the topology of the shape.
68    slice: Slice,
69}
70
71impl Shape {
72    /// Creates a new shape with the provided labels, which describe the
73    /// provided Slice.
74    ///
75    /// Shapes can also be constructed by way of the [`shape`] macro, which
76    /// creates a by-construction correct slice in row-major order given a set of
77    /// sized dimensions.
78    pub fn new(labels: Vec<String>, slice: Slice) -> Result<Self, ShapeError> {
79        if labels.len() != slice.num_dim() {
80            return Err(ShapeError::DimSliceMismatch {
81                labels_dim: labels.len(),
82                slice_dim: slice.num_dim(),
83            });
84        }
85        Ok(Self { labels, slice })
86    }
87
88    /// Select a single index along a named dimension, removing that
89    /// dimension entirely. This reduces the dimensionality by 1. In
90    /// effect it results in a cross section of the shape at the given
91    /// index in the given dimension.
92    pub fn at(&self, label: &str, index: usize) -> Result<Self, ShapeError> {
93        let dim = self.dim(label)?;
94        let slice = self.slice.at(dim, index).map_err(|err| match err {
95            SliceError::IndexOutOfRange { index, total } => ShapeError::OutOfRange {
96                range: Range(index, Some(index + 1), 1),
97                dim: label.to_string(),
98                size: total,
99            },
100            other => other.into(),
101        })?;
102        let mut labels = self.labels.clone();
103        labels.remove(dim);
104        Ok(Self { labels, slice })
105    }
106
107    /// Restrict this shape along a named dimension using a [`Range`].
108    /// The provided range must be nonempty.
109    ///
110    /// `select` is composable, it can be applied repeatedly, even on
111    /// the same dimension, to refine the view incrementally.
112    pub fn select<R: Into<Range>>(&self, label: &str, range: R) -> Result<Self, ShapeError> {
113        let dim = self.dim(label)?;
114        let range = range.into();
115        let (begin, end, step) = range.resolve(self.slice().sizes()[dim]);
116        let slice = self
117            .slice
118            .select(dim, begin, end, step)
119            .map_err(|err| match err {
120                SliceError::EmptyRange { .. } => ShapeError::EmptyRange { range },
121                SliceError::IndexOutOfRange { total, .. } => ShapeError::OutOfRange {
122                    range,
123                    dim: label.to_string(),
124                    size: total,
125                },
126                other => other.into(),
127            })?;
128        let labels = self.labels.clone();
129        Ok(Self { labels, slice })
130    }
131
132    /// Produces an iterator over subshapes by fixing the first `dims`
133    /// dimensions.
134    ///
135    /// For a shape of rank `n`, this yields `∏ sizes[0..dims]`
136    /// subshapes, each with the first `dims` dimensions restricted to
137    /// size 1. The remaining dimensions are left unconstrained.
138    ///
139    /// This is useful for structured traversal of slices within a
140    /// multidimensional shape. See [`SelectIterator`] for details and
141    /// examples.
142    ///
143    /// # Errors
144    /// Returns an error if `dims == 0` or `dims >= self.rank()`.
145    pub fn select_iter(&self, dims: usize) -> Result<SelectIterator<'_>, ShapeError> {
146        let num_dims = self.slice().num_dim();
147        if dims == 0 || dims >= num_dims {
148            return Err(ShapeError::SliceError(SliceError::IndexOutOfRange {
149                index: dims,
150                total: num_dims,
151            }));
152        }
153
154        Ok(SelectIterator {
155            shape: self,
156            iter: self.slice().dim_iter(dims),
157        })
158    }
159
160    /// Sub-set this shape by select a particular row of the given
161    /// indices The resulting shape will no longer have dimensions for
162    /// the given indices Example shape.index(vec![("gpu", 3),
163    /// ("host", 0)])
164    pub fn index(&self, indices: Vec<(String, usize)>) -> Result<Shape, ShapeError> {
165        let mut shape = self.clone();
166        for (label, index) in indices {
167            shape = shape.at(&label, index)?;
168        }
169        Ok(shape)
170    }
171
172    /// The per-dimension labels of this shape.
173    pub fn labels(&self) -> &[String] {
174        &self.labels
175    }
176
177    /// The slice describing the shape.
178    pub fn slice(&self) -> &Slice {
179        &self.slice
180    }
181
182    /// Return a set of labeled coordinates for the given rank.
183    pub fn coordinates(&self, rank: usize) -> Result<Vec<(String, usize)>, ShapeError> {
184        let coords = self.slice.coordinates(rank)?;
185        Ok(coords
186            .iter()
187            .zip(self.labels.iter())
188            .map(|(i, l)| (l.to_string(), *i))
189            .collect())
190    }
191
192    pub fn dim(&self, label: &str) -> Result<usize, ShapeError> {
193        self.labels
194            .iter()
195            .position(|l| l == label)
196            .ok_or_else(|| ShapeError::InvalidLabels {
197                labels: vec![label.to_string()],
198            })
199    }
200
201    /// Return the 0-dimensional single element shape
202    pub fn unity() -> Shape {
203        Shape::new(vec![], Slice::new(0, vec![], vec![]).expect("unity")).expect("unity")
204    }
205
206    /// The extent corresponding to this shape.
207    pub fn extent(&self) -> Extent {
208        Extent::new(self.labels.clone(), self.slice.sizes().to_vec()).unwrap()
209    }
210
211    /// The region corresponding to this shape.
212    pub fn region(&self) -> Region {
213        self.into()
214    }
215}
216
217impl From<&Region> for Shape {
218    fn from(region: &Region) -> Self {
219        Shape::new(region.labels().to_vec(), region.slice().clone())
220            .expect("Shape::new should not fail because a Region by definition is a valid Shape")
221    }
222}
223
224/// Iterator over subshapes obtained by fixing a prefix of dimensions.
225///
226/// This iterator is produced by [`Shape::select_iter(dims)`], and
227/// yields one `Shape` per coordinate prefix in the first `dims`
228/// dimensions.
229///
230/// For a shape of `n` dimensions, each yielded shape has:
231/// - The first `dims` dimensions restricted to size 1 (i.e., fixed
232///   via `select`)
233/// - The remaining `n - dims` dimensions left unconstrained
234///
235/// This allows structured iteration over "slices" of the original
236/// shape: for example with `n` = 3, `select_iter(1)` walks through 2D
237/// planes, while `select_iter(2)` yields 1D subshapes.
238///
239/// # Example
240/// ```ignore
241/// let s = shape!(zone = 2, host = 2, gpu = 8);
242/// let views: Vec<_> = s.select_iter(2).unwrap().collect();
243/// assert_eq!(views.len(), 4);
244/// assert_eq!(views[0].slice().sizes(), &[1, 1, 8]);
245/// ```
246/// The above example can be interpreted as: for each `(zone, host)`
247/// pair, `select_iter(2)` yields a `Shape` describing the associated
248/// row of GPUs — a view into the `[1, 1, 8]` subregion of the full
249/// `[2, 2, 8]` shape.
250pub struct SelectIterator<'a> {
251    shape: &'a Shape,
252    iter: DimSliceIterator,
253}
254
255impl<'a> Iterator for SelectIterator<'a> {
256    type Item = Shape;
257
258    fn next(&mut self) -> Option<Self::Item> {
259        let pos = self.iter.next()?;
260        let mut shape = self.shape.clone();
261        for (dim, index) in pos.iter().enumerate() {
262            shape = shape.select(&self.shape.labels()[dim], *index).unwrap();
263        }
264        Some(shape)
265    }
266}
267
268impl fmt::Display for Shape {
269    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270        // Just display the sizes of each dimension, for now.
271        // Once we have a selection algebra, we can provide a
272        // better Display implementation.
273        write!(f, "{{")?;
274        for dim in 0..self.labels.len() {
275            write!(f, "{}={}", self.labels[dim], self.slice.sizes()[dim])?;
276            if dim < self.labels.len() - 1 {
277                write!(f, ",")?;
278            }
279        }
280        write!(f, "}}")
281    }
282}
283
284impl FromStr for Shape {
285    type Err = ShapeError;
286
287    fn from_str(s: &str) -> Result<Self, Self::Err> {
288        let s = s.trim();
289
290        if !s.starts_with('{') || !s.ends_with('}') {
291            return Err(ShapeError::ParseError {
292                reason: "shape string must be enclosed in braces".to_string(),
293            });
294        }
295
296        let inner = &s[1..s.len() - 1].trim();
297
298        if inner.is_empty() {
299            return Ok(Shape::unity());
300        }
301
302        let mut labels = Vec::new();
303        let mut sizes = Vec::new();
304
305        for part in inner.split(',') {
306            let part = part.trim();
307            let mut split = part.split('=');
308
309            let label = split
310                .next()
311                .ok_or_else(|| ShapeError::ParseError {
312                    reason: format!("invalid dimension format: '{}'", part),
313                })?
314                .trim();
315
316            let size_str = split
317                .next()
318                .ok_or_else(|| ShapeError::ParseError {
319                    reason: format!("missing size for dimension '{}'", label),
320                })?
321                .trim();
322
323            if split.next().is_some() {
324                return Err(ShapeError::ParseError {
325                    reason: format!("invalid dimension format: '{}'", part),
326                });
327            }
328
329            if label.is_empty() {
330                return Err(ShapeError::ParseError {
331                    reason: format!("missing label in dimension: '{}'", part),
332                });
333            }
334
335            let size = size_str
336                .parse::<usize>()
337                .map_err(|_| ShapeError::ParseError {
338                    reason: format!("invalid size '{}' for dimension '{}'", size_str, label),
339                })?;
340
341            labels.push(label.to_string());
342            sizes.push(size);
343        }
344
345        let slice = Slice::new_row_major(sizes);
346        Shape::new(labels, slice)
347    }
348}
349
350/// Construct a new shape with the given set of dimension-size pairs in row-major
351/// order.
352///
353/// ```
354/// let s = ndslice::shape!(host = 2, gpu = 8);
355/// assert_eq!(s.labels(), &["host".to_string(), "gpu".to_string()]);
356/// assert_eq!(s.slice().sizes(), &[2, 8]);
357/// assert_eq!(s.slice().strides(), &[8, 1]);
358/// ```
359#[macro_export]
360macro_rules! shape {
361    ( $( $label:ident = $size:expr ),* $(,)? ) => {
362        {
363            let mut labels = Vec::new();
364            let mut sizes = Vec::new();
365
366            $(
367                labels.push(stringify!($label).to_string());
368                sizes.push($size);
369            )*
370
371            $crate::shape::Shape::new(labels, $crate::Slice::new_row_major(sizes)).unwrap()
372        }
373    };
374}
375
376/// Perform a sub-selection on the provided [`Shape`] object.
377///
378/// This macro chains `.select()` calls to apply multiple labeled
379/// dimension restrictions in a fluent way.
380///
381/// ```
382/// let s = ndslice::shape!(host = 2, gpu = 8);
383/// let s = ndslice::select!(s, host = 1, gpu = 4..).unwrap();
384/// assert_eq!(s.labels(), &["host".to_string(), "gpu".to_string()]);
385/// assert_eq!(s.slice().sizes(), &[1, 4]);
386/// ```
387#[macro_export]
388macro_rules! select {
389    ($shape:ident, $label:ident = $range:expr) => {
390        $shape.select(stringify!($label), $range)
391    };
392
393    ($shape:ident, $label:ident = $range:expr, $($labels:ident = $ranges:expr),+) => {
394        $shape.select(stringify!($label), $range).and_then(|shape| $crate::select!(shape, $($labels = $ranges),+))
395    };
396}
397
398/// A range of indices, with a stride. Ranges are convertible from
399/// native Rust ranges.
400///
401/// Deriving `Eq`, `Ord` and `Hash` is sound because all fields are
402/// `Ord` and comparison is purely structural over `(start, end,
403/// step)`.
404#[derive(
405    Debug,
406    Clone,
407    Eq,
408    Hash,
409    PartialEq,
410    Serialize,
411    Deserialize,
412    PartialOrd,
413    Ord
414)]
415pub struct Range(pub usize, pub Option<usize>, pub usize);
416
417impl Range {
418    pub(crate) fn resolve(&self, size: usize) -> (usize, usize, usize) {
419        match self {
420            Range(begin, Some(end), stride) => (*begin, std::cmp::min(size, *end), *stride),
421            Range(begin, None, stride) => (*begin, size, *stride),
422        }
423    }
424
425    pub(crate) fn is_empty(&self) -> bool {
426        matches!(self, Range(begin, Some(end), _) if end <= begin)
427    }
428}
429
430impl fmt::Display for Range {
431    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
432        match self {
433            Range(begin, None, stride) => write!(f, "{}::{}", begin, stride),
434            Range(begin, Some(end), stride) => write!(f, "{}:{}:{}", begin, end, stride),
435        }
436    }
437}
438
439impl From<std::ops::Range<usize>> for Range {
440    fn from(r: std::ops::Range<usize>) -> Self {
441        Self(r.start, Some(r.end), 1)
442    }
443}
444
445impl From<std::ops::RangeInclusive<usize>> for Range {
446    fn from(r: std::ops::RangeInclusive<usize>) -> Self {
447        Self(*r.start(), Some(*r.end() + 1), 1)
448    }
449}
450
451impl From<std::ops::RangeFrom<usize>> for Range {
452    fn from(r: std::ops::RangeFrom<usize>) -> Self {
453        Self(r.start, None, 1)
454    }
455}
456
457impl From<usize> for Range {
458    fn from(idx: usize) -> Self {
459        Self(idx, Some(idx + 1), 1)
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use std::assert_matches::assert_matches;
466
467    use super::*;
468
469    #[test]
470    fn test_basic() {
471        let s = shape!(host = 2, gpu = 8);
472        assert_eq!(&s.labels, &["host".to_string(), "gpu".to_string()]);
473        assert_eq!(s.slice.offset(), 0);
474        assert_eq!(s.slice.sizes(), &[2, 8]);
475        assert_eq!(s.slice.strides(), &[8, 1]);
476
477        assert_eq!(s.to_string(), "{host=2,gpu=8}");
478    }
479
480    #[test]
481    fn test_select() {
482        let s = shape!(host = 2, gpu = 8);
483
484        assert_eq!(
485            s.slice().iter().collect::<Vec<_>>(),
486            &[
487                0,
488                1,
489                2,
490                3,
491                4,
492                5,
493                6,
494                7,
495                8,
496                8 + 1,
497                8 + 2,
498                8 + 3,
499                8 + 4,
500                8 + 5,
501                8 + 6,
502                8 + 7
503            ]
504        );
505
506        assert_eq!(
507            select!(s, host = 1)
508                .unwrap()
509                .slice()
510                .iter()
511                .collect::<Vec<_>>(),
512            &[8, 8 + 1, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
513        );
514
515        assert_eq!(
516            select!(s, gpu = 2..)
517                .unwrap()
518                .slice()
519                .iter()
520                .collect::<Vec<_>>(),
521            &[2, 3, 4, 5, 6, 7, 8 + 2, 8 + 3, 8 + 4, 8 + 5, 8 + 6, 8 + 7]
522        );
523
524        assert_eq!(
525            select!(s, gpu = 3..5)
526                .unwrap()
527                .slice()
528                .iter()
529                .collect::<Vec<_>>(),
530            &[3, 4, 8 + 3, 8 + 4]
531        );
532
533        assert_eq!(
534            select!(s, gpu = 3..5, host = 1)
535                .unwrap()
536                .slice()
537                .iter()
538                .collect::<Vec<_>>(),
539            &[8 + 3, 8 + 4]
540        );
541    }
542
543    #[test]
544    fn test_select_iter() {
545        let s = shape!(replica = 2, host = 2, gpu = 8);
546        let selections: Vec<_> = s.select_iter(2).unwrap().collect();
547        assert_eq!(selections[0].slice().sizes(), &[1, 1, 8]);
548        assert_eq!(selections[1].slice().sizes(), &[1, 1, 8]);
549        assert_eq!(selections[2].slice().sizes(), &[1, 1, 8]);
550        assert_eq!(selections[3].slice().sizes(), &[1, 1, 8]);
551        assert_eq!(
552            selections,
553            &[
554                select!(s, replica = 0, host = 0).unwrap(),
555                select!(s, replica = 0, host = 1).unwrap(),
556                select!(s, replica = 1, host = 0).unwrap(),
557                select!(s, replica = 1, host = 1).unwrap()
558            ]
559        );
560    }
561
562    #[test]
563    fn test_coordinates() {
564        let s = shape!(host = 2, gpu = 8);
565        assert_eq!(
566            s.coordinates(0).unwrap(),
567            vec![("host".to_string(), 0), ("gpu".to_string(), 0)]
568        );
569        assert_eq!(
570            s.coordinates(1).unwrap(),
571            vec![("host".to_string(), 0), ("gpu".to_string(), 1)]
572        );
573        assert_eq!(
574            s.coordinates(8).unwrap(),
575            vec![("host".to_string(), 1), ("gpu".to_string(), 0)]
576        );
577        assert_eq!(
578            s.coordinates(9).unwrap(),
579            vec![("host".to_string(), 1), ("gpu".to_string(), 1)]
580        );
581
582        assert_matches!(
583            s.coordinates(16).unwrap_err(),
584            ShapeError::SliceError(SliceError::ValueNotInSlice { value: 16 })
585        );
586    }
587
588    #[test]
589    fn test_select_bad() {
590        let s = shape!(host = 2, gpu = 8);
591
592        assert_matches!(
593            select!(s, gpu = 1..1).unwrap_err(),
594            ShapeError::EmptyRange {
595                range: Range(1, Some(1), 1)
596            },
597        );
598
599        assert_matches!(
600            select!(s, gpu = 8).unwrap_err(),
601            ShapeError::OutOfRange {
602                range: Range(8, Some(9), 1),
603                dim,
604                size: 8,
605            } if dim == "gpu",
606        );
607    }
608
609    #[test]
610    fn test_shape_index() {
611        let n_hosts = 5;
612        let n_gpus = 7;
613
614        // Index first dim
615        let s = shape!(host = n_hosts, gpu = n_gpus);
616        assert_eq!(
617            s.index(vec![("host".to_string(), 0)]).unwrap(),
618            Shape::new(
619                vec!["gpu".to_string()],
620                Slice::new(0, vec![n_gpus], vec![1]).unwrap()
621            )
622            .unwrap()
623        );
624
625        // Index last dims
626        let offset = 1;
627        assert_eq!(
628            s.index(vec![("gpu".to_string(), offset)]).unwrap(),
629            Shape::new(
630                vec!["host".to_string()],
631                Slice::new(offset, vec![n_hosts], vec![n_gpus]).unwrap()
632            )
633            .unwrap()
634        );
635
636        // Index middle dim
637        let n_zone = 2;
638        let s = shape!(zone = n_zone, host = n_hosts, gpu = n_gpus);
639        let offset = 3;
640        assert_eq!(
641            s.index(vec![("host".to_string(), offset)]).unwrap(),
642            Shape::new(
643                vec!["zone".to_string(), "gpu".to_string()],
644                Slice::new(
645                    offset * n_gpus,
646                    vec![n_zone, n_gpus],
647                    vec![n_hosts * n_gpus, 1]
648                )
649                .unwrap()
650            )
651            .unwrap()
652        );
653
654        // Out of range
655        assert!(
656            shape!(gpu = n_gpus)
657                .index(vec![("gpu".to_string(), n_gpus)])
658                .is_err()
659        );
660        // Invalid dim
661        assert!(
662            shape!(gpu = n_gpus)
663                .index(vec![("non-exist-dim".to_string(), 0)])
664                .is_err()
665        );
666    }
667
668    #[test]
669    fn test_shape_select_stride_rounding() {
670        let shape = shape!(x = 10);
671        // Select x = 0..10 step 3 → expect indices [0, 3, 6, 9]
672        let sub = shape.select("x", Range(0, Some(10), 3)).unwrap();
673        let slice = sub.slice();
674        // 10 / 3 = 3.33..., so ceil(10 / 3) = 4
675        assert_eq!(
676            slice,
677            &Slice::new(0, vec![4], vec![3]).unwrap(),
678            "Expected offset 0, size 4, stride 3"
679        );
680    }
681
682    #[test]
683    fn test_shape_at_removes_dimension() {
684        let labels = vec![
685            "batch".to_string(),
686            "height".to_string(),
687            "width".to_string(),
688        ];
689        let slice = Slice::new_row_major(vec![2, 3, 4]);
690        let shape = Shape::new(labels, slice).unwrap();
691
692        // Select index 1 from "batch" dimension
693        let result = shape.at("batch", 1).unwrap();
694
695        // Should have 2 dimensions now
696        assert_eq!(result.labels(), &["height", "width"]);
697        assert_eq!(result.slice().sizes(), &[3, 4]);
698        assert_eq!(result.slice().offset(), 12); // 1 * 12 (batch stride)
699    }
700
701    #[test]
702    fn test_shape_at_middle_dimension() {
703        let labels = vec![
704            "batch".to_string(),
705            "height".to_string(),
706            "width".to_string(),
707        ];
708        let slice = Slice::new_row_major(vec![2, 3, 4]);
709        let shape = Shape::new(labels, slice).unwrap();
710
711        // Select index 1 from "height" dimension (middle)
712        let result = shape.at("height", 1).unwrap();
713
714        // Should remove middle label
715        assert_eq!(result.labels(), &["batch", "width"]);
716        assert_eq!(result.slice().sizes(), &[2, 4]);
717        assert_eq!(result.slice().offset(), 4); // 1 * 4 (height stride)
718    }
719
720    #[test]
721    fn test_shape_at_invalid_label() {
722        let labels = vec!["batch".to_string(), "height".to_string()];
723        let slice = Slice::new_row_major(vec![2, 3]);
724        let shape = Shape::new(labels, slice).unwrap();
725
726        let result = shape.at("nonexistent", 0);
727        assert!(matches!(result, Err(ShapeError::InvalidLabels { .. })));
728    }
729
730    #[test]
731    fn test_shape_at_index_out_of_range() {
732        let labels = vec!["batch".to_string(), "height".to_string()];
733        let slice = Slice::new_row_major(vec![2, 3]);
734        let shape = Shape::new(labels, slice).unwrap();
735
736        let result = shape.at("batch", 5); // batch only has size 2
737        assert!(matches!(result, Err(ShapeError::OutOfRange { .. })));
738    }
739
740    #[test]
741    fn test_shape_from_str_round_trip() {
742        let test_cases = vec![
743            shape!(host = 2, gpu = 8),
744            shape!(x = 1),
745            shape!(batch = 10, height = 224, width = 224, channels = 3),
746            Shape::unity(), // empty shape
747        ];
748
749        for original in test_cases {
750            let display_str = original.to_string();
751            let parsed: Shape = display_str.parse().unwrap();
752            assert_eq!(
753                parsed, original,
754                "Round-trip failed for shape: {}",
755                display_str
756            );
757        }
758    }
759
760    #[test]
761    fn test_shape_from_str_valid_cases() {
762        let test_cases = vec![
763            ("{host=2,gpu=8}", shape!(host = 2, gpu = 8)),
764            ("{x=1}", shape!(x = 1)),
765            ("{ host = 2 , gpu = 8 }", shape!(host = 2, gpu = 8)), // with spaces
766            ("{}", Shape::unity()),                                // empty shape
767        ];
768
769        for (input, expected) in test_cases {
770            let parsed: Shape = input.parse().unwrap();
771            assert_eq!(parsed, expected, "Failed to parse: {}", input);
772        }
773    }
774
775    #[test]
776    fn test_shape_from_str_error_cases() {
777        let error_cases = vec![
778            "host=2,gpu=8",
779            "{host=2,gpu=8",
780            "host=2,gpu=8}",
781            "{host=2,gpu=}",
782            "{host=,gpu=8}",
783            "{host=2=3,gpu=8}",
784            "{host=abc,gpu=8}",
785            "{host=2,}",
786            "{=8}",
787        ];
788
789        for input in error_cases {
790            let result: Result<Shape, ShapeError> = input.parse();
791            assert!(result.is_err(), "expected error for input: {}", input);
792        }
793    }
794}