ndslice/
selection.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! This module defines a recursive algebra for selecting coordinates
10//! in a multidimensional space.
11//!
12//! A `Selection` describes constraints across dimensions of an
13//! `ndslice::Slice`. Variants like [`All`], [`First`], and [`Range`]
14//! operate dimensionally, while [`Intersection`] and [`Union`] allow
15//! for logical composition of selections.
16//!
17//! ## Example
18//!
19//! Suppose a 3-dimensional mesh system of:
20//! - 2 zones
21//! - 4 hosts per zone
22//! - 8 GPUs per host
23//!
24//! The corresponding `Slice` will have shape `[2, 4, 8]`. An
25//! expression to denote the first 4 GPUs of host 0 together with the
26//! last 4 GPUs on host 1 across all regions can be written as:
27//! ```rust
28//! use ndslice::selection::dsl::all;
29//! use ndslice::selection::dsl::range;
30//! use ndslice::selection::dsl::true_;
31//! use ndslice::selection::dsl::union;
32//!
33//! let s = all(range(0, range(0..4, true_())));
34//! let t = all(range(1, range(4.., true_())));
35//! let selection = union(s, t);
36//! ```
37//! Assuming a row-major layout, that is the set of 4 x 2 x 2 = 16
38//! coordinates *{(0, 0, 0), ... (0, 0, 3), (0, 1, 5), ..., (0, 1, 7),
39//! (1, 0, 0), ..., (1, 0, 3), (1, 1, 4), ..., (1, 1, 7)}* where code
40//! to print that set might read as follows.
41//! ```rust
42//! use ndslice::Slice;
43//! use ndslice::selection::EvalOpts;
44//! use ndslice::selection::dsl::all;
45//! use ndslice::selection::dsl::range;
46//! use ndslice::selection::dsl::true_;
47//! use ndslice::selection::dsl::union;
48//!
49//! let slice = Slice::new(0usize, vec![2, 4, 8], vec![32, 8, 1]).unwrap();
50//! let s = all(range(0, range(0..4, true_())));
51//! let t = all(range(1, range(4.., true_())));
52//!
53//! for r in union(s, t).eval(&EvalOpts::lenient(), &slice).unwrap() {
54//!     println!("{:?}", slice.coordinates(r).unwrap());
55//! }
56//! ```
57//! which is using the `eval` function described next.
58//!
59//! ## Evaluation
60//!
61//! Selections are evaluated against an `ndslice::Slice` using the
62//! [`Selection::eval`] method, which returns a lazy iterator over the
63//! flat (linearized) indices of elements that match.
64//!
65//! Evaluation proceeds recursively, dimension by dimension, with each
66//! variant of `Selection` contributing logic at a particular level of
67//! the slice.
68//!
69//! If a `Selection` is shorter than the number of dimensions, it is
70//! implicitly extended with `true_()` at the remaining levels. This
71//! means `Selection::True` acts as the identity element, matching all
72//! remaining indices by default.
73
74/// A parser for selection expressions in a compact textual syntax.
75///
76/// See [`selection::parse`] for syntax details and examples.
77pub mod parse;
78
79/// Formatting utilities for `Selection` expressions.
80///
81/// This module defines pretty-printers and compact syntax renderers
82/// for selections, based on implementations of the `SelectionSYM`
83/// trait.
84///
85/// The `Display` implementation for [`Selection`] uses this module.
86pub mod pretty;
87
88/// A `TokenStream` to [`Selection`] parser. Used at compile time in
89/// [`sel!]`. See [`selection::parse`] for syntax details and
90/// examples.
91pub mod token_parser;
92
93/// Shape navigation guided by [`Selection`] expressions.
94pub mod routing;
95
96/// Normalization logic for `Selection`.
97pub mod normal;
98
99pub mod test_utils;
100
101use std::collections::BTreeSet;
102use std::collections::HashMap;
103use std::collections::HashSet;
104use std::fmt;
105
106use rand::Rng;
107use serde::Deserialize;
108use serde::Serialize;
109
110use crate::Slice;
111use crate::selection::normal::NormalizedSelection;
112use crate::selection::normal::RewriteRuleExt;
113use crate::shape;
114use crate::shape::ShapeError;
115use crate::slice::SliceError;
116
117/// This trait defines an abstract syntax without committing to a
118/// specific representation. It follow the
119/// [tagless-final](https://okmij.org/ftp/tagless-final/index.html)
120/// style where [`Selection`] is a default AST representation.
121pub trait SelectionSYM {
122    /// The identity selection (matches no nodes).
123    fn false_() -> Self;
124
125    /// The universal selection (matches all nodes).
126    fn true_() -> Self;
127
128    /// Selects all values along the current dimension, then applies
129    /// the inner selection.
130    fn all(selection: Self) -> Self;
131
132    /// Selects the first index along the current dimension for which
133    /// the inner selection is non-empty.
134    fn first(selection: Self) -> Self;
135
136    /// Selects values within the given range along the current
137    /// dimension, then applies the inner selection.
138    fn range<R: Into<shape::Range>>(range: R, selection: Self) -> Self;
139
140    /// Selects values along the current dimension that match the
141    /// given labels, then applies the inner selection.
142    fn label<L: Into<LabelKey>>(labels: Vec<L>, selection: Self) -> Self;
143
144    /// Selects a random index along the current dimension, then applies
145    /// the inner selection.
146    fn any(selection: Self) -> Self;
147
148    /// The intersection (logical AND) of two selection expressions.
149    fn intersection(lhs: Self, selection: Self) -> Self;
150
151    /// The union (logical OR) of two selection expressions.
152    fn union(lhs: Self, selection: Self) -> Self;
153}
154
155/// `SelectionSYM`-based constructors specialized to the [`Selection`]
156/// AST.
157pub mod dsl {
158
159    use super::LabelKey;
160    use super::Selection;
161    use super::SelectionSYM;
162    use crate::shape;
163
164    pub fn false_() -> Selection {
165        SelectionSYM::false_()
166    }
167    pub fn true_() -> Selection {
168        SelectionSYM::true_()
169    }
170    pub fn all(inner: Selection) -> Selection {
171        SelectionSYM::all(inner)
172    }
173    pub fn first(inner: Selection) -> Selection {
174        SelectionSYM::first(inner)
175    }
176    pub fn range<R: Into<shape::Range>>(r: R, inner: Selection) -> Selection {
177        SelectionSYM::range(r, inner)
178    }
179    pub fn label<L: Into<LabelKey>>(labels: Vec<L>, inner: Selection) -> Selection {
180        SelectionSYM::label(labels, inner)
181    }
182    pub fn any(inner: Selection) -> Selection {
183        SelectionSYM::any(inner)
184    }
185    pub fn intersection(lhs: Selection, rhs: Selection) -> Selection {
186        SelectionSYM::intersection(lhs, rhs)
187    }
188    pub fn union(lhs: Selection, rhs: Selection) -> Selection {
189        SelectionSYM::union(lhs, rhs)
190    }
191}
192
193impl SelectionSYM for Selection {
194    fn false_() -> Self {
195        ast::false_()
196    }
197    fn true_() -> Self {
198        ast::true_()
199    }
200    fn all(selection: Self) -> Self {
201        ast::all(selection)
202    }
203    fn first(selection: Self) -> Self {
204        ast::first(selection)
205    }
206    fn range<R: Into<shape::Range>>(range: R, selection: Self) -> Self {
207        ast::range(range, selection)
208    }
209    fn label<L: Into<LabelKey>>(labels: Vec<L>, selection: Selection) -> Selection {
210        let labels = labels.into_iter().map(|l| l.into()).collect();
211        Selection::Label(labels, Box::new(selection))
212    }
213    fn any(selection: Self) -> Self {
214        ast::any(selection)
215    }
216    fn intersection(lhs: Self, rhs: Self) -> Self {
217        ast::intersection(lhs, rhs)
218    }
219    fn union(lhs: Self, rhs: Self) -> Self {
220        ast::union(lhs, rhs)
221    }
222}
223
224impl fmt::Display for Selection {
225    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226        write!(f, "{}", pretty::pretty(self))
227    }
228}
229
230/// A metadata label used to constrain values at a given coordinate
231/// dimension.
232///
233/// `LabelKey` represents attribute values associated with indices —
234/// for example, GPU model names like `"A100"` or capabilities like
235/// "AVX-512".
236///
237/// Labels are not dimension names (like `"zone"` or `"rack"`); they
238/// are **values** assigned to elements at a given dimension, and are
239/// used by `Selection::Label` to restrict which values are eligible
240/// during selection or routing.
241/// For example, a selection like `sel!(["A100"]*)` matches only
242/// indices at the current dimension whose associated label value is
243/// `"A100"`.
244///
245/// `Ord` is derived to allow deterministic sorting and set membership,
246/// based on lexicographic ordering of label strings.
247#[derive(
248    Clone,
249    Debug,
250    PartialEq,
251    Eq,
252    Hash,
253    Serialize,
254    Deserialize,
255    PartialOrd,
256    Ord
257)]
258pub enum LabelKey {
259    /// A plain string label value.
260    Value(String),
261}
262
263impl From<String> for LabelKey {
264    fn from(s: String) -> Self {
265        LabelKey::Value(s)
266    }
267}
268
269impl From<&str> for LabelKey {
270    fn from(s: &str) -> Self {
271        LabelKey::Value(s.to_string())
272    }
273}
274
275impl std::fmt::Display for LabelKey {
276    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277        match self {
278            LabelKey::Value(s) => write!(f, "\"{}\"", s),
279        }
280    }
281}
282
283/// An algebra for expressing node selection.
284#[derive(Debug, Clone, Serialize, Deserialize)]
285#[non_exhaustive]
286pub enum Selection {
287    /// A selection that never matches any node.
288    False,
289
290    /// A selection that always matches any node.
291    True,
292
293    /// Selects all values along the current dimension, continuing
294    /// with the given selection.
295    All(Box<Selection>),
296
297    /// Selects the first value along the current dimension for which
298    /// applying the inner selection yields any results.
299    First(Box<Selection>),
300
301    /// Selects values within a given range along the current
302    /// dimension, continuing with the given selection.
303    Range(shape::Range, Box<Selection>),
304
305    /// Selects values based on metadata (i.e., labels) along the
306    /// current dimension. This provides attribute-based selection.
307    Label(Vec<LabelKey>, Box<Selection>),
308
309    /// Selects a random index along the current dimension, continuing
310    /// with the given selection.
311    Any(Box<Selection>),
312
313    /// The intersection (logical AND) of two selections.
314    Intersection(Box<Selection>, Box<Selection>),
315
316    /// The union (logical OR) of two selections.
317    Union(Box<Selection>, Box<Selection>),
318}
319
320// Compile-time check: ensure Selection is thread-safe and fully
321// owned.
322fn _assert_selection_traits()
323where
324    Selection: Send + Sync + 'static,
325{
326}
327
328/// Compares two `Selection` values for structural equality.
329///
330/// Two selections are structurally equal if they have the same shape
331/// and recursively equivalent substructure, but not necessarily the
332/// same pointer identity or formatting.
333pub fn structurally_equal(a: &Selection, b: &Selection) -> bool {
334    use Selection::*;
335    match (a, b) {
336        (False, False) => true,
337        (True, True) => true,
338        (All(x), All(y)) => structurally_equal(x, y),
339        (Any(x), Any(y)) => structurally_equal(x, y),
340        (First(x), First(y)) => structurally_equal(x, y),
341        (Range(r1, x), Range(r2, y)) => r1 == r2 && structurally_equal(x, y),
342        (Intersection(x1, y1), Intersection(x2, y2)) => {
343            structurally_equal(x1, x2) && structurally_equal(y1, y2)
344        }
345        (Union(x1, y1), Union(x2, y2)) => structurally_equal(x1, x2) && structurally_equal(y1, y2),
346        _ => false,
347    }
348}
349
350/// Normalizes a [`Selection`] toward a canonical form for structural
351/// comparison.
352///
353/// This rewrites the selection to eliminate redundant subtrees and
354/// bring structurally similar selections into a common
355/// representation. The result is suitable for comparison, hashing,
356/// and deduplication (e.g., in [`RoutingFrameKey`]).
357///
358/// Normalization preserves semantics but may alter syntactic
359/// structure. It is designed to improve over time as additional
360/// rewrites (e.g., flattening, simplification) are introduced.
361pub fn normalize(sel: &Selection) -> NormalizedSelection {
362    let rule = normal::FlatteningRules
363        .then(normal::IdentityRules)
364        .then(normal::AbsorbtionRules);
365    sel.fold::<normal::NormalizedSelection>()
366        .rewrite_bottom_up(&rule)
367}
368
369/// Wraps a normalized selection and derives `Eq` and `Hash`, relying
370/// on the canonical structure of the normalized form.
371///
372/// This ensures that logically equivalent selections (e.g., unions
373/// with reordered elements) compare equal and hash identically.
374#[derive(Debug, Clone, PartialEq, Eq, Hash)]
375pub struct NormalizedSelectionKey(NormalizedSelection);
376
377impl NormalizedSelectionKey {
378    /// Constructs a `NormalizedSelectionKey`, normalizing the input
379    /// selection.
380    pub fn new(sel: &Selection) -> Self {
381        Self(crate::selection::normalize(sel))
382    }
383
384    /// Access the normalized selection.
385    pub fn inner(&self) -> &NormalizedSelection {
386        &self.0
387    }
388
389    /// Consumes the key and returns the owned normalized selection.
390    pub fn into_inner(self) -> NormalizedSelection {
391        self.0
392    }
393}
394
395mod ast {
396    use super::LabelKey;
397    use super::Selection;
398    use crate::shape;
399
400    pub(crate) fn false_() -> Selection {
401        Selection::False
402    }
403    pub(crate) fn true_() -> Selection {
404        Selection::True
405    }
406    pub(crate) fn all(selection: Selection) -> Selection {
407        Selection::All(Box::new(selection))
408    }
409    pub(crate) fn first(selection: Selection) -> Selection {
410        Selection::First(Box::new(selection))
411    }
412    pub(crate) fn range<R: Into<shape::Range>>(range: R, selection: Selection) -> Selection {
413        Selection::Range(range.into(), Box::new(selection))
414    }
415    #[allow(dead_code)] // Harmless.
416    pub(crate) fn label<L: Into<LabelKey>>(labels: Vec<L>, selection: Selection) -> Selection {
417        let labels = labels.into_iter().map(Into::into).collect();
418        Selection::Label(labels, Box::new(selection))
419    }
420    pub(crate) fn any(selection: Selection) -> Selection {
421        Selection::Any(Box::new(selection))
422    }
423    pub(crate) fn intersection(lhs: Selection, rhs: Selection) -> Selection {
424        Selection::Intersection(Box::new(lhs), Box::new(rhs))
425    }
426    pub(crate) fn union(lhs: Selection, rhs: Selection) -> Selection {
427        Selection::Union(Box::new(lhs), Box::new(rhs))
428    }
429}
430
431/// `EvalOpts` controls runtime behavior of [`Selection::eval`] by
432/// enforcing stricter validation rules.
433pub struct EvalOpts {
434    /// Fail `eval` on empty range expressions.
435    pub disallow_empty_ranges: bool,
436
437    /// Fail `eval` on a range beginning after the slice's extent in
438    /// the evaluation context's dimension.
439    pub disallow_out_of_range: bool,
440
441    /// Fail `eval` if a selection can be shown to be not "static".
442    pub disallow_dynamic_selections: bool,
443}
444
445impl EvalOpts {
446    // Produce empty iterators but don't panic.
447    pub fn lenient() -> Self {
448        Self {
449            disallow_empty_ranges: false,
450            disallow_out_of_range: false,
451            disallow_dynamic_selections: false,
452        }
453    }
454
455    // `eval()` should fail with all the same [`shape::ShapeError`]s
456    // as [`Shape::select()`].
457    #[allow(dead_code)]
458    pub fn strict() -> Self {
459        Self {
460            disallow_empty_ranges: true,
461            disallow_out_of_range: true,
462            ..Self::lenient()
463        }
464    }
465}
466
467impl Selection {
468    pub(crate) fn validate(&self, opts: &EvalOpts, slice: &Slice) -> Result<&Self, ShapeError> {
469        let depth = 0;
470        self.validate_rec(opts, slice, self, depth).map(|_| self)
471    }
472
473    fn validate_rec(
474        &self,
475        opts: &EvalOpts,
476        slice: &Slice,
477        top: &Selection,
478        dim: usize,
479    ) -> Result<(), ShapeError> {
480        if dim == slice.num_dim() {
481            // This enables us to maintain identities like 'all(true)
482            // <=> true' and 'all(false) <=> false' in leaf positions.
483            match self {
484                Selection::True | Selection::False => return Ok(()),
485                _ => {
486                    return Err(ShapeError::SelectionTooDeep {
487                        expr: top.clone(),
488                        num_dim: slice.num_dim(),
489                    });
490                }
491            }
492        }
493
494        match self {
495            Selection::False | Selection::True => Ok(()),
496            Selection::Range(range, s) => {
497                if range.is_empty() && opts.disallow_empty_ranges {
498                    return Err(ShapeError::EmptyRange {
499                        range: range.clone(),
500                    });
501                } else {
502                    if opts.disallow_out_of_range {
503                        let size = slice.sizes()[dim];
504                        let (min, _, _) = range.resolve(size);
505                        if min >= size {
506                            // Use EmptyRange here for now (evaluation would result in an empty range),
507                            // until we figure out how to differentiate between slices and shapes
508                            return Err(ShapeError::EmptyRange {
509                                range: range.clone(),
510                            });
511                        }
512                    }
513
514                    s.validate_rec(opts, slice, top, dim + 1)?;
515                }
516
517                Ok(())
518            }
519            Selection::Any(s) => {
520                if opts.disallow_dynamic_selections {
521                    return Err(ShapeError::SelectionDynamic { expr: top.clone() });
522                }
523                s.validate_rec(opts, slice, top, dim + 1)?;
524                Ok(())
525            }
526            Selection::All(s) | Selection::Label(_, s) | Selection::First(s) => {
527                s.validate_rec(opts, slice, top, dim + 1)?;
528                Ok(())
529            }
530            Selection::Intersection(a, b) | Selection::Union(a, b) => {
531                a.validate_rec(opts, slice, top, dim)?;
532                b.validate_rec(opts, slice, top, dim)?;
533                Ok(())
534            }
535        }
536    }
537
538    /// Lazily evaluates this selection against the given `slice`
539    /// yielding flat indices.
540    ///
541    /// Returns a boxed iterator that produces indices of elements
542    /// matching the selection expression when evaluated over `slice`.
543    ///
544    /// # Lifetimes
545    ///
546    /// The returned iterator borrows `slice` because the internal
547    /// iterators are implemented as closures that **capture**
548    /// `&slice` in their environment. Evaluation is lazy, so these
549    /// closures dereference `slice` each time a coordinate is
550    /// visited. The `'a` lifetime ensures that the iterator cannot
551    /// outlive the `slice` it reads from.
552    ///
553    /// # Why `Box<dyn Iterator>`
554    ///
555    /// The selection algebra supports multiple recursive strategies
556    /// (`All`, `Range`, `Intersection`, etc.) that return different
557    /// iterator types (e.g. `Selection::True` =>
558    /// `std::iter::once(...)`, `Selection::False` =>
559    /// `std::iter::empty()`). Returning `impl Iterator` is not
560    /// feasible because the precise type depends on dynamic selection
561    /// structure. Boxing erases this variability and allows a uniform
562    /// return type.
563    ///
564    /// # Canonical handling of 0-dimensional slices
565    ///
566    /// A `Slice` with zero dimensions represents the empty product
567    /// `∏_{i=1}^{0} Xᵢ`, which has exactly one element: the empty
568    /// tuple. To ensure that evaluation behaves uniformly across
569    /// dimensions, we canonically embed the 0-dimensional case into a
570    /// 1-dimensional slice of extent 1. That is, we reinterpret the
571    /// 0D slice as `Slice::new(offset, [1], [1])`, which is
572    /// semantically equivalent and enables evaluation to proceed
573    /// through the normal recursive machinery without special-casing.
574    /// The result is that selection expressions are always evaluated
575    /// over a slice with at least one dimension, and uniform logic
576    /// applies.
577    pub fn eval<'a>(
578        &self,
579        opts: &EvalOpts,
580        slice: &'a Slice,
581    ) -> Result<Box<dyn Iterator<Item = usize> + 'a>, ShapeError> {
582        // Canonically embed 0D as 1D (extent 1).
583        if slice.num_dim() == 0 {
584            let slice = Slice::new(slice.offset(), vec![1], vec![1]).unwrap();
585            return Ok(Box::new(
586                self.validate(opts, &slice)?
587                    .eval_rec(&slice, vec![0; 1], 0)
588                    .collect::<Vec<_>>()
589                    .into_iter(),
590            ));
591        }
592
593        Ok(self
594            .validate(opts, slice)?
595            .eval_rec(slice, vec![0; slice.num_dim()], 0))
596    }
597
598    fn eval_rec<'a>(
599        &self,
600        slice: &'a Slice,
601        env: Vec<usize>,
602        dim: usize,
603    ) -> Box<dyn Iterator<Item = usize> + 'a> {
604        if dim == slice.num_dim() {
605            match self {
606                Selection::True => return Box::new(std::iter::once(slice.location(&env).unwrap())),
607                Selection::False => return Box::new(std::iter::empty()),
608                _ => {
609                    panic!("structural combinator {self:?} at leaf level (dim = {dim}))",);
610                }
611            }
612        }
613
614        use itertools;
615        use itertools::EitherOrBoth;
616
617        match self {
618            Selection::False => Box::new(std::iter::empty()),
619            Selection::True => Box::new((0..slice.sizes()[dim]).flat_map(move |i| {
620                let mut env = env.clone();
621                env[dim] = i;
622                Selection::True.eval_rec(slice, env, dim + 1)
623            })),
624            Selection::All(select) => {
625                let select = Box::clone(select);
626                Box::new((0..slice.sizes()[dim]).flat_map(move |i| {
627                    let mut env = env.clone();
628                    env[dim] = i;
629                    select.eval_rec(slice, env, dim + 1)
630                }))
631            }
632            Selection::First(select) => {
633                let select = Box::clone(select);
634                Box::new(iterutils::first(slice.sizes()[dim], move |i| {
635                    let mut env = env.clone();
636                    env[dim] = i;
637                    select.eval_rec(slice, env, dim + 1)
638                }))
639            }
640            Selection::Range(range, select) => {
641                let select = Box::clone(select);
642                let (min, max, step) = range.resolve(slice.sizes()[dim]);
643                Box::new((min..max).step_by(step).flat_map(move |i| {
644                    let mut env = env.clone();
645                    env[dim] = i;
646                    select.eval_rec(slice, env, dim + 1)
647                }))
648            }
649
650            // Label-based selection: filters candidates at this
651            // dimension, then either selects one (Any) or recurses.
652            //
653            // When the inner selection is `Any`, we choose one match
654            // at random (eager). Otherwise, we recurse normally and
655            // filter the results lazily.
656            //
657            // This separation reflects that `Label(...)` does *not*
658            // consume a dimension — it restricts access to it while
659            // preserving dimensional structure.
660            //
661            // See `eval_label` for more on the distinction between
662            // filtering and traversal, and the underlying
663            // projection-based interpretation.
664            //
665            // For example:
666            //
667            //   sel!(*, ["foo"]?, *)  // select one host with label "foo", then all GPUs
668            //   = all(label(["foo"], any(all(true_()))))
669            //
670            //   sel!(*, ["foo"]*, *)  // select all hosts with label "foo", then all GPUs
671            //   = all(label(["foo"], all(all(true_()))))
672            //
673            // **Note:** Label filtering is not yet implemented — all coordinates
674            // are currently accepted.
675            Selection::Label(labels, inner) => {
676                Self::eval_label(labels, inner, slice, env, dim /*, provider */)
677            }
678            Selection::Any(select) => {
679                let select = Box::clone(select);
680                let r = {
681                    let upper = slice.sizes()[dim];
682                    let mut rng = rand::thread_rng();
683                    rng.gen_range(0..upper)
684                };
685                Box::new((r..r + 1).flat_map(move |i| {
686                    let mut env = env.clone();
687                    env[dim] = i;
688                    select.eval_rec(slice, env, dim + 1)
689                }))
690            }
691            Selection::Intersection(a, b) => Box::new(
692                itertools::merge_join_by(
693                    a.eval_rec(slice, env.clone(), dim),
694                    b.eval_rec(slice, env.clone(), dim),
695                    |x, y| x.cmp(y),
696                )
697                .filter_map(|either| match either {
698                    EitherOrBoth::Both(x, _) => Some(x),
699                    _ => None,
700                }),
701            ),
702            Selection::Union(a, b) => Box::new(
703                itertools::merge_join_by(
704                    a.eval_rec(slice, env.clone(), dim),
705                    b.eval_rec(slice, env.clone(), dim),
706                    |x, y| x.cmp(y),
707                )
708                .map(|either| match either {
709                    EitherOrBoth::Left(x) => x,
710                    EitherOrBoth::Right(y) => y,
711                    EitherOrBoth::Both(x, _) => x,
712                }),
713            ),
714        }
715    }
716
717    /// Evaluates a `Label(labels, inner)` selection.
718    ///
719    /// This operator filters coordinates along the current dimension
720    /// based on associated metadata (labels). It then evaluates the inner
721    /// selection at matching positions.
722    ///
723    /// Conceptually, this corresponds to computing a pullback along a
724    /// projection `p : E → B`, where:
725    ///
726    /// - `B` is the base space of coordinates (e.g. zones × hosts × gpus)
727    /// - `E` is the space of labeled coordinates
728    /// - `p⁻¹(S)` lifts a geometric selection `S ⊆ B` into the labeled
729    ///   space
730    ///
731    /// At runtime, we simulate `p⁻¹(S)` by traversing `B` and querying a
732    /// `LabelProvider` at each coordinate. Under the identity provider,
733    /// label filtering has no effect and `eval_label` reduces to the
734    /// geometric case.
735    ///
736    /// - If `inner` is `Any`, we select one matching index at random
737    /// - Otherwise, we recurse and filter lazily
738    ///
739    /// **Note:** Label filtering is not yet implemented — all coordinates
740    /// are currently accepted.
741    fn eval_label<'a>(
742        _labels: &[LabelKey],
743        inner: &Selection,
744        slice: &'a Slice,
745        env: Vec<usize>,
746        dim: usize,
747        // provider: &dyn LabelProvider  // TODO: add when ready
748    ) -> Box<dyn Iterator<Item = usize> + 'a> {
749        match inner {
750            // Case 1: label(..., any(...))
751            // - We evaluate all indices at this dimension that match
752            //   the label predicate.
753            // - From those, choose one at random and continue
754            //   evaluating the inner selection.
755            // - Semantically: filter → choose one → recurse
756            Selection::Any(sub_inner) => {
757                let matching: Vec<usize> = (0..slice.sizes()[dim])
758                    .filter(|&i| {
759                        let mut prefix = env.clone();
760                        prefix[dim] = i;
761                        true // TODO: provider.matches(dim, &prefix[0..=dim], labels)
762                    })
763                    .collect();
764
765                if matching.is_empty() {
766                    return Box::new(std::iter::empty());
767                }
768
769                let mut rng = rand::thread_rng();
770                let chosen = matching[rng.gen_range(0..matching.len())];
771
772                let mut coord = env;
773                coord[dim] = chosen;
774                sub_inner.eval_rec(slice, coord, dim + 1 /*, provider */)
775            }
776            // Case 2: label(..., inner)
777            //
778            // Applies label filtering after evaluating `inner`. We
779            // first recurse into `inner`, then lazily filter the
780            // resulting flat indices based on whether the coordinate
781            // at `dim` matches the given labels.
782            //
783            // This preserves laziness for all cases except `Any`,
784            // which requires eager collection and is handled
785            // separately.
786            _ => {
787                // evaluate the inner selection — recurse as usual
788                let iter = inner.eval_rec(slice, env.clone(), dim /* , provider */);
789                Box::new(iter.filter(move |&flat| {
790                    let _coord = slice.coordinates(flat);
791                    true // TODO: provider.matches(dim, &coord[0..=dim], labels)
792                }))
793            }
794        }
795    }
796
797    /// Returns `true` if this selection is equivalent to `True` under
798    /// the algebra.
799    ///
800    /// In the selection algebra, `All(True)` is considered equivalent
801    /// to `True`, and this identity extends recursively. For example:
802    ///
803    ///   - `All(True)`      ≡ `True`
804    ///   - `All(All(True))` ≡ `True`
805    ///   - `All(All(All(True)))` ≡ `True`
806    ///
807    /// This method checks whether the selection is structurally
808    /// identical to True, possibly wrapped in one or more All(...)
809    /// layers. It does **not** perform full normalization—only
810    /// structural matching sufficient to recognize this identity.
811    ///
812    /// Used to detect when a selection trivially selects all elements
813    /// at all levels.
814    ///
815    /// ## Limitations
816    ///
817    /// This is a **syntactic check** only. It does *not* recognize
818    /// semantically equivalent expressions such as:
819    ///
820    ///   - `Union(True, True)`
821    ///   - `All(Union(True, False))`
822    ///   - A union of all singleton ranges covering the full space
823    ///
824    /// For a semantic check, use evaluation against a known slice.
825    pub fn is_equivalent_to_true(mut sel: &Selection) -> bool {
826        while let Selection::All(inner) = sel {
827            sel = inner;
828        }
829        matches!(sel, Selection::True)
830    }
831
832    /// Evaluates whether the specified coordinates are part of the selection.
833    /// Returns true if they are, false otherwise.
834    ///
835    /// Example:
836    /// let selection = union(
837    ///     range(0..2, range(0..1, range(0..2, true_()))),
838    ///     range(0..2, range(1..2, range(0..2, true_()))),
839    /// );
840    ///
841    /// assert!(selection.contains(&[0, 0, 1]));
842    /// assert!(!selection.contains(&[2, 0, 1]));
843    pub fn contains(&self, coords: &[usize]) -> bool {
844        self.contains_rec(coords, 0)
845    }
846
847    fn contains_rec(&self, coords: &[usize], dim: usize) -> bool {
848        if dim >= coords.len() {
849            return matches!(self, Selection::True);
850        }
851
852        match self {
853            Selection::False => false,
854            Selection::True => true,
855            Selection::All(inner) => inner.contains_rec(coords, dim + 1),
856            Selection::Range(range, inner) => {
857                let (min, max, step) = range.resolve(coords.len());
858                let index = coords[dim];
859                index >= min
860                    && index < max
861                    && (index - min).is_multiple_of(step)
862                    && inner.contains_rec(coords, dim + 1)
863            }
864            Selection::Intersection(a, b) => {
865                a.contains_rec(coords, dim) && b.contains_rec(coords, dim)
866            }
867            Selection::Union(a, b) => a.contains_rec(coords, dim) || b.contains_rec(coords, dim),
868            Selection::Label(_, _) | Selection::First(_) | Selection::Any(_) => {
869                unimplemented!()
870            }
871        }
872    }
873
874    /// Simplifies the intersection of two `Selection` expressions.
875    ///
876    /// Applies short-circuit logic to avoid constructing redundant or
877    /// degenerate intersections:
878    ///
879    /// - If either side is `False`, the result is `False`.
880    /// - If either side is `True`, the result is the other side.
881    /// - Otherwise, constructs an explicit `Intersection`.
882    ///
883    /// This is required during routing to make progress when
884    /// evaluating intersections. Without this reduction, routing may
885    /// stall — for example, in intersections like `Intersection(True,
886    /// X)`, which should simplify to `X`.
887    pub fn reduce_intersection(self: Selection, b: Selection) -> Selection {
888        match (&self, &b) {
889            (Selection::False, _) | (_, Selection::False) => Selection::False,
890            (Selection::True, other) | (other, Selection::True) => other.clone(),
891            _ => Selection::Intersection(Box::new(self), Box::new(b)),
892        }
893    }
894
895    /// Canonicalizes this selection to the specified number of
896    /// dimensions.
897    ///
898    /// Ensures that the selection has exactly `num_dims` dimensions
899    /// by recursively wrapping it in combinators (`All`, `Any`, etc.)
900    /// where needed. This transformation enforces a canonical
901    /// structural form suitable for dimensional evaluation (e.g.,
902    /// routing via `next_steps`).
903    ///
904    /// Examples:
905    /// - `True` becomes `All(All(...(True)))`
906    /// - `Any(True)` becomes `Any(Any(...(True)))`
907    /// - Fully specified selections are left unchanged.
908    ///
909    /// ---
910    ///
911    /// # Panics
912    /// Panics if `num_dims == 0`. Use a canonical embedding (e.g., 0D
913    /// → 1D) before calling this (see e.g. `RoutingFrame::root`).
914    pub(crate) fn canonicalize_to_dimensions(self, num_dims: usize) -> Selection {
915        assert!(
916            num_dims > 0,
917            "canonicalize_to_dimensions requires num_dims > 0"
918        );
919        self.canonicalize_to_dimensions_rec(0, num_dims)
920    }
921
922    fn canonicalize_to_dimensions_rec(self, dim: usize, num_dims: usize) -> Selection {
923        use crate::selection::dsl::*;
924
925        match self {
926            Selection::True if dim < num_dims => {
927                let mut out = true_();
928                for _ in (dim..num_dims).rev() {
929                    out = all(out);
930                }
931                out
932            }
933            Selection::False if dim < num_dims => {
934                let mut out = false_();
935                for _ in (dim..num_dims).rev() {
936                    out = all(out);
937                }
938                out
939            }
940            Selection::Any(inner) if dim < num_dims && matches!(*inner, Selection::True) => {
941                let mut out = true_();
942                for _ in (dim..num_dims).rev() {
943                    out = any(out);
944                }
945                out
946            }
947            Selection::All(inner) => all(inner.canonicalize_to_dimensions_rec(dim + 1, num_dims)),
948            Selection::Any(inner) => any(inner.canonicalize_to_dimensions_rec(dim + 1, num_dims)),
949            Selection::First(inner) => {
950                first(inner.canonicalize_to_dimensions_rec(dim + 1, num_dims))
951            }
952            Selection::Range(r, inner) => {
953                range(r, inner.canonicalize_to_dimensions_rec(dim + 1, num_dims))
954            }
955            Selection::Intersection(a, b) => intersection(
956                a.canonicalize_to_dimensions_rec(dim, num_dims),
957                b.canonicalize_to_dimensions_rec(dim, num_dims),
958            ),
959            Selection::Union(a, b) => union(
960                a.canonicalize_to_dimensions_rec(dim, num_dims),
961                b.canonicalize_to_dimensions_rec(dim, num_dims),
962            ),
963
964            other => other,
965        }
966    }
967
968    /// Recursively folds the `Selection` into an abstract syntax via
969    /// the `SelectionSYM` interface.
970    ///
971    /// This method structurally traverses the `Selection` tree and
972    /// reconstructs it using the operations provided by the
973    /// `SelectionSYM` trait. It is typically used to reify a
974    /// `Selection` into alternate forms, such as pretty-printers.
975    ///
976    /// # Type Parameters
977    ///
978    /// - `S`: An implementation of the `SelectionSYM` trait,
979    ///   providing the constructors for the target representation.
980    pub fn fold<S: SelectionSYM>(&self) -> S {
981        match self {
982            Selection::False => S::false_(),
983            Selection::True => S::true_(),
984            Selection::All(inner) => S::all(inner.fold::<S>()),
985            Selection::First(inner) => S::first(inner.fold::<S>()),
986            Selection::Range(r, inner) => S::range(r.clone(), inner.fold::<S>()),
987            Selection::Label(labels, inner) => S::label(labels.clone(), inner.fold::<S>()),
988            Selection::Any(inner) => S::any(inner.fold::<S>()),
989            Selection::Intersection(a, b) => S::intersection(a.fold::<S>(), b.fold::<S>()),
990            Selection::Union(a, b) => S::union(a.fold::<S>(), b.fold::<S>()),
991        }
992    }
993
994    /// Iterator over indices selected by `self` and not in
995    /// `exclusions`.
996    ///
997    /// Evaluates the selection against `slice` using `opts`, then
998    /// filters out any indices present in the exclusion set.
999    /// Evaluation is lazy and streaming; the exclusion set is used
1000    /// directly for fast membership checks.
1001    pub fn difference<'a>(
1002        &self,
1003        opts: &EvalOpts,
1004        slice: &'a Slice,
1005        exclusions: &'a HashSet<usize>,
1006    ) -> Result<impl Iterator<Item = usize> + use<'a>, ShapeError> {
1007        Ok(self
1008            .eval(opts, slice)?
1009            .filter(move |idx| !exclusions.contains(idx)))
1010    }
1011
1012    /// Calculate a new `Selection` that excludes the specified flat
1013    /// ranks.
1014    ///
1015    /// This computes `self \ exclusions` by evaluating `self`,
1016    /// removing the given ranks, and reconstructing a `Selection`
1017    /// that selects exactly the remaining elements.
1018    ///
1019    /// The result is a concrete, structurally uniform expression with
1020    /// predictable construction order and exact correspondence to the
1021    /// surviving ranks.
1022    pub fn without(
1023        &self,
1024        slice: &Slice,
1025        exclusions: &HashSet<usize>,
1026    ) -> Result<Selection, ShapeError> {
1027        let remaining = self
1028            .difference(&EvalOpts::strict(), slice, exclusions)?
1029            .collect::<BTreeSet<_>>();
1030        Ok(Selection::of_ranks(slice, &remaining)?)
1031    }
1032
1033    /// Converts a set of flat indices into a symbolic `Selection`
1034    /// expression over the given `slice`. Returns an error if any index
1035    /// is invalid.
1036    ///
1037    /// Each flat index is converted into coordinates using
1038    /// `slice.coordinates`, then folded into a nested chain of singleton
1039    /// ranges. The resulting selection evaluates exactly to the input
1040    /// indices.
1041    ///
1042    /// The selections are combined left-associatively using `union`, but
1043    /// since `union` is associative, the grouping does not affect
1044    /// correctness.
1045    ///
1046    /// The input `BTreeSet` ensures:
1047    /// - all indices are unique (no redundant singleton ranges),
1048    /// - the resulting selection has a stable, deterministic structure,
1049    /// - and iteration proceeds in ascending order, which helps produce
1050    ///   predictable routing trees and consistent test results.
1051    ///
1052    /// This choice avoids an explicit sort and makes downstream behavior
1053    /// more reproducible and auditable.
1054    pub fn of_ranks(slice: &Slice, ranks: &BTreeSet<usize>) -> Result<Selection, SliceError> {
1055        let selections = ranks
1056            .iter()
1057            .map(|&i| {
1058                Ok(slice
1059                    .coordinates(i)?
1060                    .into_iter()
1061                    .rev()
1062                    .fold(dsl::true_(), |acc, i| dsl::range(i..=i, acc)))
1063            })
1064            .collect::<Result<Vec<_>, SliceError>>()?;
1065
1066        Ok(selections
1067            .into_iter()
1068            .reduce(dsl::union)
1069            .unwrap_or_else(dsl::false_))
1070    }
1071} // impl Selection
1072
1073mod sealed {
1074    pub trait Sealed {}
1075    impl Sealed for crate::slice::Slice {}
1076}
1077
1078/// Connects the `select!` API to the `Selection` algebra by enabling
1079/// `base.reify_slice(slice)` syntax, where `base: Slice`.
1080///
1081/// The base slice defines the coordinate system in which the slice is
1082/// interpreted. Slices are themselves `Slice` values, typically
1083/// produced by `select!`, and are reified as `Selection` expressions
1084/// over the base.
1085pub trait ReifySlice: sealed::Sealed {
1086    /// Reify a slice as a `Selection` in the coordinate system of
1087    /// `self`.
1088    fn reify_slice(&self, slice: &Slice) -> Result<Selection, SliceError>;
1089
1090    /// Reify multiple slices as a union of selections in the
1091    /// coordinate system of `self`.
1092    fn reify_slices<V: AsRef<[Slice]>>(&self, slices: V) -> Result<Selection, SliceError>;
1093}
1094
1095impl ReifySlice for Slice {
1096    /// Constructs a [`Selection`] expression that symbolically
1097    /// matches all coordinates in the given `slice`, expressed in the
1098    /// coordinate system of the provided `base` slice (`self`).
1099    ///
1100    /// The result is a nested sequence of `range(start..end, step)`
1101    /// combinators that match the rectangular region covered by `slice`
1102    /// in base coordinates. This preserves geometry and layout when
1103    /// `slice` is *layout-aligned* — that is, each of its strides is
1104    /// a multiple of the corresponding base stride.
1105    ///
1106    /// If any dimension is not layout-aligned, the slice is reified
1107    /// by explicitly enumerating its coordinates.
1108    ///
1109    /// Returns [`dsl::false_()`] if the slice is empty.
1110    ///
1111    /// # Errors
1112    ///
1113    /// Returns an error if:
1114    /// - The base is not contiguous and row-major
1115    /// - The slice lies outside the bounds of the base
1116    ///
1117    /// # Example
1118    ///
1119    /// ```rust
1120    /// use ndslice::selection::ReifySlice;
1121    /// let shape = ndslice::shape!(x = 4, y = 4);
1122    /// let base = shape.slice();
1123    /// let selected = ndslice::select!(shape, x = 1..3, y = 2..4).unwrap();
1124    /// let slice = selected.slice();
1125    /// let selection = base.reify_slice(slice).unwrap();
1126    /// ```
1127    fn reify_slice(&self, slice: &Slice) -> Result<Selection, SliceError> {
1128        // Precondition: the base is contiguous and row major.
1129        if !self.is_contiguous() {
1130            return Err(SliceError::NonContiguous);
1131        }
1132
1133        if slice.is_empty() {
1134            return Ok(dsl::false_());
1135        }
1136
1137        if slice.num_dim() != self.num_dim()
1138            || slice.sizes().iter().zip(self.sizes()).any(|(&v, &s)| v > s)
1139        {
1140            return Selection::of_ranks(self, &slice.iter().collect::<BTreeSet<usize>>());
1141        }
1142
1143        let origin = self.coordinates(slice.offset())?;
1144        let mut acc = dsl::true_();
1145        for dim in (0..self.num_dim()).rev() {
1146            let start = origin[dim];
1147            let len = slice.sizes()[dim];
1148            let slice_stride = slice.strides()[dim];
1149            let base_stride = self.strides()[dim];
1150
1151            if slice_stride.is_multiple_of(base_stride) {
1152                // Layout-aligned with base.
1153                let step = slice_stride / base_stride;
1154                let end = start + step * len;
1155                // Check that `end` is within bounds.
1156                if end - 1 > self.sizes()[dim] {
1157                    let bad = origin
1158                        .iter()
1159                        .enumerate()
1160                        .map(|(d, &x)| if d == dim { end } else { x })
1161                        .collect::<Vec<_>>();
1162                    return Err(SliceError::ValueNotInSlice {
1163                        value: self.location(&bad).unwrap(),
1164                    });
1165                }
1166                acc = dsl::range(crate::shape::Range(start, Some(end), step), acc);
1167            } else {
1168                // Irregular layout; fallback to explicit enumeration.
1169                return Selection::of_ranks(self, &slice.iter().collect::<BTreeSet<_>>());
1170            }
1171        }
1172
1173        Ok(acc)
1174    }
1175
1176    /// Converts a list of `slices` into a symbolic [`Selection`]
1177    /// expression over a common `base` [`Slice`].
1178    ///
1179    /// Each slice describes a rectangular subregion of the base. This
1180    /// function reifies each slice into a nested `range(.., ..)`
1181    /// expression in the base coordinate system and returns the union
1182    /// of all such selections.
1183    ///
1184    /// Empty slices are ignored.
1185    ///
1186    /// # Errors
1187    ///
1188    /// Returns an error if any slice:
1189    /// - Refers to coordinates not contained within the base
1190    ///
1191    /// # Example
1192    ///
1193    /// ```rust
1194    /// use ndslice::selection::ReifySlice;
1195    ///
1196    /// let shape = ndslice::shape!(x = 4, y = 4);
1197    /// let base = shape.slice();
1198    ///
1199    /// let a = ndslice::select!(shape, x = 0..2, y = 0..2)
1200    ///     .unwrap()
1201    ///     .slice()
1202    ///     .clone();
1203    /// let b = ndslice::select!(shape, x = 2..4, y = 2..4)
1204    ///     .unwrap()
1205    ///     .slice()
1206    ///     .clone();
1207    ///
1208    /// let sel = base.reify_slices(&[a, b]).unwrap();
1209    /// ```
1210    fn reify_slices<V: AsRef<[Slice]>>(&self, slices: V) -> Result<Selection, SliceError> {
1211        let slices = slices.as_ref();
1212        let mut selections = Vec::with_capacity(slices.len());
1213
1214        for slice in slices {
1215            if slice.is_empty() {
1216                continue;
1217            }
1218            selections.push(self.reify_slice(slice)?);
1219        }
1220
1221        let mut iter = selections.into_iter();
1222        let first = iter.next().unwrap_or_else(dsl::false_);
1223        Ok(iter.fold(first, dsl::union))
1224    }
1225}
1226
1227/// Trivial all(true) equivalence.
1228pub fn is_equivalent_true(sel: impl std::borrow::Borrow<Selection>) -> bool {
1229    Selection::is_equivalent_to_true(sel.borrow())
1230}
1231
1232mod iterutils {
1233    // An iterator over the first non-empty result 1 applying
1234    // `mk_iter` to indices in the range `0..size`.
1235    pub(crate) fn first<'a, F>(size: usize, mut mk_iter: F) -> impl Iterator<Item = usize> + 'a
1236    where
1237        F: FnMut(usize) -> Box<dyn Iterator<Item = usize> + 'a>,
1238    {
1239        (0..size)
1240            .find_map(move |i| {
1241                let mut iter = mk_iter(i).peekable();
1242                if iter.peek().is_some() {
1243                    Some(iter)
1244                } else {
1245                    None
1246                }
1247            })
1248            .into_iter()
1249            .flatten()
1250    }
1251}
1252
1253/// Construct a [`Selection`] from a [`Shape`] and a single labeled
1254/// constraint.
1255///
1256/// This function produces a multidimensional selection expression
1257/// that is structurally aligned with the shape. It applies the given
1258/// range to the named dimension, and fills all preceding dimensions
1259/// with [`all`] to maintain alignment. Trailing dimensions are left
1260/// unconstrained.
1261///
1262/// # Arguments
1263///
1264/// - `shape`: The labeled shape describing the coordinate space.
1265/// - `label`: The name of the dimension to constrain.
1266/// - `rng`: The range to apply in the selected dimension.
1267///
1268/// # Errors
1269///
1270/// Returns [`ShapeError::InvalidLabels`] if the label is not present
1271/// in the shape.
1272///
1273/// # Example
1274///
1275/// ```
1276/// use ndslice::shape;
1277/// use ndslice::selection::selection_from_one;
1278///
1279/// let shape = shape!(zone = 2, host = 4, gpu = 8);
1280/// let sel = selection_from_one(&shape, "host", 1..3).unwrap();
1281/// assert_eq!(format!("{sel:?}"), "All(Range(Range(1, Some(3), 1), True))"); // corresponds to (*, 1..3, *)
1282/// ```
1283///
1284/// [`all`]: crate::selection::dsl::all
1285/// [`Shape`]: crate::shape::Shape
1286/// [`Selection`]: crate::selection::Selection
1287pub fn selection_from_one<R>(
1288    shape: &shape::Shape,
1289    label: &str,
1290    rng: R,
1291) -> Result<Selection, ShapeError>
1292where
1293    R: Into<shape::Range>,
1294{
1295    use crate::selection::dsl;
1296
1297    let Some(pos) = shape.labels().iter().position(|l| l == label) else {
1298        return Err(ShapeError::InvalidLabels {
1299            labels: vec![label.to_string()],
1300        });
1301    };
1302
1303    let mut selection = dsl::range(rng.into(), dsl::true_());
1304    for _ in 0..pos {
1305        selection = dsl::all(selection)
1306    }
1307
1308    Ok(selection)
1309}
1310
1311/// Construct a [`Selection`] from a [`Shape`] and multiple labeled
1312/// range constraints.
1313///
1314/// This function produces a multidimensional selection aligned with
1315/// the shape, applying the specified constraints to their
1316/// corresponding dimensions. All unconstrained dimensions are filled
1317/// with [`all`] to preserve structural alignment.
1318///
1319/// # Arguments
1320///
1321/// - `shape`: The labeled shape defining the full coordinate space.
1322/// - `constraints`: A slice of `(label, range)` pairs specifying
1323///   dimension constraints.
1324///
1325/// # Errors
1326///
1327/// Returns [`ShapeError::InvalidLabels`] if any label in
1328/// `constraints` is not present in the shape.
1329///
1330/// # Example
1331///
1332/// ```
1333/// use ndslice::selection::selection_from;
1334/// use ndslice::shape;
1335///
1336/// let shape = shape!(zone = 2, host = 4, gpu = 8);
1337/// let sel = selection_from(&shape, &[("host", 1..3), ("gpu", 0..4)]).unwrap();
1338/// assert_eq!(
1339///     format!("{sel:?}"),
1340///     "All(Range(Range(1, Some(3), 1), Range(Range(0, Some(4), 1), True)))"
1341/// );
1342/// ```
1343///
1344/// [`Shape`]: crate::shape::Shape
1345/// [`Selection`]: crate::selection::Selection
1346/// [`all`]: crate::selection::dsl::all
1347pub fn selection_from<'a, R>(
1348    shape: &shape::Shape,
1349    constraints: &[(&'a str, R)],
1350) -> Result<Selection, ShapeError>
1351where
1352    R: Clone + Into<shape::Range> + 'a,
1353{
1354    use crate::selection::dsl::*;
1355
1356    let mut label_to_constraint = HashMap::new();
1357    for (label, r) in constraints {
1358        if !shape.labels().iter().any(|l| l == label) {
1359            return Err(ShapeError::InvalidLabels {
1360                labels: vec![label.to_string()],
1361            });
1362        }
1363        label_to_constraint.insert(*label, r.clone().into());
1364    }
1365
1366    let selection = shape.labels().iter().rev().fold(true_(), |acc, label| {
1367        if let Some(rng) = label_to_constraint.get(label.as_str()) {
1368            range(rng.clone(), acc)
1369        } else {
1370            all(acc)
1371        }
1372    });
1373
1374    Ok(selection)
1375}
1376
1377/// Construct a [`Selection`] from a [`Shape`] using labeled dimension
1378/// constraints.
1379///
1380/// This macro provides a convenient syntax for specifying
1381/// sub-selections on a shape by labeling dimensions and applying
1382/// either exact indices or ranges. Internally, it wraps
1383/// [`selection_from_one`] and [`selection_from`] to produce a
1384/// fully-aligned [`Selection`] expression.
1385///
1386/// # Forms
1387///
1388/// - Single labeled range:
1389///   ```
1390///   let shape = ndslice::shape!(zone = 2, host = 4, gpu = 8);
1391///   let sel = ndslice::sel_from_shape!(&shape, host = 1..3);
1392///   ```
1393///
1394/// - Multiple exact indices (converted to `n..n+1`):
1395///   ```
1396///   let shape = ndslice::shape!(zone = 2, host = 4, gpu = 8);
1397///   let sel = ndslice::sel_from_shape!(&shape, zone = 1, gpu = 4);
1398///   ```
1399///
1400/// - Multiple labeled ranges:
1401///   ```
1402///   let shape = ndslice::shape!(zone = 2, host = 4, gpu = 8);
1403///   let sel = ndslice::sel_from_shape!(&shape, zone = 0..1, host = 1..3, gpu = 4..8);
1404///   ```
1405///
1406/// # Panics
1407///
1408/// This macro calls `.unwrap()` on the result of the underlying
1409/// functions. It will panic if any label is not found in the shape.
1410///
1411/// # See Also
1412///
1413/// - [`selection_from_one`]
1414/// - [`selection_from`]
1415/// - [`Selection`]
1416/// - [`Shape`]
1417///
1418/// [`Selection`]: crate::selection::Selection
1419/// [`Shape`]: crate::shape::Shape
1420/// [`selection_from_one`]: crate::selection::selection_from_one
1421/// [`selection_from`]: crate::selection::selection_from
1422#[macro_export]
1423macro_rules! sel_from_shape {
1424    ($shape:expr, $label:ident = $range:expr) => {
1425        $crate::selection::selection_from_one($shape, stringify!($label), $range).unwrap()
1426    };
1427
1428    ($shape:expr, $($label:ident = $val:literal),* $(,)?) => {
1429        $crate::selection::selection_from($shape,
1430                                          &[
1431                                              $((stringify!($label), $val..$val+1)),*
1432                                          ]).unwrap()
1433    };
1434
1435    ($shape:expr, $($label:ident = $range:expr),* $(,)?) => {
1436        $crate::selection::selection_from($shape, &[
1437            $((stringify!($label), $range)),*
1438        ]).unwrap()
1439    };
1440}
1441
1442#[cfg(test)]
1443mod tests {
1444    use std::assert_matches::assert_matches;
1445    use std::collections::BTreeSet;
1446
1447    use super::EvalOpts;
1448    use super::ReifySlice;
1449    use super::Selection;
1450    use super::dsl::*;
1451    use super::is_equivalent_true;
1452    use crate::Range;
1453    use crate::Slice;
1454    use crate::assert_structurally_eq;
1455    use crate::select;
1456    use crate::shape;
1457    use crate::shape::ShapeError;
1458
1459    // A test slice: (zones = 2, hosts = 4, gpus = 8).
1460    fn test_slice() -> Slice {
1461        Slice::new(0usize, vec![2, 4, 8], vec![32, 8, 1]).unwrap()
1462    }
1463
1464    // Given expression `expr`, options `opts` and slice `slice`,
1465    // canonical usage is:
1466    // ```rust
1467    // let nodes = expr.eval(&opts, slice.clone())?.collect::<Vec<usize>>();
1468    // ```
1469    // This utility cuts down on the syntactic repetition that results
1470    // from the above in the tests that follow.
1471    fn eval(expr: Selection, slice: &Slice) -> Vec<usize> {
1472        expr.eval(&EvalOpts::lenient(), slice).unwrap().collect()
1473    }
1474
1475    #[test]
1476    fn test_selection_00() {
1477        let slice = &test_slice();
1478
1479        // No GPUs on any host in any region.
1480        assert!(eval(false_(), slice).is_empty());
1481        assert!(eval(all(false_()), slice).is_empty());
1482        assert!(eval(all(all(false_())), slice).is_empty());
1483
1484        // All GPUs on all hosts in all regions.
1485        assert_eq!((0..=63).collect::<Vec<_>>(), eval(true_(), slice));
1486        assert_eq!(eval(true_(), slice), eval(all(true_()), slice));
1487        assert_eq!(eval(all(true_()), slice), eval(all(all(true_())), slice));
1488
1489        // Terminal `true_()` and `false_()` selections are allowed at
1490        // the leaf.
1491        assert_eq!(eval(true_(), slice), eval(all(all(all(true_()))), slice));
1492        assert!(eval(all(all(all(false_()))), slice).is_empty());
1493    }
1494
1495    #[test]
1496    fn test_selection_01() {
1497        let slice = &test_slice();
1498
1499        // Structural combinators beyond the slice's dimensionality
1500        // are invalid.
1501        let expr = all(all(all(all(true_()))));
1502        let result = expr.validate(&EvalOpts::lenient(), slice);
1503        assert!(
1504            matches!(result, Err(ShapeError::SelectionTooDeep { .. })),
1505            "Unexpected: {:?}",
1506            result
1507        );
1508        assert_eq!(
1509            format!("{}", result.unwrap_err()),
1510            "selection `all(all(all(all(true_()))))` exceeds dimensionality 3"
1511        );
1512    }
1513
1514    #[test]
1515    fn test_selection_02() {
1516        let slice = &test_slice();
1517
1518        // GPU 0 on host 0 in region 0.
1519        let select = range(0..=0, range(0..=0, range(0..=0, true_())));
1520        assert_eq!((0..=0).collect::<Vec<_>>(), eval(select, slice));
1521
1522        // GPU 1 on host 1 in region 1.
1523        let select = range(1..=1, range(1..=1, range(1..=1, true_())));
1524        assert_eq!((41..=41).collect::<Vec<_>>(), eval(select, slice));
1525
1526        // All GPUs on host 0 in all regions:
1527        let select = all(range(0..=0, all(true_())));
1528        assert_eq!(
1529            (0..=7).chain(32..=39).collect::<Vec<_>>(),
1530            eval(select, slice)
1531        );
1532
1533        // All GPUs on host 1 in all regions:
1534        let select = all(range(1..=1, all(true_())));
1535        assert_eq!(
1536            (8..=15).chain(40..=47).collect::<Vec<_>>(),
1537            eval(select, slice)
1538        );
1539
1540        // The first 4 GPUs on all hosts in all regions:
1541        let select = all(all(range(0..4, true_())));
1542        assert_eq!(
1543            (0..=3)
1544                .chain(8..=11)
1545                .chain(16..=19)
1546                .chain(24..=27)
1547                .chain(32..=35)
1548                .chain(40..=43)
1549                .chain(48..=51)
1550                .chain(56..=59)
1551                .collect::<Vec<_>>(),
1552            eval(select, slice)
1553        );
1554
1555        // The last 4 GPUs on all hosts in all regions:
1556        let select = all(all(range(4..8, true_())));
1557        assert_eq!(
1558            (4..=7)
1559                .chain(12..=15)
1560                .chain(20..=23)
1561                .chain(28..=31)
1562                .chain(36..=39)
1563                .chain(44..=47)
1564                .chain(52..=55)
1565                .chain(60..=63)
1566                .collect::<Vec<_>>(),
1567            eval(select, slice)
1568        );
1569
1570        // All regions, all hosts, odd GPUs:
1571        let select = all(all(range(shape::Range(1, None, 2), true_())));
1572        assert_eq!(
1573            (1..8)
1574                .step_by(2)
1575                .chain((9..16).step_by(2))
1576                .chain((17..24).step_by(2))
1577                .chain((25..32).step_by(2))
1578                .chain((33..40).step_by(2))
1579                .chain((41..48).step_by(2))
1580                .chain((49..56).step_by(2))
1581                .chain((57..64).step_by(2))
1582                .collect::<Vec<_>>(),
1583            eval(select, slice)
1584        );
1585    }
1586
1587    #[test]
1588    fn test_selection_03() {
1589        let slice = &test_slice();
1590
1591        assert_eq!(
1592            eval(intersection(true_(), true_()), slice),
1593            eval(true_(), slice)
1594        );
1595        assert_eq!(
1596            eval(intersection(true_(), false_()), slice),
1597            eval(false_(), slice)
1598        );
1599        assert_eq!(
1600            eval(intersection(false_(), true_()), slice),
1601            eval(false_(), slice)
1602        );
1603        assert_eq!(
1604            eval(intersection(false_(), false_()), slice),
1605            eval(false_(), slice)
1606        );
1607        assert_eq!(
1608            eval(
1609                intersection(
1610                    all(all(range(0..=3, true_()))),
1611                    all(all(range(4..=7, true_())))
1612                ),
1613                slice
1614            ),
1615            eval(false_(), slice)
1616        );
1617        assert_eq!(
1618            eval(intersection(true_(), all(all(range(4..8, true_())))), slice),
1619            eval(all(all(range(4..8, true_()))), slice)
1620        );
1621        assert_eq!(
1622            eval(
1623                intersection(
1624                    all(all(range(0..=4, true_()))),
1625                    all(all(range(4..=7, true_())))
1626                ),
1627                slice
1628            ),
1629            eval(all(all(range(4..=4, true_()))), slice)
1630        );
1631    }
1632
1633    #[test]
1634    fn test_selection_04() {
1635        let slice = &test_slice();
1636
1637        assert_eq!(eval(union(true_(), true_()), slice), eval(true_(), slice));
1638        assert_eq!(eval(union(false_(), true_()), slice), eval(true_(), slice));
1639        assert_eq!(eval(union(true_(), false_()), slice), eval(true_(), slice));
1640        assert_eq!(
1641            eval(union(false_(), false_()), slice),
1642            eval(false_(), slice)
1643        );
1644        assert_eq!(
1645            eval(
1646                union(
1647                    all(all(range(0..4, true_()))),
1648                    all(all(range(4.., true_())))
1649                ),
1650                slice
1651            ),
1652            eval(true_(), slice)
1653        );
1654
1655        // Across all regions, get the first 4 GPUs on host 0 and the
1656        // last 4 GPUs on host 1.
1657        let s = all(range(0..=0, range(0..4, true_())));
1658        let t = all(range(1..=1, range(4.., true_())));
1659        assert_eq!(
1660            (0..=3)
1661                .chain(12..=15)
1662                .chain(32..=35)
1663                .chain(44..=47)
1664                .collect::<Vec<_>>(),
1665            eval(union(s, t), slice)
1666        );
1667
1668        // All regions, all hosts, skip GPUs 2, 3, 4 and 5.
1669        assert_eq!(
1670            // z=0, h=0
1671            (0..=1)
1672                .chain(6..=7)
1673                // z=0, h=1
1674                .chain(8..=9)
1675                .chain(14..=15)
1676                // z=0, h=2
1677                .chain(16..=17)
1678                .chain(22..=23)
1679                // z=0, h=3
1680                .chain(24..=25)
1681                .chain(30..=31)
1682                // z=1, h=0
1683                .chain(32..=33)
1684                .chain(38..=39)
1685                // z=1, h=1
1686                .chain(40..=41)
1687                .chain(46..=47)
1688                // z=1, h=2
1689                .chain(48..=49)
1690                .chain(54..=55)
1691                // z=1, h=3
1692                .chain(56..=57)
1693                .chain(62..=63)
1694                .collect::<Vec<_>>(),
1695            eval(
1696                all(all(union(range(0..2, true_()), range(6..8, true_())))),
1697                slice
1698            )
1699        );
1700
1701        // All regions, all hosts, odd GPUs.
1702        assert_eq!(
1703            (1..8)
1704                .step_by(2)
1705                .chain((9..16).step_by(2))
1706                .chain((17..24).step_by(2))
1707                .chain((25..32).step_by(2))
1708                .chain((33..40).step_by(2))
1709                .chain((41..48).step_by(2))
1710                .chain((49..56).step_by(2))
1711                .chain((57..64).step_by(2))
1712                .collect::<Vec<_>>(),
1713            eval(
1714                all(all(union(
1715                    range(shape::Range(1, Some(4), 2), true_()),
1716                    range(shape::Range(5, Some(8), 2), true_())
1717                ))),
1718                slice
1719            )
1720        );
1721        assert_eq!(
1722            eval(
1723                all(all(union(
1724                    range(shape::Range(1, Some(4), 2), true_()),
1725                    range(shape::Range(5, Some(8), 2), true_())
1726                ))),
1727                slice
1728            ),
1729            eval(
1730                all(all(union(
1731                    union(range(1..=1, true_()), range(3..=3, true_()),),
1732                    union(range(5..=5, true_()), range(7..=7, true_()),),
1733                ))),
1734                slice
1735            ),
1736        );
1737    }
1738
1739    #[test]
1740    fn test_selection_05() {
1741        let slice = &test_slice();
1742
1743        // First region, first host, no GPU.
1744        assert!(eval(first(first(false_())), slice).is_empty());
1745        // First region, first host, first GPU.
1746        assert_eq!(vec![0], eval(first(first(range(0..1, true_()))), slice));
1747        // First region, first host, all GPUs.
1748        assert_eq!(
1749            (0..8).collect::<Vec<_>>(),
1750            eval(first(first(true_())), slice)
1751        );
1752
1753        // Terminal `true_()` and `false_()` selections are allowed at
1754        // the leaf.
1755        // First region, first host, no GPU.
1756        assert!(eval(first(first(first(false_()))), slice).is_empty());
1757        // First region, first host, first GPU.
1758        assert_eq!(vec![0], eval(first(first(first(true_()))), slice));
1759
1760        // All regions, first host, all GPUs.
1761        assert_eq!(
1762            (0..8).chain(32..40).collect::<Vec<_>>(),
1763            eval(all(first(true_())), slice)
1764        );
1765
1766        // First region, first host, GPUs 0, 1 and 2.
1767        assert_eq!(
1768            (0..3).collect::<Vec<_>>(),
1769            eval(first(first(range(0..=2, true_()))), slice)
1770        );
1771    }
1772
1773    #[test]
1774    fn test_selection_06() {
1775        let slice = &test_slice();
1776
1777        // Structural combinators beyond the slice's dimensionality
1778        // are invalid.
1779        let expr = first(first(first(first(true_()))));
1780        let result = expr.validate(&EvalOpts::lenient(), slice);
1781        assert!(
1782            matches!(result, Err(ShapeError::SelectionTooDeep { .. })),
1783            "Unexpected: {:?}",
1784            result
1785        );
1786        assert_eq!(
1787            format!("{}", result.unwrap_err()),
1788            "selection `first(first(first(first(true_()))))` exceeds dimensionality 3"
1789        );
1790    }
1791
1792    #[test]
1793    fn test_selection_07() {
1794        use crate::select;
1795        use crate::shape;
1796
1797        // 2 x 8 row-major.
1798        let s = shape!(host = 2, gpu = 8);
1799
1800        // All GPUs on host 1.
1801        assert_eq!(
1802            select!(s, host = 1)
1803                .unwrap()
1804                .slice()
1805                .iter()
1806                .collect::<Vec<_>>(),
1807            eval(range(1..2, true_()), s.slice())
1808        );
1809
1810        // All hosts, GPUs 2 through 7.
1811        assert_eq!(
1812            select!(s, gpu = 2..)
1813                .unwrap()
1814                .slice()
1815                .iter()
1816                .collect::<Vec<_>>(),
1817            eval(all(range(2.., true_())), s.slice())
1818        );
1819
1820        // All hosts, GPUs 3 and 4.
1821        assert_eq!(
1822            select!(s, gpu = 3..5)
1823                .unwrap()
1824                .slice()
1825                .iter()
1826                .collect::<Vec<_>>(),
1827            eval(all(range(3..5, true_())), s.slice())
1828        );
1829
1830        // GPUS 3 and 4 on host 1.
1831        assert_eq!(
1832            select!(s, gpu = 3..5, host = 1)
1833                .unwrap()
1834                .slice()
1835                .iter()
1836                .collect::<Vec<_>>(),
1837            eval(range(1..=1, range(3..5, true_())), s.slice())
1838        );
1839
1840        // All hosts, no GPUs.
1841        assert_matches!(
1842            select!(s, gpu = 1..1).unwrap_err(),
1843            ShapeError::EmptyRange {
1844                range: shape::Range(1, Some(1), 1)
1845            },
1846        );
1847        assert!(eval(all(range(1..1, true_())), s.slice()).is_empty());
1848
1849        // All hosts, GPU 8.
1850        assert_matches!(
1851            select!(s, gpu = 8).unwrap_err(),
1852            ShapeError::OutOfRange {
1853                range: shape::Range(8, Some(9), 1),
1854                dim,
1855                size: 8,
1856            } if dim == "gpu",
1857        );
1858        assert!(eval(all(range(8..8, true_())), s.slice()).is_empty());
1859    }
1860
1861    #[test]
1862    fn test_selection_08() {
1863        let s = &shape!(host = 2, gpu = 8);
1864
1865        assert_eq!(
1866            eval(range(1..2, true_()), s.slice()),
1867            eval(sel_from_shape!(s, host = 1), s.slice())
1868        );
1869
1870        assert_eq!(
1871            eval(all(range(2.., true_())), s.slice()),
1872            eval(sel_from_shape!(s, gpu = 2..), s.slice())
1873        );
1874
1875        assert_eq!(
1876            eval(all(range(3..5, true_())), s.slice()),
1877            eval(sel_from_shape!(s, gpu = 3..5), s.slice())
1878        );
1879
1880        assert_eq!(
1881            eval(range(1..2, range(3..5, true_())), s.slice()),
1882            eval(sel_from_shape!(s, host = 1..2, gpu = 3..5), s.slice())
1883        );
1884
1885        assert_eq!(
1886            eval(
1887                union(
1888                    sel_from_shape!(s, host = 0..1, gpu = 0..4),
1889                    sel_from_shape!(s, host = 1..2, gpu = 4..5)
1890                ),
1891                s.slice()
1892            ),
1893            eval(
1894                union(
1895                    range(0..1, range(0..4, true_())),
1896                    range(1..2, range(4..5, true_()))
1897                ),
1898                s.slice()
1899            )
1900        );
1901    }
1902
1903    #[test]
1904    fn test_selection_09() {
1905        let slice = &test_slice(); // 2 x 4 x 8
1906
1907        // Identity.
1908        assert_eq!(eval(any(false_()), slice), eval(false_(), slice));
1909
1910        // An arbitrary GPU.
1911        let res = eval(any(any(any(true_()))), slice);
1912        assert_eq!(res.len(), 1);
1913        assert!(res[0] < 64);
1914
1915        // The first 4 GPUs of any host in region-0.
1916        let res = eval(range(0, any(range(0..4, true_()))), slice);
1917        assert!((0..4).any(|host| res == eval(range(0, range(host, range(0..4, true_()))), slice)));
1918
1919        // Any GPU on host-0 in region-0.
1920        let res = eval(range(0, range(0, any(true_()))), slice);
1921        assert_eq!(res.len(), 1);
1922        assert!(res[0] < 8);
1923
1924        // All GPUs on any host in region-0.
1925        let res = eval(range(0, any(true_())), slice);
1926        assert!((0..4).any(|host| res == eval(range(0, range(host, true_())), slice)));
1927
1928        // All GPUs on any host in region-1.
1929        let res = eval(range(1, any(true_())), slice);
1930        assert!((0..4).any(|host| res == eval(range(1, range(host, true_())), slice)));
1931
1932        // Any two GPUs on host-0 in region-0.
1933        let mut res = vec![];
1934        while res.len() < 2 {
1935            res = eval(
1936                union(
1937                    range(0, range(0, any(true_()))),
1938                    range(0, range(0, any(true_()))),
1939                ),
1940                slice,
1941            );
1942        }
1943        assert_matches!(res.as_slice(), [i, j] if *i < *j && *i < 8 && *j < 8);
1944    }
1945
1946    #[test]
1947    fn test_eval_zero_dim_slice() {
1948        let slice_0d = Slice::new(1, vec![], vec![]).unwrap();
1949        // Let s be a slice with dim(s) = 0. Then: ∃! x ∈ s :
1950        // coordsₛ(x) = ().
1951        assert_eq!(slice_0d.coordinates(1).unwrap(), Vec::<usize>::new());
1952
1953        assert_eq!(eval(true_(), &slice_0d), vec![1]);
1954        assert_eq!(eval(false_(), &slice_0d), Vec::<usize>::new());
1955        assert_eq!(eval(all(true_()), &slice_0d), vec![1]);
1956        assert_eq!(eval(all(false_()), &slice_0d), Vec::<usize>::new());
1957        assert_eq!(eval(union(true_(), true_()), &slice_0d), vec![1]);
1958        assert_eq!(
1959            eval(intersection(true_(), false_()), &slice_0d),
1960            Vec::<usize>::new()
1961        );
1962    }
1963
1964    #[test]
1965    fn test_selection_10() {
1966        let slice = &test_slice();
1967        let opts = EvalOpts {
1968            disallow_dynamic_selections: true,
1969            ..EvalOpts::lenient()
1970        };
1971        let expr = any(any(any(true_())));
1972        let res = expr.validate(&opts, slice);
1973        assert_matches!(res, Err(ShapeError::SelectionDynamic { .. }));
1974    }
1975
1976    #[test]
1977    fn test_13() {
1978        // Structural identity: `all(true)` <=> `true`.
1979        assert!(is_equivalent_true(true_()));
1980        assert!(is_equivalent_true(all(true_())));
1981        assert!(is_equivalent_true(all(all(true_()))));
1982        assert!(is_equivalent_true(all(all(all(true_())))));
1983        assert!(is_equivalent_true(all(all(all(all(true_()))))));
1984        assert!(is_equivalent_true(all(all(all(all(all(true_())))))));
1985        // ...
1986
1987        assert!(!is_equivalent_true(false_()));
1988        assert!(!is_equivalent_true(union(true_(), true_())));
1989        assert!(!is_equivalent_true(range(0..=0, true_())));
1990        assert!(!is_equivalent_true(all(false_())));
1991    }
1992
1993    #[test]
1994    fn test_14() {
1995        use std::collections::HashSet;
1996
1997        use crate::selection::NormalizedSelectionKey;
1998        use crate::selection::dsl::*;
1999
2000        let a = all(all(true_()));
2001        let b = all(all(true_()));
2002
2003        let key_a = NormalizedSelectionKey::new(&a);
2004        let key_b = NormalizedSelectionKey::new(&b);
2005
2006        // They should be structurally equal.
2007        assert_eq!(key_a, key_b);
2008
2009        // Their hashes should agree, and they deduplicate in a set.
2010        let mut set = HashSet::new();
2011        set.insert(key_a);
2012        assert!(set.contains(&key_b));
2013    }
2014
2015    #[test]
2016    fn test_contains_true() {
2017        let selection = true_();
2018        assert!(selection.contains(&[0, 0, 0]));
2019        assert!(selection.contains(&[1, 2, 3]));
2020    }
2021
2022    #[test]
2023    fn test_contains_false() {
2024        let selection = false_();
2025        assert!(!selection.contains(&[0, 0, 0]));
2026        assert!(!selection.contains(&[1, 2, 3]));
2027    }
2028
2029    #[test]
2030    fn test_contains_all() {
2031        let selection = all(true_());
2032        assert!(selection.contains(&[0, 0, 0]));
2033        assert!(selection.contains(&[1, 2, 3]));
2034    }
2035
2036    #[test]
2037    fn test_contains_range() {
2038        let selection = range(1..3, true_());
2039        assert!(selection.contains(&[1, 0, 0]));
2040        assert!(!selection.contains(&[3, 0, 0]));
2041    }
2042
2043    #[test]
2044    fn test_contains_intersection() {
2045        let selection = intersection(range(1..3, true_()), range(2..4, true_()));
2046        assert!(selection.contains(&[2, 0, 0]));
2047        assert!(!selection.contains(&[1, 0, 0]));
2048    }
2049
2050    #[test]
2051    fn test_contains_union() {
2052        let selection = union(range(1..2, true_()), range(3..4, true_()));
2053        assert!(selection.contains(&[1, 0, 0]));
2054        assert!(!selection.contains(&[2, 0, 0]));
2055    }
2056
2057    #[test]
2058    #[should_panic(expected = "not implemented")]
2059    fn test_contains_any() {
2060        let selection = any(true_());
2061        selection.contains(&[0, 0, 0]);
2062    }
2063
2064    #[test]
2065    #[should_panic(expected = "not implemented")]
2066    fn test_contains_label() {
2067        let selection = label(vec!["zone".to_string()], true_());
2068        selection.contains(&[1, 2, 3]);
2069    }
2070
2071    #[test]
2072    #[should_panic(expected = "not implemented")]
2073    fn test_contains_first() {
2074        let selection = first(true_());
2075        selection.contains(&[0, 0, 0]);
2076    }
2077
2078    #[test]
2079    fn test_difference_1d() {
2080        assert_eq!(
2081            true_()
2082                .difference(
2083                    &EvalOpts::strict(),
2084                    &Slice::new_row_major([5]),
2085                    &[2usize, 4].into(),
2086                )
2087                .unwrap()
2088                .collect::<Vec<_>>(),
2089            vec![0, 1, 3]
2090        );
2091    }
2092
2093    #[test]
2094    fn test_difference_empty_selection() {
2095        assert_eq!(
2096            false_()
2097                .difference(
2098                    &EvalOpts::strict(),
2099                    &Slice::new_row_major([3]),
2100                    &[0usize, 1].into(),
2101                )
2102                .unwrap()
2103                .collect::<Vec<_>>(),
2104            Vec::<usize>::new()
2105        );
2106    }
2107
2108    #[test]
2109    fn test_difference_2d() {
2110        // [[0, 1, 2],
2111        //  [3, 4, 5]]
2112        // Select everything, exclude the second row.
2113        assert_eq!(
2114            all(all(true_()))
2115                .difference(
2116                    &EvalOpts::strict(),
2117                    &Slice::new_row_major([2, 3]),
2118                    &[3usize, 4, 5].into(),
2119                )
2120                .unwrap()
2121                .collect::<Vec<_>>(),
2122            vec![0, 1, 2]
2123        );
2124    }
2125
2126    #[test]
2127    fn test_of_ranks_1d() {
2128        let slice = Slice::new_row_major([5]);
2129        let ranks = BTreeSet::from([1, 3]);
2130        let selection = Selection::of_ranks(&slice, &ranks).unwrap();
2131        assert_eq!(
2132            selection
2133                .eval(&EvalOpts::strict(), &slice)
2134                .unwrap()
2135                .collect::<Vec<_>>(),
2136            vec![1, 3]
2137        )
2138    }
2139
2140    #[test]
2141    fn test_of_ranks_empty_set() {
2142        let slice = Slice::new_row_major([4]);
2143        let ranks = BTreeSet::new();
2144        let selection = Selection::of_ranks(&slice, &ranks).unwrap();
2145        assert_eq!(
2146            selection
2147                .eval(&EvalOpts::strict(), &slice)
2148                .unwrap()
2149                .collect::<Vec<_>>(),
2150            Vec::<usize>::new()
2151        )
2152    }
2153
2154    #[test]
2155    fn test_of_ranks_singleton_structural() {
2156        let slice = Slice::new_row_major([5]);
2157        let ranks = BTreeSet::from([2]);
2158        let actual = Selection::of_ranks(&slice, &ranks).unwrap();
2159        let expected = range(2..=2, true_());
2160        assert_structurally_eq!(&actual, &expected);
2161    }
2162
2163    #[test]
2164    fn test_of_ranks_union_2d_structural() {
2165        let slice = Slice::new_row_major([2, 3]);
2166        // [ [0, 1, 2],
2167        //   [3, 4, 5] ]
2168        // We'll select (0, 2), (1, 0) and (1, 1).
2169        let ranks = BTreeSet::from([2, 3, 4]);
2170        let actual = Selection::of_ranks(&slice, &ranks).unwrap();
2171        // Each rank becomes a nested selection:
2172        // 2 -> (0, 2) -> range(0, range(2, true_()))
2173        // 3 -> (1, 0) -> range(1, range(0, true_()))
2174        // 4 -> (1, 1) -> range(1, range(1, true_()))
2175        //
2176        // Their union is:
2177        let expected = union(
2178            union(range(0, range(2, true_())), range(1, range(0, true_()))),
2179            range(1, range(1, true_())),
2180        );
2181        assert_structurally_eq!(&actual, &expected);
2182    }
2183
2184    #[test]
2185    fn test_of_ranks_3d_structural() {
2186        let slice = Slice::new_row_major([2, 2, 2]);
2187        // [ [ [0, 1],
2188        //     [2, 3] ],
2189        //   [ [4, 5],
2190        //     [6, 7] ] ]
2191        let ranks = BTreeSet::from([1, 6]);
2192        let actual = Selection::of_ranks(&slice, &ranks).unwrap();
2193        let expected = union(
2194            range(0, range(0, range(1, true_()))), // (0, 0, 1)
2195            range(1, range(1, range(0, true_()))), // (1, 1, 0)
2196        );
2197        assert_structurally_eq!(&actual, &expected);
2198    }
2199
2200    #[test]
2201    fn test_of_ranks_invalid_index() {
2202        let slice = Slice::new_row_major([4]);
2203        let ranks = BTreeSet::from([0, 4]); // 4 is out of bounds
2204        assert!(
2205            Selection::of_ranks(&slice, &ranks).is_err(),
2206            "expected out-of-bounds error"
2207        );
2208    }
2209
2210    #[test]
2211    fn test_reify_slice_empty() {
2212        let slice = Slice::new_row_major([0]);
2213        let selection = slice.reify_slice(&slice).unwrap();
2214        let expected = false_();
2215        assert_structurally_eq!(&selection, expected);
2216        assert_eq!(
2217            selection
2218                .eval(&EvalOpts::lenient(), &slice)
2219                .unwrap()
2220                .collect::<Vec<_>>(),
2221            Vec::<usize>::new()
2222        );
2223    }
2224
2225    #[test]
2226    fn test_reify_slice_1d() {
2227        let shape = shape!(x = 6); // 1D shape with 6 elements
2228        let base = shape.slice();
2229
2230        let selected = select!(shape, x = 2..5).unwrap();
2231        let view = selected.slice();
2232
2233        let selection = base.reify_slice(view).unwrap();
2234        let expected = range(2..5, true_());
2235        assert_structurally_eq!(&selection, expected);
2236
2237        let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2238        assert_eq!(flat, vec![2, 3, 4]);
2239    }
2240
2241    #[test]
2242    fn test_reify_slice_2d() {
2243        let shape = shape!(x = 4, y = 5); // 2D shape: 4 rows, 5 columns
2244        let base = shape.slice();
2245
2246        // Select the middle 2x3 block: rows 1..3 and columns 2..5
2247        let selected = select!(shape, x = 1..3, y = 2..5).unwrap();
2248        let view = selected.slice();
2249        let selection = base.reify_slice(view).unwrap();
2250        let expected = range(1..3, range(2..5, true_()));
2251        assert_structurally_eq!(&selection, expected);
2252
2253        let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2254        assert_eq!(
2255            flat,
2256            vec![
2257                base.location(&[1, 2]).unwrap(),
2258                base.location(&[1, 3]).unwrap(),
2259                base.location(&[1, 4]).unwrap(),
2260                base.location(&[2, 2]).unwrap(),
2261                base.location(&[2, 3]).unwrap(),
2262                base.location(&[2, 4]).unwrap(),
2263            ]
2264        );
2265    }
2266
2267    #[test]
2268    #[allow(clippy::identity_op)]
2269    fn test_reify_slice_1d_with_stride() {
2270        let shape = shape!(x = 7); // 1D shape with 7 elements
2271        let selected = shape.select("x", Range(0, None, 2)).unwrap();
2272        let view = selected.slice();
2273        assert_eq!(view, &Slice::new(0, vec![4], vec![1 * 2]).unwrap());
2274
2275        let base = shape.slice();
2276        let selection = base.reify_slice(view).unwrap();
2277        // Note: ceil(7 / 2) = 4, hence end = 0 + 2 × 4 = 8. See the
2278        // more detailed explanation in
2279        // `test_reify_slice_2d_with_stride`.
2280        let expected = range(Range(0, Some(8), 2), true_());
2281        assert_structurally_eq!(&selection, expected);
2282
2283        let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2284        assert_eq!(flat, vec![0, 2, 4, 6]);
2285    }
2286
2287    #[test]
2288    #[allow(clippy::identity_op)]
2289    fn test_reify_slice_2d_with_stride() {
2290        // 4 x 4: x = 4, y = 4.
2291        let base = shape!(x = 4, y = 4);
2292        // Step 1: select odd rows (x = 1..4 step 2)
2293        let shape = base.select("x", Range(1, Some(4), 2)).unwrap();
2294        // Step 2: then select odd columns (y = 1..4 step 2)
2295        let shape = shape.select("y", Range(1, Some(4), 2)).unwrap();
2296        let view = shape.slice();
2297        assert_eq!(
2298            view,
2299            &Slice::new(5, vec![2, 2], vec![4 * 2, 1 * 2]).unwrap()
2300        );
2301
2302        let base = base.slice();
2303        let selection = base.reify_slice(view).unwrap();
2304        // We use `end = start + step * len` to reify the selection.
2305        // Note: This may yield `end > original_end` (e.g., 5 instead of 4)
2306        // when the selection length was computed via ceiling division.
2307        // This is safe: the resulting range will still select the correct
2308        // indices (e.g., 1 and 3 for Range(1, Some(5), 2)).
2309        let expected = range(Range(1, Some(5), 2), range(Range(1, Some(5), 2), true_()));
2310        assert_structurally_eq!(&selection, expected);
2311
2312        let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2313        assert_eq!(flat, vec![5, 7, 13, 15]);
2314    }
2315
2316    #[test]
2317    fn test_reify_slice_selects_column_across_rows() {
2318        let shape = shape!(host = 2, gpu = 4); // shape [2, 4]
2319        let base = shape.slice();
2320
2321        // Select the 3rd GPU (index 2) across both hosts
2322        let selected = select!(shape, gpu = 2).unwrap(); // (0, 2) and (1, 2)
2323        let view = selected.slice();
2324        let coordinates: Vec<_> = view.iter().map(|i| view.coordinates(i).unwrap()).collect();
2325        assert_eq!(coordinates, [[0, 0], [1, 0]]);
2326
2327        let selection = base.reify_slice(view).unwrap();
2328        let expected = range(0..2, range(2..3, true_()));
2329        assert_structurally_eq!(&selection, expected);
2330
2331        let actual = selection
2332            .eval(&EvalOpts::strict(), base)
2333            .unwrap()
2334            .collect::<Vec<_>>();
2335        assert_eq!(
2336            actual,
2337            vec![
2338                base.location(&[0, 2]).unwrap(),
2339                base.location(&[1, 2]).unwrap()
2340            ]
2341        );
2342    }
2343
2344    #[test]
2345    fn test_reify_slice_dimension_mismatch() {
2346        let shape = shape!(host = 2, gpu = 4);
2347        let base = shape.slice();
2348
2349        // Select the 3rd GPU (index 2) across both hosts i.e. flat
2350        // indices [2, 6]
2351        let indices = vec![
2352            base.location(&[0, 2]).unwrap(),
2353            base.location(&[1, 2]).unwrap(),
2354        ];
2355
2356        let view = Slice::new(indices[0], vec![indices.len()], vec![4]).unwrap();
2357        let selection = base.reify_slice(&view).unwrap();
2358
2359        let expected = Selection::of_ranks(base, &indices.iter().cloned().collect()).unwrap();
2360        assert_structurally_eq!(&selection, expected);
2361
2362        let actual: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2363        assert_eq!(actual, indices);
2364    }
2365
2366    #[test]
2367    fn test_union_of_slices_empty() {
2368        let base = Slice::new_row_major([2]);
2369        let sel = base.reify_slices(&[]).unwrap();
2370        assert_structurally_eq!(&sel, &false_());
2371        assert_eq!(
2372            sel.eval(&EvalOpts::strict(), &base)
2373                .unwrap()
2374                .collect::<Vec<_>>(),
2375            Vec::<usize>::new()
2376        );
2377    }
2378
2379    #[test]
2380    fn test_union_of_slices_singleton() {
2381        let shape = shape!(x = 3);
2382        let base = shape.slice();
2383        let selected = select!(shape, x = 1).unwrap();
2384        let view = selected.slice().clone();
2385
2386        let selection = base.reify_slices(&[view]).unwrap();
2387        let expected = range(1..=1, true_());
2388        assert_structurally_eq!(&selection, &expected);
2389
2390        assert_eq!(
2391            selection
2392                .eval(&EvalOpts::strict(), base)
2393                .unwrap()
2394                .collect::<Vec<_>>(),
2395            vec![1],
2396        );
2397    }
2398
2399    #[test]
2400    fn test_union_of_slices_disjoint() {
2401        let shape = shape!(x = 2, y = 2); // 2x2 grid
2402        let base = shape.slice();
2403
2404        // View A: (0, *)
2405        let a = select!(shape, x = 0).unwrap();
2406        let view_a = a.slice().clone();
2407
2408        // View B: (1, *)
2409        let b = select!(shape, x = 1).unwrap();
2410        let view_b = b.slice().clone();
2411
2412        let selection = base.reify_slices(&[view_a, view_b]).unwrap();
2413        let expected = union(
2414            range(0..1, range(0..2, true_())),
2415            range(1..2, range(0..2, true_())),
2416        );
2417        assert_structurally_eq!(&selection, &expected);
2418        assert_eq!(
2419            selection
2420                .eval(&EvalOpts::strict(), base)
2421                .unwrap()
2422                .collect::<Vec<_>>(),
2423            base.iter().collect::<Vec<_>>()
2424        );
2425    }
2426
2427    #[test]
2428    fn test_union_of_slices_overlapping() {
2429        let shape = shape!(x = 1, y = 4); // 1x4 grid
2430        let base = shape.slice();
2431
2432        let selected1 = select!(shape, y = 0..2).unwrap();
2433        let view1 = selected1.slice().clone();
2434
2435        let selected2 = select!(shape, y = 1..4).unwrap();
2436        let view2 = selected2.slice().clone();
2437
2438        let selection = base.reify_slices(&[view1, view2]).unwrap();
2439        let expected = union(
2440            range(0..1, range(0..2, true_())),
2441            range(0..1, range(1..4, true_())),
2442        );
2443        assert_structurally_eq!(&selection, &expected);
2444
2445        assert_eq!(
2446            selection
2447                .eval(&EvalOpts::strict(), base)
2448                .unwrap()
2449                .collect::<Vec<_>>(),
2450            base.iter().collect::<Vec<_>>()
2451        );
2452    }
2453
2454    #[test]
2455    fn test_canonicalize_to_dimensions() {
2456        assert_structurally_eq!(
2457            true_().canonicalize_to_dimensions(3),
2458            &all(all(all(true_())))
2459        );
2460        assert_structurally_eq!(
2461            all(true_()).canonicalize_to_dimensions(3),
2462            &all(all(all(true_())))
2463        );
2464        assert_structurally_eq!(
2465            all(all(true_())).canonicalize_to_dimensions(3),
2466            &all(all(all(true_())))
2467        );
2468        assert_structurally_eq!(
2469            all(all(all(true_()))).canonicalize_to_dimensions(3),
2470            &all(all(all(true_())))
2471        );
2472
2473        assert_structurally_eq!(
2474            false_().canonicalize_to_dimensions(3),
2475            &all(all(all(false_())))
2476        );
2477        assert_structurally_eq!(
2478            all(false_()).canonicalize_to_dimensions(3),
2479            &all(all(all(false_())))
2480        );
2481        assert_structurally_eq!(
2482            all(all(false_())).canonicalize_to_dimensions(3),
2483            &all(all(all(false_())))
2484        );
2485        assert_structurally_eq!(
2486            all(all(all(false_()))).canonicalize_to_dimensions(3),
2487            &all(all(all(false_())))
2488        );
2489
2490        assert_structurally_eq!(
2491            any(true_()).canonicalize_to_dimensions(3),
2492            &any(any(any(true_())))
2493        );
2494        assert_structurally_eq!(
2495            any(any(true_())).canonicalize_to_dimensions(3),
2496            &any(any(any(true_())))
2497        );
2498        assert_structurally_eq!(
2499            any(any(any(true_()))).canonicalize_to_dimensions(3),
2500            &any(any(any(true_())))
2501        );
2502
2503        // 0:1 -> 0:1, *, * <=> range(0..1, all(all(true_())))
2504        assert_structurally_eq!(
2505            range(0..1, true_()).canonicalize_to_dimensions(3),
2506            &range(0..1, all(all(true_())))
2507        );
2508        // *, 0:1 -> *, 0:1, * <=> all(range(0..1, all(true_())))
2509        assert_structurally_eq!(
2510            all(range(0..1, true_())).canonicalize_to_dimensions(3),
2511            &all(range(0..1, all(true_())))
2512        );
2513        // 0:1, ? -> 0:1, ?, ? <=> range(0..1, any(any(true_())))
2514        assert_structurally_eq!(
2515            range(0..1, any(true_())).canonicalize_to_dimensions(3),
2516            &range(0..1, any(any(true_())))
2517        );
2518        // 0:1, ?, * -> 0:1, ?, * <=> range(0..1, any(all(true_())))
2519        assert_structurally_eq!(
2520            range(0..1, any(all(true_()))).canonicalize_to_dimensions(3),
2521            &range(0..1, any(all(true_())))
2522        );
2523    }
2524}