ndslice/
selection.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module defines a recursive algebra for selecting coordinates
10//! in a multidimensional space.
11//!
12//! A `Selection` describes constraints across dimensions of an
13//! `ndslice::Slice`. Variants like [`All`], [`First`], and [`Range`]
14//! operate dimensionally, while [`Intersection`] and [`Union`] allow
15//! for logical composition of selections.
16//!
17//! ## Example
18//!
19//! Suppose a 3-dimensional mesh system of:
20//! - 2 zones
21//! - 4 hosts per zone
22//! - 8 GPUs per host
23//!
24//! The corresponding `Slice` will have shape `[2, 4, 8]`. An
25//! expression to denote the first 4 GPUs of host 0 together with the
26//! last 4 GPUs on host 1 across all regions can be written as:
27//! ```rust
28//! use ndslice::selection::dsl::all;
29//! use ndslice::selection::dsl::range;
30//! use ndslice::selection::dsl::true_;
31//! use ndslice::selection::dsl::union;
32//!
33//! let s = all(range(0, range(0..4, true_())));
34//! let t = all(range(1, range(4.., true_())));
35//! let selection = union(s, t);
36//! ```
37//! Assuming a row-major layout, that is the set of 4 x 2 x 2 = 16
38//! coordinates *{(0, 0, 0), ... (0, 0, 3), (0, 1, 5), ..., (0, 1, 7),
39//! (1, 0, 0), ..., (1, 0, 3), (1, 1, 4), ..., (1, 1, 7)}* where code
40//! to print that set might read as follows.
41//! ```rust
42//! use ndslice::Slice;
43//! use ndslice::selection::EvalOpts;
44//! use ndslice::selection::dsl::all;
45//! use ndslice::selection::dsl::range;
46//! use ndslice::selection::dsl::true_;
47//! use ndslice::selection::dsl::union;
48//!
49//! let slice = Slice::new(0usize, vec![2, 4, 8], vec![32, 8, 1]).unwrap();
50//! let s = all(range(0, range(0..4, true_())));
51//! let t = all(range(1, range(4.., true_())));
52//!
53//! for r in union(s, t).eval(&EvalOpts::lenient(), &slice).unwrap() {
54//!     println!("{:?}", slice.coordinates(r).unwrap());
55//! }
56//! ```
57//! which is using the `eval` function described next.
58//!
59//! ## Evaluation
60//!
61//! Selections are evaluated against an `ndslice::Slice` using the
62//! [`Selection::eval`] method, which returns a lazy iterator over the
63//! flat (linearized) indices of elements that match.
64//!
65//! Evaluation proceeds recursively, dimension by dimension, with each
66//! variant of `Selection` contributing logic at a particular level of
67//! the slice.
68//!
69//! If a `Selection` is shorter than the number of dimensions, it is
70//! implicitly extended with `true_()` at the remaining levels. This
71//! means `Selection::True` acts as the identity element, matching all
72//! remaining indices by default.
73
74/// A parser for selection expressions in a compact textual syntax.
75///
76/// See [`selection::parse`] for syntax details and examples.
77pub mod parse;
78
79/// Formatting utilities for `Selection` expressions.
80///
81/// This module defines pretty-printers and compact syntax renderers
82/// for selections, based on implementations of the `SelectionSYM`
83/// trait.
84///
85/// The `Display` implementation for [`Selection`] uses this module.
86pub mod pretty;
87
88/// A `TokenStream` to [`Selection`] parser. Used at compile time in
89/// [`sel!]`. See [`selection::parse`] for syntax details and
90/// examples.
91pub mod token_parser;
92
93/// Shape navigation guided by [`Selection`] expressions.
94pub mod routing;
95
96/// Normalization logic for `Selection`.
97pub mod normal;
98
99pub mod test_utils;
100
101use std::collections::BTreeSet;
102use std::collections::HashMap;
103use std::collections::HashSet;
104use std::fmt;
105
106use rand::Rng;
107use serde::Deserialize;
108use serde::Serialize;
109
110use crate::Slice;
111use crate::selection::normal::NormalizedSelection;
112use crate::selection::normal::RewriteRuleExt;
113use crate::shape;
114use crate::shape::ShapeError;
115use crate::slice::SliceError;
116
117/// This trait defines an abstract syntax without committing to a
118/// specific representation. It follow the
119/// [tagless-final](https://okmij.org/ftp/tagless-final/index.html)
120/// style where [`Selection`] is a default AST representation.
121pub trait SelectionSYM {
122    /// The identity selection (matches no nodes).
123    fn false_() -> Self;
124
125    /// The universal selection (matches all nodes).
126    fn true_() -> Self;
127
128    /// Selects all values along the current dimension, then applies
129    /// the inner selection.
130    fn all(selection: Self) -> Self;
131
132    /// Selects the first index along the current dimension for which
133    /// the inner selection is non-empty.
134    fn first(selection: Self) -> Self;
135
136    /// Selects values within the given range along the current
137    /// dimension, then applies the inner selection.
138    fn range<R: Into<shape::Range>>(range: R, selection: Self) -> Self;
139
140    /// Selects values along the current dimension that match the
141    /// given labels, then applies the inner selection.
142    fn label<L: Into<LabelKey>>(labels: Vec<L>, selection: Self) -> Self;
143
144    /// Selects a random index along the current dimension, then applies
145    /// the inner selection.
146    fn any(selection: Self) -> Self;
147
148    /// The intersection (logical AND) of two selection expressions.
149    fn intersection(lhs: Self, selection: Self) -> Self;
150
151    /// The union (logical OR) of two selection expressions.
152    fn union(lhs: Self, selection: Self) -> Self;
153}
154
155/// `SelectionSYM`-based constructors specialized to the [`Selection`]
156/// AST.
157pub mod dsl {
158
159    use super::LabelKey;
160    use super::Selection;
161    use super::SelectionSYM;
162    use crate::shape;
163
164    pub fn false_() -> Selection {
165        SelectionSYM::false_()
166    }
167    pub fn true_() -> Selection {
168        SelectionSYM::true_()
169    }
170    pub fn all(inner: Selection) -> Selection {
171        SelectionSYM::all(inner)
172    }
173    pub fn first(inner: Selection) -> Selection {
174        SelectionSYM::first(inner)
175    }
176    pub fn range<R: Into<shape::Range>>(r: R, inner: Selection) -> Selection {
177        SelectionSYM::range(r, inner)
178    }
179    pub fn label<L: Into<LabelKey>>(labels: Vec<L>, inner: Selection) -> Selection {
180        SelectionSYM::label(labels, inner)
181    }
182    pub fn any(inner: Selection) -> Selection {
183        SelectionSYM::any(inner)
184    }
185    pub fn intersection(lhs: Selection, rhs: Selection) -> Selection {
186        SelectionSYM::intersection(lhs, rhs)
187    }
188    pub fn union(lhs: Selection, rhs: Selection) -> Selection {
189        SelectionSYM::union(lhs, rhs)
190    }
191}
192
193impl SelectionSYM for Selection {
194    fn false_() -> Self {
195        ast::false_()
196    }
197    fn true_() -> Self {
198        ast::true_()
199    }
200    fn all(selection: Self) -> Self {
201        ast::all(selection)
202    }
203    fn first(selection: Self) -> Self {
204        ast::first(selection)
205    }
206    fn range<R: Into<shape::Range>>(range: R, selection: Self) -> Self {
207        ast::range(range, selection)
208    }
209    fn label<L: Into<LabelKey>>(labels: Vec<L>, selection: Selection) -> Selection {
210        let labels = labels.into_iter().map(|l| l.into()).collect();
211        Selection::Label(labels, Box::new(selection))
212    }
213    fn any(selection: Self) -> Self {
214        ast::any(selection)
215    }
216    fn intersection(lhs: Self, rhs: Self) -> Self {
217        ast::intersection(lhs, rhs)
218    }
219    fn union(lhs: Self, rhs: Self) -> Self {
220        ast::union(lhs, rhs)
221    }
222}
223
224impl fmt::Display for Selection {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        write!(f, "{}", pretty::pretty(self))
227    }
228}
229
230/// A metadata label used to constrain values at a given coordinate
231/// dimension.
232///
233/// `LabelKey` represents attribute values associated with indices —
234/// for example, GPU model names like `"A100"` or capabilities like
235/// "AVX-512".
236///
237/// Labels are not dimension names (like `"zone"` or `"rack"`); they
238/// are **values** assigned to elements at a given dimension, and are
239/// used by `Selection::Label` to restrict which values are eligible
240/// during selection or routing.
241/// For example, a selection like `sel!(["A100"]*)` matches only
242/// indices at the current dimension whose associated label value is
243/// `"A100"`.
244///
245/// `Ord` is derived to allow deterministic sorting and set membership,
246/// based on lexicographic ordering of label strings.
247#[derive(
248    Clone,
249    Debug,
250    PartialEq,
251    Eq,
252    Hash,
253    Serialize,
254    Deserialize,
255    PartialOrd,
256    Ord
257)]
258pub enum LabelKey {
259    /// A plain string label value.
260    Value(String),
261}
262
263impl From<String> for LabelKey {
264    fn from(s: String) -> Self {
265        LabelKey::Value(s)
266    }
267}
268
269impl From<&str> for LabelKey {
270    fn from(s: &str) -> Self {
271        LabelKey::Value(s.to_string())
272    }
273}
274
275impl std::fmt::Display for LabelKey {
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        match self {
278            LabelKey::Value(s) => write!(f, "\"{}\"", s),
279        }
280    }
281}
282
283/// An algebra for expressing node selection.
284#[derive(Debug, Clone, Serialize, Deserialize)]
285#[non_exhaustive]
286pub enum Selection {
287    /// A selection that never matches any node.
288    False,
289
290    /// A selection that always matches any node.
291    True,
292
293    /// Selects all values along the current dimension, continuing
294    /// with the given selection.
295    All(Box<Selection>),
296
297    /// Selects the first value along the current dimension for which
298    /// applying the inner selection yields any results.
299    First(Box<Selection>),
300
301    /// Selects values within a given range along the current
302    /// dimension, continuing with the given selection.
303    Range(shape::Range, Box<Selection>),
304
305    /// Selects values based on metadata (i.e., labels) along the
306    /// current dimension. This provides attribute-based selection.
307    Label(Vec<LabelKey>, Box<Selection>),
308
309    /// Selects a random index along the current dimension, continuing
310    /// with the given selection.
311    Any(Box<Selection>),
312
313    /// The intersection (logical AND) of two selections.
314    Intersection(Box<Selection>, Box<Selection>),
315
316    /// The union (logical OR) of two selections.
317    Union(Box<Selection>, Box<Selection>),
318}
319
320// Compile-time check: ensure Selection is thread-safe and fully
321// owned.
322fn _assert_selection_traits()
323where
324    Selection: Send + Sync + 'static,
325{
326}
327
328/// Compares two `Selection` values for structural equality.
329///
330/// Two selections are structurally equal if they have the same shape
331/// and recursively equivalent substructure, but not necessarily the
332/// same pointer identity or formatting.
333pub fn structurally_equal(a: &Selection, b: &Selection) -> bool {
334    use Selection::*;
335    match (a, b) {
336        (False, False) => true,
337        (True, True) => true,
338        (All(x), All(y)) => structurally_equal(x, y),
339        (Any(x), Any(y)) => structurally_equal(x, y),
340        (First(x), First(y)) => structurally_equal(x, y),
341        (Range(r1, x), Range(r2, y)) => r1 == r2 && structurally_equal(x, y),
342        (Intersection(x1, y1), Intersection(x2, y2)) => {
343            structurally_equal(x1, x2) && structurally_equal(y1, y2)
344        }
345        (Union(x1, y1), Union(x2, y2)) => structurally_equal(x1, x2) && structurally_equal(y1, y2),
346        _ => false,
347    }
348}
349
350/// Normalizes a [`Selection`] toward a canonical form for structural
351/// comparison.
352///
353/// This rewrites the selection to eliminate redundant subtrees and
354/// bring structurally similar selections into a common
355/// representation. The result is suitable for comparison, hashing,
356/// and deduplication (e.g., in [`RoutingFrameKey`]).
357///
358/// Normalization preserves semantics but may alter syntactic
359/// structure. It is designed to improve over time as additional
360/// rewrites (e.g., flattening, simplification) are introduced.
361pub fn normalize(sel: &Selection) -> NormalizedSelection {
362    let rule = normal::FlatteningRules
363        .then(normal::IdentityRules)
364        .then(normal::AbsorbtionRules);
365    sel.fold::<normal::NormalizedSelection>()
366        .rewrite_bottom_up(&rule)
367}
368
369/// Wraps a normalized selection and derives `Eq` and `Hash`, relying
370/// on the canonical structure of the normalized form.
371///
372/// This ensures that logically equivalent selections (e.g., unions
373/// with reordered elements) compare equal and hash identically.
374#[derive(Debug, Clone, PartialEq, Eq, Hash)]
375pub struct NormalizedSelectionKey(NormalizedSelection);
376
377impl NormalizedSelectionKey {
378    /// Constructs a `NormalizedSelectionKey`, normalizing the input
379    /// selection.
380    pub fn new(sel: &Selection) -> Self {
381        Self(crate::selection::normalize(sel))
382    }
383
384    /// Access the normalized selection.
385    pub fn inner(&self) -> &NormalizedSelection {
386        &self.0
387    }
388
389    /// Consumes the key and returns the owned normalized selection.
390    pub fn into_inner(self) -> NormalizedSelection {
391        self.0
392    }
393}
394
395mod ast {
396    use super::LabelKey;
397    use super::Selection;
398    use crate::shape;
399
400    pub(crate) fn false_() -> Selection {
401        Selection::False
402    }
403    pub(crate) fn true_() -> Selection {
404        Selection::True
405    }
406    pub(crate) fn all(selection: Selection) -> Selection {
407        Selection::All(Box::new(selection))
408    }
409    pub(crate) fn first(selection: Selection) -> Selection {
410        Selection::First(Box::new(selection))
411    }
412    pub(crate) fn range<R: Into<shape::Range>>(range: R, selection: Selection) -> Selection {
413        Selection::Range(range.into(), Box::new(selection))
414    }
415    #[allow(dead_code)] // Harmless.
416    pub(crate) fn label<L: Into<LabelKey>>(labels: Vec<L>, selection: Selection) -> Selection {
417        let labels = labels.into_iter().map(Into::into).collect();
418        Selection::Label(labels, Box::new(selection))
419    }
420    pub(crate) fn any(selection: Selection) -> Selection {
421        Selection::Any(Box::new(selection))
422    }
423    pub(crate) fn intersection(lhs: Selection, rhs: Selection) -> Selection {
424        Selection::Intersection(Box::new(lhs), Box::new(rhs))
425    }
426    pub(crate) fn union(lhs: Selection, rhs: Selection) -> Selection {
427        Selection::Union(Box::new(lhs), Box::new(rhs))
428    }
429}
430
431/// `EvalOpts` controls runtime behavior of [`Selection::eval`] by
432/// enforcing stricter validation rules.
433pub struct EvalOpts {
434    /// Fail `eval` on empty range expressions.
435    pub disallow_empty_ranges: bool,
436
437    /// Fail `eval` on a range beginning after the slice's extent in
438    /// the evaluation context's dimension.
439    pub disallow_out_of_range: bool,
440
441    /// Fail `eval` if a selection can be shown to be not "static".
442    pub disallow_dynamic_selections: bool,
443}
444
445impl EvalOpts {
446    // Produce empty iterators but don't panic.
447    pub fn lenient() -> Self {
448        Self {
449            disallow_empty_ranges: false,
450            disallow_out_of_range: false,
451            disallow_dynamic_selections: false,
452        }
453    }
454
455    // `eval()` should fail with all the same [`shape::ShapeError`]s
456    // as [`Shape::select()`].
457    #[allow(dead_code)]
458    pub fn strict() -> Self {
459        Self {
460            disallow_empty_ranges: true,
461            disallow_out_of_range: true,
462            ..Self::lenient()
463        }
464    }
465}
466
467impl Selection {
468    pub(crate) fn validate(&self, opts: &EvalOpts, slice: &Slice) -> Result<&Self, ShapeError> {
469        let depth = 0;
470        self.validate_rec(opts, slice, self, depth).map(|_| self)
471    }
472
473    fn validate_rec(
474        &self,
475        opts: &EvalOpts,
476        slice: &Slice,
477        top: &Selection,
478        dim: usize,
479    ) -> Result<(), ShapeError> {
480        if dim == slice.num_dim() {
481            // This enables us to maintain identities like 'all(true)
482            // <=> true' and 'all(false) <=> false' in leaf positions.
483            match self {
484                Selection::True | Selection::False => return Ok(()),
485                _ => {
486                    return Err(ShapeError::SelectionTooDeep {
487                        expr: top.clone(),
488                        num_dim: slice.num_dim(),
489                    });
490                }
491            }
492        }
493
494        match self {
495            Selection::False | Selection::True => Ok(()),
496            Selection::Range(range, s) => {
497                if range.is_empty() && opts.disallow_empty_ranges {
498                    return Err(ShapeError::EmptyRange {
499                        range: range.clone(),
500                    });
501                } else {
502                    if opts.disallow_out_of_range {
503                        let size = slice.sizes()[dim];
504                        let (min, _, _) = range.resolve(size);
505                        if min >= size {
506                            // Use EmptyRange here for now (evaluation would result in an empty range),
507                            // until we figure out how to differentiate between slices and shapes
508                            return Err(ShapeError::EmptyRange {
509                                range: range.clone(),
510                            });
511                        }
512                    }
513
514                    s.validate_rec(opts, slice, top, dim + 1)?;
515                }
516
517                Ok(())
518            }
519            Selection::Any(s) => {
520                if opts.disallow_dynamic_selections {
521                    return Err(ShapeError::SelectionDynamic { expr: top.clone() });
522                }
523                s.validate_rec(opts, slice, top, dim + 1)?;
524                Ok(())
525            }
526            Selection::All(s) | Selection::Label(_, s) | Selection::First(s) => {
527                s.validate_rec(opts, slice, top, dim + 1)?;
528                Ok(())
529            }
530            Selection::Intersection(a, b) | Selection::Union(a, b) => {
531                a.validate_rec(opts, slice, top, dim)?;
532                b.validate_rec(opts, slice, top, dim)?;
533                Ok(())
534            }
535        }
536    }
537
538    /// Lazily evaluates this selection against the given `slice`
539    /// yielding flat indices.
540    ///
541    /// Returns a boxed iterator that produces indices of elements
542    /// matching the selection expression when evaluated over `slice`.
543    ///
544    /// # Lifetimes
545    ///
546    /// The returned iterator borrows `slice` because the internal
547    /// iterators are implemented as closures that **capture**
548    /// `&slice` in their environment. Evaluation is lazy, so these
549    /// closures dereference `slice` each time a coordinate is
550    /// visited. The `'a` lifetime ensures that the iterator cannot
551    /// outlive the `slice` it reads from.
552    ///
553    /// # Why `Box<dyn Iterator>`
554    ///
555    /// The selection algebra supports multiple recursive strategies
556    /// (`All`, `Range`, `Intersection`, etc.) that return different
557    /// iterator types (e.g. `Selection::True` =>
558    /// `std::iter::once(...)`, `Selection::False` =>
559    /// `std::iter::empty()`). Returning `impl Iterator` is not
560    /// feasible because the precise type depends on dynamic selection
561    /// structure. Boxing erases this variability and allows a uniform
562    /// return type.
563    ///
564    /// # Canonical handling of 0-dimensional slices
565    ///
566    /// A `Slice` with zero dimensions represents the empty product
567    /// `∏_{i=1}^{0} Xᵢ`, which has exactly one element: the empty
568    /// tuple. To ensure that evaluation behaves uniformly across
569    /// dimensions, we canonically embed the 0-dimensional case into a
570    /// 1-dimensional slice of extent 1. That is, we reinterpret the
571    /// 0D slice as `Slice::new(offset, [1], [1])`, which is
572    /// semantically equivalent and enables evaluation to proceed
573    /// through the normal recursive machinery without special-casing.
574    /// The result is that selection expressions are always evaluated
575    /// over a slice with at least one dimension, and uniform logic
576    /// applies.
577    pub fn eval<'a>(
578        &self,
579        opts: &EvalOpts,
580        slice: &'a Slice,
581    ) -> Result<Box<dyn Iterator<Item = usize> + 'a>, ShapeError> {
582        // Canonically embed 0D as 1D (extent 1).
583        if slice.num_dim() == 0 {
584            let slice = Slice::new(slice.offset(), vec![1], vec![1]).unwrap();
585            return Ok(Box::new(
586                self.validate(opts, &slice)?
587                    .eval_rec(&slice, vec![0; 1], 0)
588                    .collect::<Vec<_>>()
589                    .into_iter(),
590            ));
591        }
592
593        Ok(self
594            .validate(opts, slice)?
595            .eval_rec(slice, vec![0; slice.num_dim()], 0))
596    }
597
598    fn eval_rec<'a>(
599        &self,
600        slice: &'a Slice,
601        env: Vec<usize>,
602        dim: usize,
603    ) -> Box<dyn Iterator<Item = usize> + 'a> {
604        if dim == slice.num_dim() {
605            match self {
606                Selection::True => return Box::new(std::iter::once(slice.location(&env).unwrap())),
607                Selection::False => return Box::new(std::iter::empty()),
608                _ => {
609                    panic!("structural combinator {self:?} at leaf level (dim = {dim}))",);
610                }
611            }
612        }
613
614        use itertools;
615        use itertools::EitherOrBoth;
616
617        match self {
618            Selection::False => Box::new(std::iter::empty()),
619            Selection::True => Box::new((0..slice.sizes()[dim]).flat_map(move |i| {
620                let mut env = env.clone();
621                env[dim] = i;
622                Selection::True.eval_rec(slice, env, dim + 1)
623            })),
624            Selection::All(select) => {
625                let select = Box::clone(select);
626                Box::new((0..slice.sizes()[dim]).flat_map(move |i| {
627                    let mut env = env.clone();
628                    env[dim] = i;
629                    select.eval_rec(slice, env, dim + 1)
630                }))
631            }
632            Selection::First(select) => {
633                let select = Box::clone(select);
634                Box::new(iterutils::first(slice.sizes()[dim], move |i| {
635                    let mut env = env.clone();
636                    env[dim] = i;
637                    select.eval_rec(slice, env, dim + 1)
638                }))
639            }
640            Selection::Range(range, select) => {
641                let select = Box::clone(select);
642                let (min, max, step) = range.resolve(slice.sizes()[dim]);
643                Box::new((min..max).step_by(step).flat_map(move |i| {
644                    let mut env = env.clone();
645                    env[dim] = i;
646                    select.eval_rec(slice, env, dim + 1)
647                }))
648            }
649
650            // Label-based selection: filters candidates at this
651            // dimension, then either selects one (Any) or recurses.
652            //
653            // When the inner selection is `Any`, we choose one match
654            // at random (eager). Otherwise, we recurse normally and
655            // filter the results lazily.
656            //
657            // This separation reflects that `Label(...)` does *not*
658            // consume a dimension — it restricts access to it while
659            // preserving dimensional structure.
660            //
661            // See `eval_label` for more on the distinction between
662            // filtering and traversal, and the underlying
663            // projection-based interpretation.
664            //
665            // For example:
666            //
667            //   sel!(*, ["foo"]?, *)  // select one host with label "foo", then all GPUs
668            //   = all(label(["foo"], any(all(true_()))))
669            //
670            //   sel!(*, ["foo"]*, *)  // select all hosts with label "foo", then all GPUs
671            //   = all(label(["foo"], all(all(true_()))))
672            //
673            // **Note:** Label filtering is not yet implemented — all coordinates
674            // are currently accepted.
675            Selection::Label(labels, inner) => {
676                Self::eval_label(labels, inner, slice, env, dim /*, provider */)
677            }
678            Selection::Any(select) => {
679                let select = Box::clone(select);
680                let r = {
681                    let upper = slice.sizes()[dim];
682                    let mut rng = rand::thread_rng();
683                    rng.gen_range(0..upper)
684                };
685                Box::new((r..r + 1).flat_map(move |i| {
686                    let mut env = env.clone();
687                    env[dim] = i;
688                    select.eval_rec(slice, env, dim + 1)
689                }))
690            }
691            Selection::Intersection(a, b) => Box::new(
692                itertools::merge_join_by(
693                    a.eval_rec(slice, env.clone(), dim),
694                    b.eval_rec(slice, env.clone(), dim),
695                    |x, y| x.cmp(y),
696                )
697                .filter_map(|either| match either {
698                    EitherOrBoth::Both(x, _) => Some(x),
699                    _ => None,
700                }),
701            ),
702            Selection::Union(a, b) => Box::new(
703                itertools::merge_join_by(
704                    a.eval_rec(slice, env.clone(), dim),
705                    b.eval_rec(slice, env.clone(), dim),
706                    |x, y| x.cmp(y),
707                )
708                .map(|either| match either {
709                    EitherOrBoth::Left(x) => x,
710                    EitherOrBoth::Right(y) => y,
711                    EitherOrBoth::Both(x, _) => x,
712                }),
713            ),
714        }
715    }
716
717    /// Evaluates a `Label(labels, inner)` selection.
718    ///
719    /// This operator filters coordinates along the current dimension
720    /// based on associated metadata (labels). It then evaluates the inner
721    /// selection at matching positions.
722    ///
723    /// Conceptually, this corresponds to computing a pullback along a
724    /// projection `p : E → B`, where:
725    ///
726    /// - `B` is the base space of coordinates (e.g. zones × hosts × gpus)
727    /// - `E` is the space of labeled coordinates
728    /// - `p⁻¹(S)` lifts a geometric selection `S ⊆ B` into the labeled
729    ///   space
730    ///
731    /// At runtime, we simulate `p⁻¹(S)` by traversing `B` and querying a
732    /// `LabelProvider` at each coordinate. Under the identity provider,
733    /// label filtering has no effect and `eval_label` reduces to the
734    /// geometric case.
735    ///
736    /// - If `inner` is `Any`, we select one matching index at random
737    /// - Otherwise, we recurse and filter lazily
738    ///
739    /// **Note:** Label filtering is not yet implemented — all coordinates
740    /// are currently accepted.
741    fn eval_label<'a>(
742        _labels: &[LabelKey],
743        inner: &Selection,
744        slice: &'a Slice,
745        env: Vec<usize>,
746        dim: usize,
747        // provider: &dyn LabelProvider  // TODO: add when ready
748    ) -> Box<dyn Iterator<Item = usize> + 'a> {
749        match inner {
750            // Case 1: label(..., any(...))
751            // - We evaluate all indices at this dimension that match
752            //   the label predicate.
753            // - From those, choose one at random and continue
754            //   evaluating the inner selection.
755            // - Semantically: filter → choose one → recurse
756            Selection::Any(sub_inner) => {
757                let matching: Vec<usize> = (0..slice.sizes()[dim])
758                    .filter(|&i| {
759                        let mut prefix = env.clone();
760                        prefix[dim] = i;
761                        true // TODO: provider.matches(dim, &prefix[0..=dim], labels)
762                    })
763                    .collect();
764
765                if matching.is_empty() {
766                    return Box::new(std::iter::empty());
767                }
768
769                let mut rng = rand::thread_rng();
770                let chosen = matching[rng.gen_range(0..matching.len())];
771
772                let mut coord = env;
773                coord[dim] = chosen;
774                sub_inner.eval_rec(slice, coord, dim + 1 /*, provider */)
775            }
776            // Case 2: label(..., inner)
777            //
778            // Applies label filtering after evaluating `inner`. We
779            // first recurse into `inner`, then lazily filter the
780            // resulting flat indices based on whether the coordinate
781            // at `dim` matches the given labels.
782            //
783            // This preserves laziness for all cases except `Any`,
784            // which requires eager collection and is handled
785            // separately.
786            _ => {
787                // evaluate the inner selection — recurse as usual
788                let iter = inner.eval_rec(slice, env.clone(), dim /* , provider */);
789                Box::new(iter.filter(move |&flat| {
790                    let _coord = slice.coordinates(flat);
791                    true // TODO: provider.matches(dim, &coord[0..=dim], labels)
792                }))
793            }
794        }
795    }
796
797    /// Returns `true` if this selection is equivalent to `True` under
798    /// the algebra.
799    ///
800    /// In the selection algebra, `All(True)` is considered equivalent
801    /// to `True`, and this identity extends recursively. For example:
802    ///
803    ///   - `All(True)`      ≡ `True`
804    ///   - `All(All(True))` ≡ `True`
805    ///   - `All(All(All(True)))` ≡ `True`
806    ///
807    /// This method checks whether the selection is structurally
808    /// identical to True, possibly wrapped in one or more All(...)
809    /// layers. It does **not** perform full normalization—only
810    /// structural matching sufficient to recognize this identity.
811    ///
812    /// Used to detect when a selection trivially selects all elements
813    /// at all levels.
814    ///
815    /// ## Limitations
816    ///
817    /// This is a **syntactic check** only. It does *not* recognize
818    /// semantically equivalent expressions such as:
819    ///
820    ///   - `Union(True, True)`
821    ///   - `All(Union(True, False))`
822    ///   - A union of all singleton ranges covering the full space
823    ///
824    /// For a semantic check, use evaluation against a known slice.
825    pub fn is_equivalent_to_true(mut sel: &Selection) -> bool {
826        while let Selection::All(inner) = sel {
827            sel = inner;
828        }
829        matches!(sel, Selection::True)
830    }
831
832    /// Evaluates whether the specified coordinates are part of the selection.
833    /// Returns true if they are, false otherwise.
834    ///
835    /// Example:
836    /// let selection = union(
837    ///     range(0..2, range(0..1, range(0..2, true_()))),
838    ///     range(0..2, range(1..2, range(0..2, true_()))),
839    /// );
840    ///
841    /// assert!(selection.contains(&[0, 0, 1]));
842    /// assert!(!selection.contains(&[2, 0, 1]));
843    pub fn contains(&self, coords: &[usize]) -> bool {
844        self.contains_rec(coords, 0)
845    }
846
847    fn contains_rec(&self, coords: &[usize], dim: usize) -> bool {
848        if dim >= coords.len() {
849            return matches!(self, Selection::True);
850        }
851
852        match self {
853            Selection::False => false,
854            Selection::True => true,
855            Selection::All(inner) => inner.contains_rec(coords, dim + 1),
856            Selection::Range(range, inner) => {
857                let (min, max, step) = range.resolve(coords.len());
858                let index = coords[dim];
859                index >= min
860                    && index < max
861                    && (index - min) % step == 0
862                    && inner.contains_rec(coords, dim + 1)
863            }
864            Selection::Intersection(a, b) => {
865                a.contains_rec(coords, dim) && b.contains_rec(coords, dim)
866            }
867            Selection::Union(a, b) => a.contains_rec(coords, dim) || b.contains_rec(coords, dim),
868            Selection::Label(_, _) | Selection::First(_) | Selection::Any(_) => {
869                unimplemented!()
870            }
871        }
872    }
873
874    /// Simplifies the intersection of two `Selection` expressions.
875    ///
876    /// Applies short-circuit logic to avoid constructing redundant or
877    /// degenerate intersections:
878    ///
879    /// - If either side is `False`, the result is `False`.
880    /// - If either side is `True`, the result is the other side.
881    /// - Otherwise, constructs an explicit `Intersection`.
882    ///
883    /// This is required during routing to make progress when
884    /// evaluating intersections. Without this reduction, routing may
885    /// stall — for example, in intersections like `Intersection(True,
886    /// X)`, which should simplify to `X`.
887    pub fn reduce_intersection(self: Selection, b: Selection) -> Selection {
888        match (&self, &b) {
889            (Selection::False, _) | (_, Selection::False) => Selection::False,
890            (Selection::True, other) | (other, Selection::True) => other.clone(),
891            _ => Selection::Intersection(Box::new(self), Box::new(b)),
892        }
893    }
894
895    /// Canonicalizes this selection to the specified number of
896    /// dimensions.
897    ///
898    /// Ensures that the selection has exactly `num_dims` dimensions
899    /// by recursively wrapping it in combinators (`All`, `Any`, etc.)
900    /// where needed. This transformation enforces a canonical
901    /// structural form suitable for dimensional evaluation (e.g.,
902    /// routing via `next_steps`).
903    ///
904    /// Examples:
905    /// - `True` becomes `All(All(...(True)))`
906    /// - `Any(True)` becomes `Any(Any(...(True)))`
907    /// - Fully specified selections are left unchanged.
908    ///
909    /// ---
910    ///
911    /// # Panics
912    /// Panics if `num_dims == 0`. Use a canonical embedding (e.g., 0D
913    /// → 1D) before calling this (see e.g. `RoutingFrame::root`).
914    pub(crate) fn canonicalize_to_dimensions(self, num_dims: usize) -> Selection {
915        assert!(
916            num_dims > 0,
917            "canonicalize_to_dimensions requires num_dims > 0"
918        );
919        self.canonicalize_to_dimensions_rec(0, num_dims)
920    }
921
922    fn canonicalize_to_dimensions_rec(self, dim: usize, num_dims: usize) -> Selection {
923        use crate::selection::dsl::*;
924
925        match self {
926            Selection::True if dim < num_dims => {
927                let mut out = true_();
928                for _ in (dim..num_dims).rev() {
929                    out = all(out);
930                }
931                out
932            }
933            Selection::False if dim < num_dims => {
934                let mut out = false_();
935                for _ in (dim..num_dims).rev() {
936                    out = all(out);
937                }
938                out
939            }
940            Selection::Any(inner) if dim < num_dims && matches!(*inner, Selection::True) => {
941                let mut out = true_();
942                for _ in (dim..num_dims).rev() {
943                    out = any(out);
944                }
945                out
946            }
947            Selection::All(inner) => all(inner.canonicalize_to_dimensions_rec(dim + 1, num_dims)),
948            Selection::Any(inner) => any(inner.canonicalize_to_dimensions_rec(dim + 1, num_dims)),
949            Selection::First(inner) => {
950                first(inner.canonicalize_to_dimensions_rec(dim + 1, num_dims))
951            }
952            Selection::Range(r, inner) => {
953                range(r, inner.canonicalize_to_dimensions_rec(dim + 1, num_dims))
954            }
955            Selection::Intersection(a, b) => intersection(
956                a.canonicalize_to_dimensions_rec(dim, num_dims),
957                b.canonicalize_to_dimensions_rec(dim, num_dims),
958            ),
959            Selection::Union(a, b) => union(
960                a.canonicalize_to_dimensions_rec(dim, num_dims),
961                b.canonicalize_to_dimensions_rec(dim, num_dims),
962            ),
963
964            other => other,
965        }
966    }
967
968    /// Recursively folds the `Selection` into an abstract syntax via
969    /// the `SelectionSYM` interface.
970    ///
971    /// This method structurally traverses the `Selection` tree and
972    /// reconstructs it using the operations provided by the
973    /// `SelectionSYM` trait. It is typically used to reify a
974    /// `Selection` into alternate forms, such as pretty-printers.
975    ///
976    /// # Type Parameters
977    ///
978    /// - `S`: An implementation of the `SelectionSYM` trait,
979    ///   providing the constructors for the target representation.
980    pub fn fold<S: SelectionSYM>(&self) -> S {
981        match self {
982            Selection::False => S::false_(),
983            Selection::True => S::true_(),
984            Selection::All(inner) => S::all(inner.fold::<S>()),
985            Selection::First(inner) => S::first(inner.fold::<S>()),
986            Selection::Range(r, inner) => S::range(r.clone(), inner.fold::<S>()),
987            Selection::Label(labels, inner) => S::label(labels.clone(), inner.fold::<S>()),
988            Selection::Any(inner) => S::any(inner.fold::<S>()),
989            Selection::Intersection(a, b) => S::intersection(a.fold::<S>(), b.fold::<S>()),
990            Selection::Union(a, b) => S::union(a.fold::<S>(), b.fold::<S>()),
991        }
992    }
993
994    /// Iterator over indices selected by `self` and not in
995    /// `exclusions`.
996    ///
997    /// Evaluates the selection against `slice` using `opts`, then
998    /// filters out any indices present in the exclusion set.
999    /// Evaluation is lazy and streaming; the exclusion set is used
1000    /// directly for fast membership checks.
1001    pub fn difference<'a>(
1002        &self,
1003        opts: &EvalOpts,
1004        slice: &'a Slice,
1005        exclusions: &'a HashSet<usize>,
1006    ) -> Result<impl Iterator<Item = usize> + use<'a>, ShapeError> {
1007        Ok(self
1008            .eval(opts, slice)?
1009            .filter(move |idx| !exclusions.contains(idx)))
1010    }
1011
1012    /// Calculate a new `Selection` that excludes the specified flat
1013    /// ranks.
1014    ///
1015    /// This computes `self \ exclusions` by evaluating `self`,
1016    /// removing the given ranks, and reconstructing a `Selection`
1017    /// that selects exactly the remaining elements.
1018    ///
1019    /// The result is a concrete, structurally uniform expression with
1020    /// predictable construction order and exact correspondence to the
1021    /// surviving ranks.
1022    pub fn without(
1023        &self,
1024        slice: &Slice,
1025        exclusions: &HashSet<usize>,
1026    ) -> Result<Selection, ShapeError> {
1027        let remaining = self
1028            .difference(&EvalOpts::strict(), slice, exclusions)?
1029            .collect::<BTreeSet<_>>();
1030        Ok(Selection::of_ranks(slice, &remaining)?)
1031    }
1032
1033    /// Converts a set of flat indices into a symbolic `Selection`
1034    /// expression over the given `slice`. Returns an error if any index
1035    /// is invalid.
1036    ///
1037    /// Each flat index is converted into coordinates using
1038    /// `slice.coordinates`, then folded into a nested chain of singleton
1039    /// ranges. The resulting selection evaluates exactly to the input
1040    /// indices.
1041    ///
1042    /// The selections are combined left-associatively using `union`, but
1043    /// since `union` is associative, the grouping does not affect
1044    /// correctness.
1045    ///
1046    /// The input `BTreeSet` ensures:
1047    /// - all indices are unique (no redundant singleton ranges),
1048    /// - the resulting selection has a stable, deterministic structure,
1049    /// - and iteration proceeds in ascending order, which helps produce
1050    ///   predictable routing trees and consistent test results.
1051    ///
1052    /// This choice avoids an explicit sort and makes downstream behavior
1053    /// more reproducible and auditable.
1054    pub fn of_ranks(slice: &Slice, ranks: &BTreeSet<usize>) -> Result<Selection, SliceError> {
1055        let selections = ranks
1056            .iter()
1057            .map(|&i| {
1058                Ok(slice
1059                    .coordinates(i)?
1060                    .into_iter()
1061                    .rev()
1062                    .fold(dsl::true_(), |acc, i| dsl::range(i..=i, acc)))
1063            })
1064            .collect::<Result<Vec<_>, SliceError>>()?;
1065
1066        Ok(selections
1067            .into_iter()
1068            .reduce(dsl::union)
1069            .unwrap_or_else(dsl::false_))
1070    }
1071} // impl Selection
1072
1073mod sealed {
1074    pub trait Sealed {}
1075    impl Sealed for crate::slice::Slice {}
1076}
1077
1078/// Connects the `select!` API to the `Selection` algebra by enabling
1079/// `base.reify_slice(slice)` syntax, where `base: Slice`.
1080///
1081/// The base slice defines the coordinate system in which the slice is
1082/// interpreted. Slices are themselves `Slice` values, typically
1083/// produced by `select!`, and are reified as `Selection` expressions
1084/// over the base.
1085pub trait ReifySlice: sealed::Sealed {
1086    /// Reify a slice as a `Selection` in the coordinate system of
1087    /// `self`.
1088    fn reify_slice(&self, slice: &Slice) -> Result<Selection, SliceError>;
1089
1090    /// Reify multiple slices as a union of selections in the
1091    /// coordinate system of `self`.
1092    fn reify_slices<V: AsRef<[Slice]>>(&self, slices: V) -> Result<Selection, SliceError>;
1093}
1094
1095impl ReifySlice for Slice {
1096    /// Constructs a [`Selection`] expression that symbolically
1097    /// matches all coordinates in the given `slice`, expressed in the
1098    /// coordinate system of the provided `base` slice (`self`).
1099    ///
1100    /// The result is a nested sequence of `range(start..end, step)`
1101    /// combinators that match the rectangular region covered by `slice`
1102    /// in base coordinates. This preserves geometry and layout when
1103    /// `slice` is *layout-aligned* — that is, each of its strides is
1104    /// a multiple of the corresponding base stride.
1105    ///
1106    /// If any dimension is not layout-aligned, the slice is reified
1107    /// by explicitly enumerating its coordinates.
1108    ///
1109    /// Returns [`dsl::false_()`] if the slice is empty.
1110    ///
1111    /// # Errors
1112    ///
1113    /// Returns an error if:
1114    /// - The base is not contiguous and row-major
1115    /// - The slice lies outside the bounds of the base
1116    ///
1117    /// # Example
1118    ///
1119    /// ```rust
1120    /// use ndslice::selection::ReifySlice;
1121    /// let shape = ndslice::shape!(x = 4, y = 4);
1122    /// let base = shape.slice();
1123    /// let selected = ndslice::select!(shape, x = 1..3, y = 2..4).unwrap();
1124    /// let slice = selected.slice();
1125    /// let selection = base.reify_slice(slice).unwrap();
1126    /// ```
1127    fn reify_slice(&self, slice: &Slice) -> Result<Selection, SliceError> {
1128        // Precondition: the base is contiguous and row major.
1129        if !self.is_contiguous() {
1130            return Err(SliceError::NonContiguous);
1131        }
1132
1133        if slice.is_empty() {
1134            return Ok(dsl::false_());
1135        }
1136
1137        if slice.num_dim() != self.num_dim()
1138            || slice.sizes().iter().zip(self.sizes()).any(|(&v, &s)| v > s)
1139        {
1140            return Selection::of_ranks(self, &slice.iter().collect::<BTreeSet<usize>>());
1141        }
1142
1143        let origin = self.coordinates(slice.offset())?;
1144        let mut acc = dsl::true_();
1145        for dim in (0..self.num_dim()).rev() {
1146            let start = origin[dim];
1147            let len = slice.sizes()[dim];
1148            let slice_stride = slice.strides()[dim];
1149            let base_stride = self.strides()[dim];
1150
1151            if slice_stride % base_stride == 0 {
1152                // Layout-aligned with base.
1153                let step = slice_stride / base_stride;
1154                let end = start + step * len;
1155                // Check that `end` is within bounds.
1156                if end - 1 > self.sizes()[dim] {
1157                    let bad = origin
1158                        .iter()
1159                        .enumerate()
1160                        .map(|(d, &x)| if d == dim { end } else { x })
1161                        .collect::<Vec<_>>();
1162                    return Err(SliceError::ValueNotInSlice {
1163                        value: self.location(&bad).unwrap(),
1164                    });
1165                }
1166                acc = dsl::range(crate::shape::Range(start, Some(end), step), acc);
1167            } else {
1168                // Irregular layout; fallback to explicit enumeration.
1169                return Selection::of_ranks(self, &slice.iter().collect::<BTreeSet<_>>());
1170            }
1171        }
1172
1173        Ok(acc)
1174    }
1175
1176    /// Converts a list of `slices` into a symbolic [`Selection`]
1177    /// expression over a common `base` [`Slice`].
1178    ///
1179    /// Each slice describes a rectangular subregion of the base. This
1180    /// function reifies each slice into a nested `range(.., ..)`
1181    /// expression in the base coordinate system and returns the union
1182    /// of all such selections.
1183    ///
1184    /// Empty slices are ignored.
1185    ///
1186    /// # Errors
1187    ///
1188    /// Returns an error if any slice:
1189    /// - Refers to coordinates not contained within the base
1190    ///
1191    /// # Example
1192    ///
1193    /// ```rust
1194    /// use ndslice::selection::ReifySlice;
1195    ///
1196    /// let shape = ndslice::shape!(x = 4, y = 4);
1197    /// let base = shape.slice();
1198    ///
1199    /// let a = ndslice::select!(shape, x = 0..2, y = 0..2)
1200    ///     .unwrap()
1201    ///     .slice()
1202    ///     .clone();
1203    /// let b = ndslice::select!(shape, x = 2..4, y = 2..4)
1204    ///     .unwrap()
1205    ///     .slice()
1206    ///     .clone();
1207    ///
1208    /// let sel = base.reify_slices(&[a, b]).unwrap();
1209    /// ```
1210    fn reify_slices<V: AsRef<[Slice]>>(&self, slices: V) -> Result<Selection, SliceError> {
1211        let slices = slices.as_ref();
1212        let mut selections = Vec::with_capacity(slices.len());
1213
1214        for slice in slices {
1215            if slice.is_empty() {
1216                continue;
1217            }
1218            selections.push(self.reify_slice(slice)?);
1219        }
1220
1221        let mut iter = selections.into_iter();
1222        let first = iter.next().unwrap_or_else(dsl::false_);
1223        Ok(iter.fold(first, dsl::union))
1224    }
1225}
1226
1227/// Trivial all(true) equivalence.
1228pub fn is_equivalent_true(sel: impl std::borrow::Borrow<Selection>) -> bool {
1229    Selection::is_equivalent_to_true(sel.borrow())
1230}
1231
1232mod iterutils {
1233    // An iterator over the first non-empty result 1 applying
1234    // `mk_iter` to indices in the range `0..size`.
1235    pub(crate) fn first<'a, F>(size: usize, mut mk_iter: F) -> impl Iterator<Item = usize> + 'a
1236    where
1237        F: FnMut(usize) -> Box<dyn Iterator<Item = usize> + 'a>,
1238    {
1239        (0..size)
1240            .find_map(move |i| {
1241                let mut iter = mk_iter(i).peekable();
1242                if iter.peek().is_some() {
1243                    Some(iter)
1244                } else {
1245                    None
1246                }
1247            })
1248            .into_iter()
1249            .flatten()
1250    }
1251}
1252
1253/// Construct a [`Selection`] from a [`Shape`] and a single labeled
1254/// constraint.
1255///
1256/// This function produces a multidimensional selection expression
1257/// that is structurally aligned with the shape. It applies the given
1258/// range to the named dimension, and fills all preceding dimensions
1259/// with [`all`] to maintain alignment. Trailing dimensions are left
1260/// unconstrained.
1261///
1262/// # Arguments
1263///
1264/// - `shape`: The labeled shape describing the coordinate space.
1265/// - `label`: The name of the dimension to constrain.
1266/// - `rng`: The range to apply in the selected dimension.
1267///
1268/// # Errors
1269///
1270/// Returns [`ShapeError::InvalidLabels`] if the label is not present
1271/// in the shape.
1272///
1273/// # Example
1274///
1275/// ```
1276/// use ndslice::shape;
1277/// use ndslice::selection::selection_from_one;
1278///
1279/// let shape = shape!(zone = 2, host = 4, gpu = 8);
1280/// let sel = selection_from_one(&shape, "host", 1..3).unwrap();
1281/// assert_eq!(format!("{sel:?}"), "All(Range(Range(1, Some(3), 1), True))"); // corresponds to (*, 1..3, *)
1282/// ```
1283///
1284/// [`all`]: crate::selection::dsl::all
1285/// [`Shape`]: crate::shape::Shape
1286/// [`Selection`]: crate::selection::Selection
1287pub fn selection_from_one<'a, R>(
1288    shape: &shape::Shape,
1289    label: &'a str,
1290    rng: R,
1291) -> Result<Selection, ShapeError>
1292where
1293    R: Into<shape::Range>,
1294{
1295    use crate::selection::dsl;
1296
1297    let Some(pos) = shape.labels().iter().position(|l| l == label) else {
1298        return Err(ShapeError::InvalidLabels {
1299            labels: vec![label.to_string()],
1300        });
1301    };
1302
1303    let mut selection = dsl::range(rng.into(), dsl::true_());
1304    for _ in 0..pos {
1305        selection = dsl::all(selection)
1306    }
1307
1308    Ok(selection)
1309}
1310
1311/// Construct a [`Selection`] from a [`Shape`] and multiple labeled
1312/// range constraints.
1313///
1314/// This function produces a multidimensional selection aligned with
1315/// the shape, applying the specified constraints to their
1316/// corresponding dimensions. All unconstrained dimensions are filled
1317/// with [`all`] to preserve structural alignment.
1318///
1319/// # Arguments
1320///
1321/// - `shape`: The labeled shape defining the full coordinate space.
1322/// - `constraints`: A slice of `(label, range)` pairs specifying
1323///   dimension constraints.
1324///
1325/// # Errors
1326///
1327/// Returns [`ShapeError::InvalidLabels`] if any label in
1328/// `constraints` is not present in the shape.
1329///
1330/// # Example
1331///
1332/// ```
1333/// use ndslice::selection::selection_from;
1334/// use ndslice::shape;
1335///
1336/// let shape = shape!(zone = 2, host = 4, gpu = 8);
1337/// let sel = selection_from(&shape, &[("host", 1..3), ("gpu", 0..4)]).unwrap();
1338/// assert_eq!(
1339///     format!("{sel:?}"),
1340///     "All(Range(Range(1, Some(3), 1), Range(Range(0, Some(4), 1), True)))"
1341/// );
1342/// ```
1343///
1344/// [`Shape`]: crate::shape::Shape
1345/// [`Selection`]: crate::selection::Selection
1346/// [`all`]: crate::selection::dsl::all
1347pub fn selection_from<'a, R>(
1348    shape: &shape::Shape,
1349    constraints: &[(&'a str, R)],
1350) -> Result<Selection, ShapeError>
1351where
1352    R: Clone + Into<shape::Range> + 'a,
1353{
1354    use crate::selection::dsl::*;
1355
1356    let mut label_to_constraint = HashMap::new();
1357    for (label, r) in constraints {
1358        if !shape.labels().iter().any(|l| l == label) {
1359            return Err(ShapeError::InvalidLabels {
1360                labels: vec![label.to_string()],
1361            });
1362        }
1363        label_to_constraint.insert(*label, r.clone().into());
1364    }
1365
1366    let selection = shape.labels().iter().rev().fold(true_(), |acc, label| {
1367        if let Some(rng) = label_to_constraint.get(label.as_str()) {
1368            range(rng.clone(), acc)
1369        } else {
1370            all(acc)
1371        }
1372    });
1373
1374    Ok(selection)
1375}
1376
1377/// Construct a [`Selection`] from a [`Shape`] using labeled dimension
1378/// constraints.
1379///
1380/// This macro provides a convenient syntax for specifying
1381/// sub-selections on a shape by labeling dimensions and applying
1382/// either exact indices or ranges. Internally, it wraps
1383/// [`selection_from_one`] and [`selection_from`] to produce a
1384/// fully-aligned [`Selection`] expression.
1385///
1386/// # Forms
1387///
1388/// - Single labeled range:
1389///   ```
1390///   let shape = ndslice::shape!(zone = 2, host = 4, gpu = 8);
1391///   let sel = ndslice::sel_from_shape!(&shape, host = 1..3);
1392///   ```
1393///
1394/// - Multiple exact indices (converted to `n..n+1`):
1395///   ```
1396///   let shape = ndslice::shape!(zone = 2, host = 4, gpu = 8);
1397///   let sel = ndslice::sel_from_shape!(&shape, zone = 1, gpu = 4);
1398///   ```
1399///
1400/// - Multiple labeled ranges:
1401///   ```
1402///   let shape = ndslice::shape!(zone = 2, host = 4, gpu = 8);
1403///   let sel = ndslice::sel_from_shape!(&shape, zone = 0..1, host = 1..3, gpu = 4..8);
1404///   ```
1405///
1406/// # Panics
1407///
1408/// This macro calls `.unwrap()` on the result of the underlying
1409/// functions. It will panic if any label is not found in the shape.
1410///
1411/// # See Also
1412///
1413/// - [`selection_from_one`]
1414/// - [`selection_from`]
1415/// - [`Selection`]
1416/// - [`Shape`]
1417///
1418/// [`Selection`]: crate::selection::Selection
1419/// [`Shape`]: crate::shape::Shape
1420/// [`selection_from_one`]: crate::selection::selection_from_one
1421/// [`selection_from`]: crate::selection::selection_from
1422#[macro_export]
1423macro_rules! sel_from_shape {
1424    ($shape:expr_2021, $label:ident = $range:expr_2021) => {
1425        $crate::selection::selection_from_one($shape, stringify!($label), $range).unwrap()
1426    };
1427
1428    ($shape:expr_2021, $($label:ident = $val:literal),* $(,)?) => {
1429        $crate::selection::selection_from($shape,
1430                                          &[
1431                                              $((stringify!($label), $val..$val+1)),*
1432                                          ]).unwrap()
1433    };
1434
1435    ($shape:expr_2021, $($label:ident = $range:expr_2021),* $(,)?) => {
1436        $crate::selection::selection_from($shape, &[
1437            $((stringify!($label), $range)),*
1438        ]).unwrap()
1439    };
1440}
1441
1442#[cfg(test)]
1443mod tests {
1444    use std::assert_matches::assert_matches;
1445    use std::collections::BTreeSet;
1446
1447    use super::EvalOpts;
1448    use super::ReifySlice;
1449    use super::Selection;
1450    use super::dsl::*;
1451    use super::is_equivalent_true;
1452    use crate::Range;
1453    use crate::Slice;
1454    use crate::assert_structurally_eq;
1455    use crate::select;
1456    use crate::shape;
1457    use crate::shape::ShapeError;
1458
1459    // A test slice: (zones = 2, hosts = 4, gpus = 8).
1460    fn test_slice() -> Slice {
1461        Slice::new(0usize, vec![2, 4, 8], vec![32, 8, 1]).unwrap()
1462    }
1463
1464    // Given expression `expr`, options `opts` and slice `slice`,
1465    // cannonical usage is:
1466    // ```rust
1467    // let nodes = expr.eval(&opts, slice.clone())?.collect::<Vec<usize>>();
1468    // ```
1469    // This utility cuts down on the syntactic repetition that results
1470    // from the above in the tests that follow.
1471    fn eval(expr: Selection, slice: &Slice) -> Vec<usize> {
1472        expr.eval(&EvalOpts::lenient(), slice).unwrap().collect()
1473    }
1474
1475    #[test]
1476    fn test_selection_00() {
1477        let slice = &test_slice();
1478
1479        // No GPUs on any host in any region.
1480        assert!(eval(false_(), slice).is_empty());
1481        assert!(eval(all(false_()), slice).is_empty());
1482        assert!(eval(all(all(false_())), slice).is_empty());
1483
1484        // All GPUs on all hosts in all regions.
1485        assert_eq!((0..=63).collect::<Vec<_>>(), eval(true_(), slice));
1486        assert_eq!(eval(true_(), slice), eval(all(true_()), slice));
1487        assert_eq!(eval(all(true_()), slice), eval(all(all(true_())), slice));
1488
1489        // Terminal `true_()` and `false_()` selections are allowed at
1490        // the leaf.
1491        assert_eq!(eval(true_(), slice), eval(all(all(all(true_()))), slice));
1492        assert!(eval(all(all(all(false_()))), slice).is_empty());
1493    }
1494
1495    #[test]
1496    fn test_selection_01() {
1497        let slice = &test_slice();
1498
1499        // Structural combinators beyond the slice's dimensionality
1500        // are invalid.
1501        let expr = all(all(all(all(true_()))));
1502        let result = expr.validate(&EvalOpts::lenient(), slice);
1503        assert!(
1504            matches!(result, Err(ShapeError::SelectionTooDeep { .. })),
1505            "Unexpected: {:?}",
1506            result
1507        );
1508        assert_eq!(
1509            format!("{}", result.unwrap_err()),
1510            "selection `all(all(all(all(true_()))))` exceeds dimensionality 3"
1511        );
1512    }
1513
1514    #[test]
1515    fn test_selection_02() {
1516        let slice = &test_slice();
1517
1518        // GPU 0 on host 0 in region 0.
1519        let select = range(0..=0, range(0..=0, range(0..=0, true_())));
1520        assert_eq!((0..=0).collect::<Vec<_>>(), eval(select, slice));
1521
1522        // GPU 1 on host 1 in region 1.
1523        let select = range(1..=1, range(1..=1, range(1..=1, true_())));
1524        assert_eq!((41..=41).collect::<Vec<_>>(), eval(select, slice));
1525
1526        // All GPUs on host 0 in all regions:
1527        let select = all(range(0..=0, all(true_())));
1528        assert_eq!(
1529            (0..=7).chain(32..=39).collect::<Vec<_>>(),
1530            eval(select, slice)
1531        );
1532
1533        // All GPUs on host 1 in all regions:
1534        let select = all(range(1..=1, all(true_())));
1535        assert_eq!(
1536            (8..=15).chain(40..=47).collect::<Vec<_>>(),
1537            eval(select, slice)
1538        );
1539
1540        // The first 4 GPUs on all hosts in all regions:
1541        let select = all(all(range(0..4, true_())));
1542        assert_eq!(
1543            (0..=3)
1544                .chain(8..=11)
1545                .chain(16..=19)
1546                .chain(24..=27)
1547                .chain(32..=35)
1548                .chain(40..=43)
1549                .chain(48..=51)
1550                .chain(56..=59)
1551                .collect::<Vec<_>>(),
1552            eval(select, slice)
1553        );
1554
1555        // The last 4 GPUs on all hosts in all regions:
1556        let select = all(all(range(4..8, true_())));
1557        assert_eq!(
1558            (4..=7)
1559                .chain(12..=15)
1560                .chain(20..=23)
1561                .chain(28..=31)
1562                .chain(36..=39)
1563                .chain(44..=47)
1564                .chain(52..=55)
1565                .chain(60..=63)
1566                .collect::<Vec<_>>(),
1567            eval(select, slice)
1568        );
1569
1570        // All regions, all hosts, odd GPUs:
1571        let select = all(all(range(shape::Range(1, None, 2), true_())));
1572        assert_eq!(
1573            (1..8)
1574                .step_by(2)
1575                .chain((9..16).step_by(2))
1576                .chain((17..24).step_by(2))
1577                .chain((25..32).step_by(2))
1578                .chain((33..40).step_by(2))
1579                .chain((41..48).step_by(2))
1580                .chain((49..56).step_by(2))
1581                .chain((57..64).step_by(2))
1582                .collect::<Vec<_>>(),
1583            eval(select, slice)
1584        );
1585    }
1586
1587    #[test]
1588    fn test_selection_03() {
1589        let slice = &test_slice();
1590
1591        assert_eq!(
1592            eval(intersection(true_(), true_()), slice),
1593            eval(true_(), slice)
1594        );
1595        assert_eq!(
1596            eval(intersection(true_(), false_()), slice),
1597            eval(false_(), slice)
1598        );
1599        assert_eq!(
1600            eval(intersection(false_(), true_()), slice),
1601            eval(false_(), slice)
1602        );
1603        assert_eq!(
1604            eval(intersection(false_(), false_()), slice),
1605            eval(false_(), slice)
1606        );
1607        assert_eq!(
1608            eval(
1609                intersection(
1610                    all(all(range(0..=3, true_()))),
1611                    all(all(range(4..=7, true_())))
1612                ),
1613                slice
1614            ),
1615            eval(false_(), slice)
1616        );
1617        assert_eq!(
1618            eval(intersection(true_(), all(all(range(4..8, true_())))), slice),
1619            eval(all(all(range(4..8, true_()))), slice)
1620        );
1621        assert_eq!(
1622            eval(
1623                intersection(
1624                    all(all(range(0..=4, true_()))),
1625                    all(all(range(4..=7, true_())))
1626                ),
1627                slice
1628            ),
1629            eval(all(all(range(4..=4, true_()))), slice)
1630        );
1631    }
1632
1633    #[test]
1634    fn test_selection_04() {
1635        let slice = &test_slice();
1636
1637        assert_eq!(eval(union(true_(), true_()), slice), eval(true_(), slice));
1638        assert_eq!(eval(union(false_(), true_()), slice), eval(true_(), slice));
1639        assert_eq!(eval(union(true_(), false_()), slice), eval(true_(), slice));
1640        assert_eq!(
1641            eval(union(false_(), false_()), slice),
1642            eval(false_(), slice)
1643        );
1644        assert_eq!(
1645            eval(
1646                union(
1647                    all(all(range(0..4, true_()))),
1648                    all(all(range(4.., true_())))
1649                ),
1650                slice
1651            ),
1652            eval(true_(), slice)
1653        );
1654
1655        // Across all regions, get the first 4 GPUs on host 0 and the
1656        // last 4 GPUs on host 1.
1657        let s = all(range(0..=0, range(0..4, true_())));
1658        let t = all(range(1..=1, range(4.., true_())));
1659        assert_eq!(
1660            (0..=3)
1661                .chain(12..=15)
1662                .chain(32..=35)
1663                .chain(44..=47)
1664                .collect::<Vec<_>>(),
1665            eval(union(s, t), slice)
1666        );
1667
1668        // All regions, all hosts, skip GPUs 2, 3, 4 and 5.
1669        assert_eq!(
1670            // z=0, h=0
1671            (0..=1)
1672                .chain(6..=7)
1673                // z=0, h=1
1674                .chain(8..=9)
1675                .chain(14..=15)
1676                // z=0, h=2
1677                .chain(16..=17)
1678                .chain(22..=23)
1679                // z=0, h=3
1680                .chain(24..=25)
1681                .chain(30..=31)
1682                // z=1, h=0
1683                .chain(32..=33)
1684                .chain(38..=39)
1685                // z=1, h=1
1686                .chain(40..=41)
1687                .chain(46..=47)
1688                // z=1, h=2
1689                .chain(48..=49)
1690                .chain(54..=55)
1691                // z=1, h=3
1692                .chain(56..=57)
1693                .chain(62..=63)
1694                .collect::<Vec<_>>(),
1695            eval(
1696                all(all(union(range(0..2, true_()), range(6..8, true_())))),
1697                slice
1698            )
1699        );
1700
1701        // All regions, all hosts, odd GPUs.
1702        assert_eq!(
1703            (1..8)
1704                .step_by(2)
1705                .chain((9..16).step_by(2))
1706                .chain((17..24).step_by(2))
1707                .chain((25..32).step_by(2))
1708                .chain((33..40).step_by(2))
1709                .chain((41..48).step_by(2))
1710                .chain((49..56).step_by(2))
1711                .chain((57..64).step_by(2))
1712                .collect::<Vec<_>>(),
1713            eval(
1714                all(all(union(
1715                    range(shape::Range(1, Some(4), 2), true_()),
1716                    range(shape::Range(5, Some(8), 2), true_())
1717                ))),
1718                slice
1719            )
1720        );
1721        assert_eq!(
1722            eval(
1723                all(all(union(
1724                    range(shape::Range(1, Some(4), 2), true_()),
1725                    range(shape::Range(5, Some(8), 2), true_())
1726                ))),
1727                slice
1728            ),
1729            eval(
1730                all(all(union(
1731                    union(range(1..=1, true_()), range(3..=3, true_()),),
1732                    union(range(5..=5, true_()), range(7..=7, true_()),),
1733                ))),
1734                slice
1735            ),
1736        );
1737    }
1738
1739    #[test]
1740    fn test_selection_05() {
1741        let slice = &test_slice();
1742
1743        // First region, first host, no GPU.
1744        assert!(eval(first(first(false_())), slice).is_empty());
1745        // First region, first host, first GPU.
1746        assert_eq!(vec![0], eval(first(first(range(0..1, true_()))), slice));
1747        // First region, first host, all GPUs.
1748        assert_eq!(
1749            (0..8).collect::<Vec<_>>(),
1750            eval(first(first(true_())), slice)
1751        );
1752
1753        // Terminal `true_()` and `false_()` selections are allowed at
1754        // the leaf.
1755        // First region, first host, no GPU.
1756        assert!(eval(first(first(first(false_()))), slice).is_empty());
1757        // First region, first host, first GPU.
1758        assert_eq!(vec![0], eval(first(first(first(true_()))), slice));
1759
1760        // All regions, first host, all GPUs.
1761        assert_eq!(
1762            (0..8).chain(32..40).collect::<Vec<_>>(),
1763            eval(all(first(true_())), slice)
1764        );
1765
1766        // First region, first host, GPUs 0, 1 and 2.
1767        assert_eq!(
1768            (0..3).collect::<Vec<_>>(),
1769            eval(first(first(range(0..=2, true_()))), slice)
1770        );
1771    }
1772
1773    #[test]
1774    fn test_selection_06() {
1775        let slice = &test_slice();
1776
1777        // Structural combinators beyond the slice's dimensionality
1778        // are invalid.
1779        let expr = first(first(first(first(true_()))));
1780        let result = expr.validate(&EvalOpts::lenient(), slice);
1781        assert!(
1782            matches!(result, Err(ShapeError::SelectionTooDeep { .. })),
1783            "Unexpected: {:?}",
1784            result
1785        );
1786        assert_eq!(
1787            format!("{}", result.unwrap_err()),
1788            "selection `first(first(first(first(true_()))))` exceeds dimensionality 3"
1789        );
1790    }
1791
1792    #[test]
1793    fn test_selection_07() {
1794        use crate::select;
1795        use crate::shape;
1796
1797        // 2 x 8 row-major.
1798        let s = shape!(host = 2, gpu = 8);
1799
1800        // All GPUs on host 1.
1801        assert_eq!(
1802            select!(s, host = 1)
1803                .unwrap()
1804                .slice()
1805                .iter()
1806                .collect::<Vec<_>>(),
1807            eval(range(1..2, true_()), s.slice())
1808        );
1809
1810        // All hosts, GPUs 2 through 7.
1811        assert_eq!(
1812            select!(s, gpu = 2..)
1813                .unwrap()
1814                .slice()
1815                .iter()
1816                .collect::<Vec<_>>(),
1817            eval(all(range(2.., true_())), s.slice())
1818        );
1819
1820        // All hosts, GPUs 3 and 4.
1821        assert_eq!(
1822            select!(s, gpu = 3..5)
1823                .unwrap()
1824                .slice()
1825                .iter()
1826                .collect::<Vec<_>>(),
1827            eval(all(range(3..5, true_())), s.slice())
1828        );
1829
1830        // GPUS 3 and 4 on host 1.
1831        assert_eq!(
1832            select!(s, gpu = 3..5, host = 1)
1833                .unwrap()
1834                .slice()
1835                .iter()
1836                .collect::<Vec<_>>(),
1837            eval(range(1..=1, range(3..5, true_())), s.slice())
1838        );
1839
1840        // All hosts, no GPUs.
1841        assert_matches!(
1842            select!(s, gpu = 1..1).unwrap_err(),
1843            ShapeError::EmptyRange {
1844                range: shape::Range(1, Some(1), 1)
1845            },
1846        );
1847        assert!(eval(all(range(1..1, true_())), s.slice()).is_empty());
1848
1849        // All hosts, GPU 8.
1850        assert_matches!(
1851            select!(s, gpu = 8).unwrap_err(),
1852            ShapeError::OutOfRange {
1853                range: shape::Range(8, Some(9), 1),
1854                dim,
1855                size: 8,
1856            } if dim == "gpu",
1857        );
1858        assert!(eval(all(range(8..8, true_())), s.slice()).is_empty());
1859    }
1860
1861    #[test]
1862    fn test_selection_08() {
1863        let s = &shape!(host = 2, gpu = 8);
1864
1865        assert_eq!(
1866            eval(range(1..2, true_()), s.slice()),
1867            eval(sel_from_shape!(s, host = 1), s.slice())
1868        );
1869
1870        assert_eq!(
1871            eval(all(range(2.., true_())), s.slice()),
1872            eval(sel_from_shape!(s, gpu = 2..), s.slice())
1873        );
1874
1875        assert_eq!(
1876            eval(all(range(3..5, true_())), s.slice()),
1877            eval(sel_from_shape!(s, gpu = 3..5), s.slice())
1878        );
1879
1880        assert_eq!(
1881            eval(range(1..2, range(3..5, true_())), s.slice()),
1882            eval(sel_from_shape!(s, host = 1..2, gpu = 3..5), s.slice())
1883        );
1884
1885        assert_eq!(
1886            eval(
1887                union(
1888                    sel_from_shape!(s, host = 0..1, gpu = 0..4),
1889                    sel_from_shape!(s, host = 1..2, gpu = 4..5)
1890                ),
1891                s.slice()
1892            ),
1893            eval(
1894                union(
1895                    range(0..1, range(0..4, true_())),
1896                    range(1..2, range(4..5, true_()))
1897                ),
1898                s.slice()
1899            )
1900        );
1901    }
1902
1903    #[test]
1904    fn test_selection_09() {
1905        let slice = &test_slice(); // 2 x 4 x 8
1906
1907        // Identity.
1908        assert_eq!(eval(any(false_()), slice), eval(false_(), slice));
1909
1910        // An arbitrary GPU.
1911        let res = eval(any(any(any(true_()))), slice);
1912        assert_eq!(res.len(), 1);
1913        assert!(res[0] < 64);
1914
1915        // The first 4 GPUs of any host in region-0.
1916        let res = eval(range(0, any(range(0..4, true_()))), slice);
1917        assert!((0..4).any(|host| res == eval(range(0, range(host, range(0..4, true_()))), slice)));
1918
1919        // Any GPU on host-0 in region-0.
1920        let res = eval(range(0, range(0, any(true_()))), slice);
1921        assert_eq!(res.len(), 1);
1922        assert!(res[0] < 8);
1923
1924        // All GPUs on any host in region-0.
1925        let res = eval(range(0, any(true_())), slice);
1926        assert!((0..4).any(|host| res == eval(range(0, range(host, true_())), slice)));
1927
1928        // All GPUs on any host in region-1.
1929        let res = eval(range(1, any(true_())), slice);
1930        assert!((0..4).any(|host| res == eval(range(1, range(host, true_())), slice)));
1931
1932        // Any two GPUs on host-0 in region-0.
1933        let mut res = vec![];
1934        while res.len() < 2 {
1935            res = eval(
1936                union(
1937                    range(0, range(0, any(true_()))),
1938                    range(0, range(0, any(true_()))),
1939                ),
1940                slice,
1941            );
1942        }
1943        assert_matches!(res.as_slice(), [i, j] if *i < *j && *i < 8 && *j < 8);
1944    }
1945
1946    #[test]
1947    fn test_eval_zero_dim_slice() {
1948        let slice_0d = Slice::new(1, vec![], vec![]).unwrap();
1949        // Let s be a slice with dim(s) = 0. Then: ∃! x ∈ s :
1950        // coordsₛ(x) = ().
1951        assert_eq!(slice_0d.coordinates(1).unwrap(), vec![]);
1952
1953        assert_eq!(eval(true_(), &slice_0d), vec![1]);
1954        assert_eq!(eval(false_(), &slice_0d), vec![]);
1955        assert_eq!(eval(all(true_()), &slice_0d), vec![1]);
1956        assert_eq!(eval(all(false_()), &slice_0d), vec![]);
1957        assert_eq!(eval(union(true_(), true_()), &slice_0d), vec![1]);
1958        assert_eq!(eval(intersection(true_(), false_()), &slice_0d), vec![]);
1959    }
1960
1961    #[test]
1962    fn test_selection_10() {
1963        let slice = &test_slice();
1964        let opts = EvalOpts {
1965            disallow_dynamic_selections: true,
1966            ..EvalOpts::lenient()
1967        };
1968        let expr = any(any(any(true_())));
1969        let res = expr.validate(&opts, slice);
1970        assert_matches!(res, Err(ShapeError::SelectionDynamic { .. }));
1971    }
1972
1973    #[test]
1974    fn test_13() {
1975        // Structural identity: `all(true)` <=> `true`.
1976        assert!(is_equivalent_true(true_()));
1977        assert!(is_equivalent_true(all(true_())));
1978        assert!(is_equivalent_true(all(all(true_()))));
1979        assert!(is_equivalent_true(all(all(all(true_())))));
1980        assert!(is_equivalent_true(all(all(all(all(true_()))))));
1981        assert!(is_equivalent_true(all(all(all(all(all(true_())))))));
1982        // ...
1983
1984        assert!(!is_equivalent_true(false_()));
1985        assert!(!is_equivalent_true(union(true_(), true_())));
1986        assert!(!is_equivalent_true(range(0..=0, true_())));
1987        assert!(!is_equivalent_true(all(false_())));
1988    }
1989
1990    #[test]
1991    fn test_14() {
1992        use std::collections::HashSet;
1993
1994        use crate::selection::NormalizedSelectionKey;
1995        use crate::selection::dsl::*;
1996
1997        let a = all(all(true_()));
1998        let b = all(all(true_()));
1999
2000        let key_a = NormalizedSelectionKey::new(&a);
2001        let key_b = NormalizedSelectionKey::new(&b);
2002
2003        // They should be structurally equal.
2004        assert_eq!(key_a, key_b);
2005
2006        // Their hashes should agree, and they deduplicate in a set.
2007        let mut set = HashSet::new();
2008        set.insert(key_a);
2009        assert!(set.contains(&key_b));
2010    }
2011
2012    #[test]
2013    fn test_contains_true() {
2014        let selection = true_();
2015        assert!(selection.contains(&[0, 0, 0]));
2016        assert!(selection.contains(&[1, 2, 3]));
2017    }
2018
2019    #[test]
2020    fn test_contains_false() {
2021        let selection = false_();
2022        assert!(!selection.contains(&[0, 0, 0]));
2023        assert!(!selection.contains(&[1, 2, 3]));
2024    }
2025
2026    #[test]
2027    fn test_contains_all() {
2028        let selection = all(true_());
2029        assert!(selection.contains(&[0, 0, 0]));
2030        assert!(selection.contains(&[1, 2, 3]));
2031    }
2032
2033    #[test]
2034    fn test_contains_range() {
2035        let selection = range(1..3, true_());
2036        assert!(selection.contains(&[1, 0, 0]));
2037        assert!(!selection.contains(&[3, 0, 0]));
2038    }
2039
2040    #[test]
2041    fn test_contains_intersection() {
2042        let selection = intersection(range(1..3, true_()), range(2..4, true_()));
2043        assert!(selection.contains(&[2, 0, 0]));
2044        assert!(!selection.contains(&[1, 0, 0]));
2045    }
2046
2047    #[test]
2048    fn test_contains_union() {
2049        let selection = union(range(1..2, true_()), range(3..4, true_()));
2050        assert!(selection.contains(&[1, 0, 0]));
2051        assert!(!selection.contains(&[2, 0, 0]));
2052    }
2053
2054    #[test]
2055    #[should_panic(expected = "not implemented")]
2056    fn test_contains_any() {
2057        let selection = any(true_());
2058        selection.contains(&[0, 0, 0]);
2059    }
2060
2061    #[test]
2062    #[should_panic(expected = "not implemented")]
2063    fn test_contains_label() {
2064        let selection = label(vec!["zone".to_string()], true_());
2065        selection.contains(&[1, 2, 3]);
2066    }
2067
2068    #[test]
2069    #[should_panic(expected = "not implemented")]
2070    fn test_contains_first() {
2071        let selection = first(true_());
2072        selection.contains(&[0, 0, 0]);
2073    }
2074
2075    #[test]
2076    fn test_difference_1d() {
2077        assert_eq!(
2078            true_()
2079                .difference(
2080                    &EvalOpts::strict(),
2081                    &Slice::new_row_major([5]),
2082                    &[2usize, 4].into(),
2083                )
2084                .unwrap()
2085                .collect::<Vec<_>>(),
2086            vec![0, 1, 3]
2087        );
2088    }
2089
2090    #[test]
2091    fn test_difference_empty_selection() {
2092        assert_eq!(
2093            false_()
2094                .difference(
2095                    &EvalOpts::strict(),
2096                    &Slice::new_row_major([3]),
2097                    &[0usize, 1].into(),
2098                )
2099                .unwrap()
2100                .collect::<Vec<_>>(),
2101            vec![]
2102        );
2103    }
2104
2105    #[test]
2106    fn test_difference_2d() {
2107        // [[0, 1, 2],
2108        //  [3, 4, 5]]
2109        // Select everything, exclude the second row.
2110        assert_eq!(
2111            all(all(true_()))
2112                .difference(
2113                    &EvalOpts::strict(),
2114                    &Slice::new_row_major([2, 3]),
2115                    &[3usize, 4, 5].into(),
2116                )
2117                .unwrap()
2118                .collect::<Vec<_>>(),
2119            vec![0, 1, 2]
2120        );
2121    }
2122
2123    #[test]
2124    fn test_of_ranks_1d() {
2125        let slice = Slice::new_row_major([5]);
2126        let ranks = BTreeSet::from([1, 3]);
2127        let selection = Selection::of_ranks(&slice, &ranks).unwrap();
2128        assert_eq!(
2129            selection
2130                .eval(&EvalOpts::strict(), &slice)
2131                .unwrap()
2132                .collect::<Vec<_>>(),
2133            vec![1, 3]
2134        )
2135    }
2136
2137    #[test]
2138    fn test_of_ranks_empty_set() {
2139        let slice = Slice::new_row_major([4]);
2140        let ranks = BTreeSet::new();
2141        let selection = Selection::of_ranks(&slice, &ranks).unwrap();
2142        assert_eq!(
2143            selection
2144                .eval(&EvalOpts::strict(), &slice)
2145                .unwrap()
2146                .collect::<Vec<_>>(),
2147            vec![]
2148        )
2149    }
2150
2151    #[test]
2152    fn test_of_ranks_singleton_structural() {
2153        let slice = Slice::new_row_major([5]);
2154        let ranks = BTreeSet::from([2]);
2155        let actual = Selection::of_ranks(&slice, &ranks).unwrap();
2156        let expected = range(2..=2, true_());
2157        assert_structurally_eq!(&actual, &expected);
2158    }
2159
2160    #[test]
2161    fn test_of_ranks_union_2d_structural() {
2162        let slice = Slice::new_row_major([2, 3]);
2163        // [ [0, 1, 2],
2164        //   [3, 4, 5] ]
2165        // We'll select (0, 2), (1, 0) and (1, 1).
2166        let ranks = BTreeSet::from([2, 3, 4]);
2167        let actual = Selection::of_ranks(&slice, &ranks).unwrap();
2168        // Each rank becomes a nested selection:
2169        // 2 -> (0, 2) -> range(0, range(2, true_()))
2170        // 3 -> (1, 0) -> range(1, range(0, true_()))
2171        // 4 -> (1, 1) -> range(1, range(1, true_()))
2172        //
2173        // Their union is:
2174        let expected = union(
2175            union(range(0, range(2, true_())), range(1, range(0, true_()))),
2176            range(1, range(1, true_())),
2177        );
2178        assert_structurally_eq!(&actual, &expected);
2179    }
2180
2181    #[test]
2182    fn test_of_ranks_3d_structural() {
2183        let slice = Slice::new_row_major([2, 2, 2]);
2184        // [ [ [0, 1],
2185        //     [2, 3] ],
2186        //   [ [4, 5],
2187        //     [6, 7] ] ]
2188        let ranks = BTreeSet::from([1, 6]);
2189        let actual = Selection::of_ranks(&slice, &ranks).unwrap();
2190        let expected = union(
2191            range(0, range(0, range(1, true_()))), // (0, 0, 1)
2192            range(1, range(1, range(0, true_()))), // (1, 1, 0)
2193        );
2194        assert_structurally_eq!(&actual, &expected);
2195    }
2196
2197    #[test]
2198    fn test_of_ranks_invalid_index() {
2199        let slice = Slice::new_row_major([4]);
2200        let ranks = BTreeSet::from([0, 4]); // 4 is out of bounds
2201        assert!(
2202            Selection::of_ranks(&slice, &ranks).is_err(),
2203            "expected out-of-bounds error"
2204        );
2205    }
2206
2207    #[test]
2208    fn test_reify_slice_empty() {
2209        let slice = Slice::new_row_major([0]);
2210        let selection = slice.reify_slice(&slice).unwrap();
2211        let expected = false_();
2212        assert_structurally_eq!(&selection, expected);
2213        assert_eq!(
2214            selection
2215                .eval(&EvalOpts::lenient(), &slice)
2216                .unwrap()
2217                .collect::<Vec<_>>(),
2218            vec![]
2219        );
2220    }
2221
2222    #[test]
2223    fn test_reify_slice_1d() {
2224        let shape = shape!(x = 6); // 1D shape with 6 elements
2225        let base = shape.slice();
2226
2227        let selected = select!(shape, x = 2..5).unwrap();
2228        let view = selected.slice();
2229
2230        let selection = base.reify_slice(view).unwrap();
2231        let expected = range(2..5, true_());
2232        assert_structurally_eq!(&selection, expected);
2233
2234        let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2235        assert_eq!(flat, vec![2, 3, 4]);
2236    }
2237
2238    #[test]
2239    fn test_reify_slice_2d() {
2240        let shape = shape!(x = 4, y = 5); // 2D shape: 4 rows, 5 columns
2241        let base = shape.slice();
2242
2243        // Select the middle 2x3 block: rows 1..3 and columns 2..5
2244        let selected = select!(shape, x = 1..3, y = 2..5).unwrap();
2245        let view = selected.slice();
2246        let selection = base.reify_slice(view).unwrap();
2247        let expected = range(1..3, range(2..5, true_()));
2248        assert_structurally_eq!(&selection, expected);
2249
2250        let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2251        assert_eq!(
2252            flat,
2253            vec![
2254                base.location(&[1, 2]).unwrap(),
2255                base.location(&[1, 3]).unwrap(),
2256                base.location(&[1, 4]).unwrap(),
2257                base.location(&[2, 2]).unwrap(),
2258                base.location(&[2, 3]).unwrap(),
2259                base.location(&[2, 4]).unwrap(),
2260            ]
2261        );
2262    }
2263
2264    #[test]
2265    #[allow(clippy::identity_op)]
2266    fn test_reify_slice_1d_with_stride() {
2267        let shape = shape!(x = 7); // 1D shape with 7 elements
2268        let selected = shape.select("x", Range(0, None, 2)).unwrap();
2269        let view = selected.slice();
2270        assert_eq!(view, &Slice::new(0, vec![4], vec![1 * 2]).unwrap());
2271
2272        let base = shape.slice();
2273        let selection = base.reify_slice(view).unwrap();
2274        // Note: ceil(7 / 2) = 4, hence end = 0 + 2 × 4 = 8. See the
2275        // more detailed explanation in
2276        // `test_reify_slice_2d_with_stride`.
2277        let expected = range(Range(0, Some(8), 2), true_());
2278        assert_structurally_eq!(&selection, expected);
2279
2280        let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2281        assert_eq!(flat, vec![0, 2, 4, 6]);
2282    }
2283
2284    #[test]
2285    #[allow(clippy::identity_op)]
2286    fn test_reify_slice_2d_with_stride() {
2287        // 4 x 4: x = 4, y = 4.
2288        let base = shape!(x = 4, y = 4);
2289        // Step 1: select odd rows (x = 1..4 step 2)
2290        let shape = base.select("x", Range(1, Some(4), 2)).unwrap();
2291        // Step 2: then select odd columns (y = 1..4 step 2)
2292        let shape = shape.select("y", Range(1, Some(4), 2)).unwrap();
2293        let view = shape.slice();
2294        assert_eq!(
2295            view,
2296            &Slice::new(5, vec![2, 2], vec![4 * 2, 1 * 2]).unwrap()
2297        );
2298
2299        let base = base.slice();
2300        let selection = base.reify_slice(view).unwrap();
2301        // We use `end = start + step * len` to reify the selection.
2302        // Note: This may yield `end > original_end` (e.g., 5 instead of 4)
2303        // when the selection length was computed via ceiling division.
2304        // This is safe: the resulting range will still select the correct
2305        // indices (e.g., 1 and 3 for Range(1, Some(5), 2)).
2306        let expected = range(Range(1, Some(5), 2), range(Range(1, Some(5), 2), true_()));
2307        assert_structurally_eq!(&selection, expected);
2308
2309        let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2310        assert_eq!(flat, vec![5, 7, 13, 15]);
2311    }
2312
2313    #[test]
2314    fn test_reify_slice_selects_column_across_rows() {
2315        let shape = shape!(host = 2, gpu = 4); // shape [2, 4]
2316        let base = shape.slice();
2317
2318        // Select the 3rd GPU (index 2) across both hosts
2319        let selected = select!(shape, gpu = 2).unwrap(); // (0, 2) and (1, 2)
2320        let view = selected.slice();
2321        let coordinates: Vec<_> = view.iter().map(|i| view.coordinates(i).unwrap()).collect();
2322        assert_eq!(coordinates, [[0, 0], [1, 0]]);
2323
2324        let selection = base.reify_slice(view).unwrap();
2325        let expected = range(0..2, range(2..3, true_()));
2326        assert_structurally_eq!(&selection, expected);
2327
2328        let actual = selection
2329            .eval(&EvalOpts::strict(), base)
2330            .unwrap()
2331            .collect::<Vec<_>>();
2332        assert_eq!(
2333            actual,
2334            vec![
2335                base.location(&[0, 2]).unwrap(),
2336                base.location(&[1, 2]).unwrap()
2337            ]
2338        );
2339    }
2340
2341    #[test]
2342    fn test_reify_slice_dimension_mismatch() {
2343        let shape = shape!(host = 2, gpu = 4);
2344        let base = shape.slice();
2345
2346        // Select the 3rd GPU (index 2) across both hosts i.e. flat
2347        // indices [2, 6]
2348        let indices = vec![
2349            base.location(&[0, 2]).unwrap(),
2350            base.location(&[1, 2]).unwrap(),
2351        ];
2352
2353        let view = Slice::new(indices[0], vec![indices.len()], vec![4]).unwrap();
2354        let selection = base.reify_slice(&view).unwrap();
2355
2356        let expected = Selection::of_ranks(base, &indices.iter().cloned().collect()).unwrap();
2357        assert_structurally_eq!(&selection, expected);
2358
2359        let actual: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2360        assert_eq!(actual, indices);
2361    }
2362
2363    #[test]
2364    fn test_union_of_slices_empty() {
2365        let base = Slice::new_row_major([2]);
2366        let sel = base.reify_slices(&[]).unwrap();
2367        assert_structurally_eq!(&sel, &false_());
2368        assert_eq!(
2369            sel.eval(&EvalOpts::strict(), &base)
2370                .unwrap()
2371                .collect::<Vec<_>>(),
2372            vec![]
2373        );
2374    }
2375
2376    #[test]
2377    fn test_union_of_slices_singleton() {
2378        let shape = shape!(x = 3);
2379        let base = shape.slice();
2380        let selected = select!(shape, x = 1).unwrap();
2381        let view = selected.slice().clone();
2382
2383        let selection = base.reify_slices(&[view]).unwrap();
2384        let expected = range(1..=1, true_());
2385        assert_structurally_eq!(&selection, &expected);
2386
2387        assert_eq!(
2388            selection
2389                .eval(&EvalOpts::strict(), base)
2390                .unwrap()
2391                .collect::<Vec<_>>(),
2392            vec![1],
2393        );
2394    }
2395
2396    #[test]
2397    fn test_union_of_slices_disjoint() {
2398        let shape = shape!(x = 2, y = 2); // 2x2 grid
2399        let base = shape.slice();
2400
2401        // View A: (0, *)
2402        let a = select!(shape, x = 0).unwrap();
2403        let view_a = a.slice().clone();
2404
2405        // View B: (1, *)
2406        let b = select!(shape, x = 1).unwrap();
2407        let view_b = b.slice().clone();
2408
2409        let selection = base.reify_slices(&[view_a, view_b]).unwrap();
2410        let expected = union(
2411            range(0..1, range(0..2, true_())),
2412            range(1..2, range(0..2, true_())),
2413        );
2414        assert_structurally_eq!(&selection, &expected);
2415        assert_eq!(
2416            selection
2417                .eval(&EvalOpts::strict(), base)
2418                .unwrap()
2419                .collect::<Vec<_>>(),
2420            base.iter().collect::<Vec<_>>()
2421        );
2422    }
2423
2424    #[test]
2425    fn test_union_of_slices_overlapping() {
2426        let shape = shape!(x = 1, y = 4); // 1x4 grid
2427        let base = shape.slice();
2428
2429        let selected1 = select!(shape, y = 0..2).unwrap();
2430        let view1 = selected1.slice().clone();
2431
2432        let selected2 = select!(shape, y = 1..4).unwrap();
2433        let view2 = selected2.slice().clone();
2434
2435        let selection = base.reify_slices(&[view1, view2]).unwrap();
2436        let expected = union(
2437            range(0..1, range(0..2, true_())),
2438            range(0..1, range(1..4, true_())),
2439        );
2440        assert_structurally_eq!(&selection, &expected);
2441
2442        assert_eq!(
2443            selection
2444                .eval(&EvalOpts::strict(), base)
2445                .unwrap()
2446                .collect::<Vec<_>>(),
2447            base.iter().collect::<Vec<_>>()
2448        );
2449    }
2450
2451    #[test]
2452    fn test_canonicalize_to_dimensions() {
2453        assert_structurally_eq!(
2454            true_().canonicalize_to_dimensions(3),
2455            &all(all(all(true_())))
2456        );
2457        assert_structurally_eq!(
2458            all(true_()).canonicalize_to_dimensions(3),
2459            &all(all(all(true_())))
2460        );
2461        assert_structurally_eq!(
2462            all(all(true_())).canonicalize_to_dimensions(3),
2463            &all(all(all(true_())))
2464        );
2465        assert_structurally_eq!(
2466            all(all(all(true_()))).canonicalize_to_dimensions(3),
2467            &all(all(all(true_())))
2468        );
2469
2470        assert_structurally_eq!(
2471            false_().canonicalize_to_dimensions(3),
2472            &all(all(all(false_())))
2473        );
2474        assert_structurally_eq!(
2475            all(false_()).canonicalize_to_dimensions(3),
2476            &all(all(all(false_())))
2477        );
2478        assert_structurally_eq!(
2479            all(all(false_())).canonicalize_to_dimensions(3),
2480            &all(all(all(false_())))
2481        );
2482        assert_structurally_eq!(
2483            all(all(all(false_()))).canonicalize_to_dimensions(3),
2484            &all(all(all(false_())))
2485        );
2486
2487        assert_structurally_eq!(
2488            any(true_()).canonicalize_to_dimensions(3),
2489            &any(any(any(true_())))
2490        );
2491        assert_structurally_eq!(
2492            any(any(true_())).canonicalize_to_dimensions(3),
2493            &any(any(any(true_())))
2494        );
2495        assert_structurally_eq!(
2496            any(any(any(true_()))).canonicalize_to_dimensions(3),
2497            &any(any(any(true_())))
2498        );
2499
2500        // 0:1 -> 0:1, *, * <=> range(0..1, all(all(true_())))
2501        assert_structurally_eq!(
2502            range(0..1, true_()).canonicalize_to_dimensions(3),
2503            &range(0..1, all(all(true_())))
2504        );
2505        // *, 0:1 -> *, 0:1, * <=> all(range(0..1, all(true_())))
2506        assert_structurally_eq!(
2507            all(range(0..1, true_())).canonicalize_to_dimensions(3),
2508            &all(range(0..1, all(true_())))
2509        );
2510        // 0:1, ? -> 0:1, ?, ? <=> range(0..1, any(any(true_())))
2511        assert_structurally_eq!(
2512            range(0..1, any(true_())).canonicalize_to_dimensions(3),
2513            &range(0..1, any(any(true_())))
2514        );
2515        // 0:1, ?, * -> 0:1, ?, * <=> range(0..1, any(all(true_())))
2516        assert_structurally_eq!(
2517            range(0..1, any(all(true_()))).canonicalize_to_dimensions(3),
2518            &range(0..1, any(all(true_())))
2519        );
2520    }
2521}