1pub mod parse;
78
79pub mod pretty;
87
88pub mod token_parser;
92
93pub mod routing;
95
96pub mod normal;
98
99pub mod test_utils;
100
101use std::collections::BTreeSet;
102use std::collections::HashMap;
103use std::collections::HashSet;
104use std::fmt;
105
106use rand::Rng;
107use serde::Deserialize;
108use serde::Serialize;
109
110use crate::Slice;
111use crate::selection::normal::NormalizedSelection;
112use crate::selection::normal::RewriteRuleExt;
113use crate::shape;
114use crate::shape::ShapeError;
115use crate::slice::SliceError;
116
117pub trait SelectionSYM {
122 fn false_() -> Self;
124
125 fn true_() -> Self;
127
128 fn all(selection: Self) -> Self;
131
132 fn first(selection: Self) -> Self;
135
136 fn range<R: Into<shape::Range>>(range: R, selection: Self) -> Self;
139
140 fn label<L: Into<LabelKey>>(labels: Vec<L>, selection: Self) -> Self;
143
144 fn any(selection: Self) -> Self;
147
148 fn intersection(lhs: Self, selection: Self) -> Self;
150
151 fn union(lhs: Self, selection: Self) -> Self;
153}
154
155pub mod dsl {
158
159 use super::LabelKey;
160 use super::Selection;
161 use super::SelectionSYM;
162 use crate::shape;
163
164 pub fn false_() -> Selection {
165 SelectionSYM::false_()
166 }
167 pub fn true_() -> Selection {
168 SelectionSYM::true_()
169 }
170 pub fn all(inner: Selection) -> Selection {
171 SelectionSYM::all(inner)
172 }
173 pub fn first(inner: Selection) -> Selection {
174 SelectionSYM::first(inner)
175 }
176 pub fn range<R: Into<shape::Range>>(r: R, inner: Selection) -> Selection {
177 SelectionSYM::range(r, inner)
178 }
179 pub fn label<L: Into<LabelKey>>(labels: Vec<L>, inner: Selection) -> Selection {
180 SelectionSYM::label(labels, inner)
181 }
182 pub fn any(inner: Selection) -> Selection {
183 SelectionSYM::any(inner)
184 }
185 pub fn intersection(lhs: Selection, rhs: Selection) -> Selection {
186 SelectionSYM::intersection(lhs, rhs)
187 }
188 pub fn union(lhs: Selection, rhs: Selection) -> Selection {
189 SelectionSYM::union(lhs, rhs)
190 }
191}
192
193impl SelectionSYM for Selection {
194 fn false_() -> Self {
195 ast::false_()
196 }
197 fn true_() -> Self {
198 ast::true_()
199 }
200 fn all(selection: Self) -> Self {
201 ast::all(selection)
202 }
203 fn first(selection: Self) -> Self {
204 ast::first(selection)
205 }
206 fn range<R: Into<shape::Range>>(range: R, selection: Self) -> Self {
207 ast::range(range, selection)
208 }
209 fn label<L: Into<LabelKey>>(labels: Vec<L>, selection: Selection) -> Selection {
210 let labels = labels.into_iter().map(|l| l.into()).collect();
211 Selection::Label(labels, Box::new(selection))
212 }
213 fn any(selection: Self) -> Self {
214 ast::any(selection)
215 }
216 fn intersection(lhs: Self, rhs: Self) -> Self {
217 ast::intersection(lhs, rhs)
218 }
219 fn union(lhs: Self, rhs: Self) -> Self {
220 ast::union(lhs, rhs)
221 }
222}
223
224impl fmt::Display for Selection {
225 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
226 write!(f, "{}", pretty::pretty(self))
227 }
228}
229
230#[derive(
248 Clone,
249 Debug,
250 PartialEq,
251 Eq,
252 Hash,
253 Serialize,
254 Deserialize,
255 PartialOrd,
256 Ord
257)]
258pub enum LabelKey {
259 Value(String),
261}
262
263impl From<String> for LabelKey {
264 fn from(s: String) -> Self {
265 LabelKey::Value(s)
266 }
267}
268
269impl From<&str> for LabelKey {
270 fn from(s: &str) -> Self {
271 LabelKey::Value(s.to_string())
272 }
273}
274
275impl std::fmt::Display for LabelKey {
276 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
277 match self {
278 LabelKey::Value(s) => write!(f, "\"{}\"", s),
279 }
280 }
281}
282
283#[derive(Debug, Clone, Serialize, Deserialize)]
285#[non_exhaustive]
286pub enum Selection {
287 False,
289
290 True,
292
293 All(Box<Selection>),
296
297 First(Box<Selection>),
300
301 Range(shape::Range, Box<Selection>),
304
305 Label(Vec<LabelKey>, Box<Selection>),
308
309 Any(Box<Selection>),
312
313 Intersection(Box<Selection>, Box<Selection>),
315
316 Union(Box<Selection>, Box<Selection>),
318}
319
320fn _assert_selection_traits()
323where
324 Selection: Send + Sync + 'static,
325{
326}
327
328pub fn structurally_equal(a: &Selection, b: &Selection) -> bool {
334 use Selection::*;
335 match (a, b) {
336 (False, False) => true,
337 (True, True) => true,
338 (All(x), All(y)) => structurally_equal(x, y),
339 (Any(x), Any(y)) => structurally_equal(x, y),
340 (First(x), First(y)) => structurally_equal(x, y),
341 (Range(r1, x), Range(r2, y)) => r1 == r2 && structurally_equal(x, y),
342 (Intersection(x1, y1), Intersection(x2, y2)) => {
343 structurally_equal(x1, x2) && structurally_equal(y1, y2)
344 }
345 (Union(x1, y1), Union(x2, y2)) => structurally_equal(x1, x2) && structurally_equal(y1, y2),
346 _ => false,
347 }
348}
349
350pub fn normalize(sel: &Selection) -> NormalizedSelection {
362 let rule = normal::FlatteningRules
363 .then(normal::IdentityRules)
364 .then(normal::AbsorbtionRules);
365 sel.fold::<normal::NormalizedSelection>()
366 .rewrite_bottom_up(&rule)
367}
368
369#[derive(Debug, Clone, PartialEq, Eq, Hash)]
375pub struct NormalizedSelectionKey(NormalizedSelection);
376
377impl NormalizedSelectionKey {
378 pub fn new(sel: &Selection) -> Self {
381 Self(crate::selection::normalize(sel))
382 }
383
384 pub fn inner(&self) -> &NormalizedSelection {
386 &self.0
387 }
388
389 pub fn into_inner(self) -> NormalizedSelection {
391 self.0
392 }
393}
394
395mod ast {
396 use super::LabelKey;
397 use super::Selection;
398 use crate::shape;
399
400 pub(crate) fn false_() -> Selection {
401 Selection::False
402 }
403 pub(crate) fn true_() -> Selection {
404 Selection::True
405 }
406 pub(crate) fn all(selection: Selection) -> Selection {
407 Selection::All(Box::new(selection))
408 }
409 pub(crate) fn first(selection: Selection) -> Selection {
410 Selection::First(Box::new(selection))
411 }
412 pub(crate) fn range<R: Into<shape::Range>>(range: R, selection: Selection) -> Selection {
413 Selection::Range(range.into(), Box::new(selection))
414 }
415 #[allow(dead_code)] pub(crate) fn label<L: Into<LabelKey>>(labels: Vec<L>, selection: Selection) -> Selection {
417 let labels = labels.into_iter().map(Into::into).collect();
418 Selection::Label(labels, Box::new(selection))
419 }
420 pub(crate) fn any(selection: Selection) -> Selection {
421 Selection::Any(Box::new(selection))
422 }
423 pub(crate) fn intersection(lhs: Selection, rhs: Selection) -> Selection {
424 Selection::Intersection(Box::new(lhs), Box::new(rhs))
425 }
426 pub(crate) fn union(lhs: Selection, rhs: Selection) -> Selection {
427 Selection::Union(Box::new(lhs), Box::new(rhs))
428 }
429}
430
431pub struct EvalOpts {
434 pub disallow_empty_ranges: bool,
436
437 pub disallow_out_of_range: bool,
440
441 pub disallow_dynamic_selections: bool,
443}
444
445impl EvalOpts {
446 pub fn lenient() -> Self {
448 Self {
449 disallow_empty_ranges: false,
450 disallow_out_of_range: false,
451 disallow_dynamic_selections: false,
452 }
453 }
454
455 #[allow(dead_code)]
458 pub fn strict() -> Self {
459 Self {
460 disallow_empty_ranges: true,
461 disallow_out_of_range: true,
462 ..Self::lenient()
463 }
464 }
465}
466
467impl Selection {
468 pub(crate) fn validate(&self, opts: &EvalOpts, slice: &Slice) -> Result<&Self, ShapeError> {
469 let depth = 0;
470 self.validate_rec(opts, slice, self, depth).map(|_| self)
471 }
472
473 fn validate_rec(
474 &self,
475 opts: &EvalOpts,
476 slice: &Slice,
477 top: &Selection,
478 dim: usize,
479 ) -> Result<(), ShapeError> {
480 if dim == slice.num_dim() {
481 match self {
484 Selection::True | Selection::False => return Ok(()),
485 _ => {
486 return Err(ShapeError::SelectionTooDeep {
487 expr: top.clone(),
488 num_dim: slice.num_dim(),
489 });
490 }
491 }
492 }
493
494 match self {
495 Selection::False | Selection::True => Ok(()),
496 Selection::Range(range, s) => {
497 if range.is_empty() && opts.disallow_empty_ranges {
498 return Err(ShapeError::EmptyRange {
499 range: range.clone(),
500 });
501 } else {
502 if opts.disallow_out_of_range {
503 let size = slice.sizes()[dim];
504 let (min, _, _) = range.resolve(size);
505 if min >= size {
506 return Err(ShapeError::EmptyRange {
509 range: range.clone(),
510 });
511 }
512 }
513
514 s.validate_rec(opts, slice, top, dim + 1)?;
515 }
516
517 Ok(())
518 }
519 Selection::Any(s) => {
520 if opts.disallow_dynamic_selections {
521 return Err(ShapeError::SelectionDynamic { expr: top.clone() });
522 }
523 s.validate_rec(opts, slice, top, dim + 1)?;
524 Ok(())
525 }
526 Selection::All(s) | Selection::Label(_, s) | Selection::First(s) => {
527 s.validate_rec(opts, slice, top, dim + 1)?;
528 Ok(())
529 }
530 Selection::Intersection(a, b) | Selection::Union(a, b) => {
531 a.validate_rec(opts, slice, top, dim)?;
532 b.validate_rec(opts, slice, top, dim)?;
533 Ok(())
534 }
535 }
536 }
537
538 pub fn eval<'a>(
578 &self,
579 opts: &EvalOpts,
580 slice: &'a Slice,
581 ) -> Result<Box<dyn Iterator<Item = usize> + 'a>, ShapeError> {
582 if slice.num_dim() == 0 {
584 let slice = Slice::new(slice.offset(), vec![1], vec![1]).unwrap();
585 return Ok(Box::new(
586 self.validate(opts, &slice)?
587 .eval_rec(&slice, vec![0; 1], 0)
588 .collect::<Vec<_>>()
589 .into_iter(),
590 ));
591 }
592
593 Ok(self
594 .validate(opts, slice)?
595 .eval_rec(slice, vec![0; slice.num_dim()], 0))
596 }
597
598 fn eval_rec<'a>(
599 &self,
600 slice: &'a Slice,
601 env: Vec<usize>,
602 dim: usize,
603 ) -> Box<dyn Iterator<Item = usize> + 'a> {
604 if dim == slice.num_dim() {
605 match self {
606 Selection::True => return Box::new(std::iter::once(slice.location(&env).unwrap())),
607 Selection::False => return Box::new(std::iter::empty()),
608 _ => {
609 panic!("structural combinator {self:?} at leaf level (dim = {dim}))",);
610 }
611 }
612 }
613
614 use itertools;
615 use itertools::EitherOrBoth;
616
617 match self {
618 Selection::False => Box::new(std::iter::empty()),
619 Selection::True => Box::new((0..slice.sizes()[dim]).flat_map(move |i| {
620 let mut env = env.clone();
621 env[dim] = i;
622 Selection::True.eval_rec(slice, env, dim + 1)
623 })),
624 Selection::All(select) => {
625 let select = Box::clone(select);
626 Box::new((0..slice.sizes()[dim]).flat_map(move |i| {
627 let mut env = env.clone();
628 env[dim] = i;
629 select.eval_rec(slice, env, dim + 1)
630 }))
631 }
632 Selection::First(select) => {
633 let select = Box::clone(select);
634 Box::new(iterutils::first(slice.sizes()[dim], move |i| {
635 let mut env = env.clone();
636 env[dim] = i;
637 select.eval_rec(slice, env, dim + 1)
638 }))
639 }
640 Selection::Range(range, select) => {
641 let select = Box::clone(select);
642 let (min, max, step) = range.resolve(slice.sizes()[dim]);
643 Box::new((min..max).step_by(step).flat_map(move |i| {
644 let mut env = env.clone();
645 env[dim] = i;
646 select.eval_rec(slice, env, dim + 1)
647 }))
648 }
649
650 Selection::Label(labels, inner) => {
676 Self::eval_label(labels, inner, slice, env, dim )
677 }
678 Selection::Any(select) => {
679 let select = Box::clone(select);
680 let r = {
681 let upper = slice.sizes()[dim];
682 let mut rng = rand::thread_rng();
683 rng.gen_range(0..upper)
684 };
685 Box::new((r..r + 1).flat_map(move |i| {
686 let mut env = env.clone();
687 env[dim] = i;
688 select.eval_rec(slice, env, dim + 1)
689 }))
690 }
691 Selection::Intersection(a, b) => Box::new(
692 itertools::merge_join_by(
693 a.eval_rec(slice, env.clone(), dim),
694 b.eval_rec(slice, env.clone(), dim),
695 |x, y| x.cmp(y),
696 )
697 .filter_map(|either| match either {
698 EitherOrBoth::Both(x, _) => Some(x),
699 _ => None,
700 }),
701 ),
702 Selection::Union(a, b) => Box::new(
703 itertools::merge_join_by(
704 a.eval_rec(slice, env.clone(), dim),
705 b.eval_rec(slice, env.clone(), dim),
706 |x, y| x.cmp(y),
707 )
708 .map(|either| match either {
709 EitherOrBoth::Left(x) => x,
710 EitherOrBoth::Right(y) => y,
711 EitherOrBoth::Both(x, _) => x,
712 }),
713 ),
714 }
715 }
716
717 fn eval_label<'a>(
742 _labels: &[LabelKey],
743 inner: &Selection,
744 slice: &'a Slice,
745 env: Vec<usize>,
746 dim: usize,
747 ) -> Box<dyn Iterator<Item = usize> + 'a> {
749 match inner {
750 Selection::Any(sub_inner) => {
757 let matching: Vec<usize> = (0..slice.sizes()[dim])
758 .filter(|&i| {
759 let mut prefix = env.clone();
760 prefix[dim] = i;
761 true })
763 .collect();
764
765 if matching.is_empty() {
766 return Box::new(std::iter::empty());
767 }
768
769 let mut rng = rand::thread_rng();
770 let chosen = matching[rng.gen_range(0..matching.len())];
771
772 let mut coord = env;
773 coord[dim] = chosen;
774 sub_inner.eval_rec(slice, coord, dim + 1 )
775 }
776 _ => {
787 let iter = inner.eval_rec(slice, env.clone(), dim );
789 Box::new(iter.filter(move |&flat| {
790 let _coord = slice.coordinates(flat);
791 true }))
793 }
794 }
795 }
796
797 pub fn is_equivalent_to_true(mut sel: &Selection) -> bool {
826 while let Selection::All(inner) = sel {
827 sel = inner;
828 }
829 matches!(sel, Selection::True)
830 }
831
832 pub fn contains(&self, coords: &[usize]) -> bool {
844 self.contains_rec(coords, 0)
845 }
846
847 fn contains_rec(&self, coords: &[usize], dim: usize) -> bool {
848 if dim >= coords.len() {
849 return matches!(self, Selection::True);
850 }
851
852 match self {
853 Selection::False => false,
854 Selection::True => true,
855 Selection::All(inner) => inner.contains_rec(coords, dim + 1),
856 Selection::Range(range, inner) => {
857 let (min, max, step) = range.resolve(coords.len());
858 let index = coords[dim];
859 index >= min
860 && index < max
861 && (index - min) % step == 0
862 && inner.contains_rec(coords, dim + 1)
863 }
864 Selection::Intersection(a, b) => {
865 a.contains_rec(coords, dim) && b.contains_rec(coords, dim)
866 }
867 Selection::Union(a, b) => a.contains_rec(coords, dim) || b.contains_rec(coords, dim),
868 Selection::Label(_, _) | Selection::First(_) | Selection::Any(_) => {
869 unimplemented!()
870 }
871 }
872 }
873
874 pub fn reduce_intersection(self: Selection, b: Selection) -> Selection {
888 match (&self, &b) {
889 (Selection::False, _) | (_, Selection::False) => Selection::False,
890 (Selection::True, other) | (other, Selection::True) => other.clone(),
891 _ => Selection::Intersection(Box::new(self), Box::new(b)),
892 }
893 }
894
895 pub(crate) fn canonicalize_to_dimensions(self, num_dims: usize) -> Selection {
915 assert!(
916 num_dims > 0,
917 "canonicalize_to_dimensions requires num_dims > 0"
918 );
919 self.canonicalize_to_dimensions_rec(0, num_dims)
920 }
921
922 fn canonicalize_to_dimensions_rec(self, dim: usize, num_dims: usize) -> Selection {
923 use crate::selection::dsl::*;
924
925 match self {
926 Selection::True if dim < num_dims => {
927 let mut out = true_();
928 for _ in (dim..num_dims).rev() {
929 out = all(out);
930 }
931 out
932 }
933 Selection::False if dim < num_dims => {
934 let mut out = false_();
935 for _ in (dim..num_dims).rev() {
936 out = all(out);
937 }
938 out
939 }
940 Selection::Any(inner) if dim < num_dims && matches!(*inner, Selection::True) => {
941 let mut out = true_();
942 for _ in (dim..num_dims).rev() {
943 out = any(out);
944 }
945 out
946 }
947 Selection::All(inner) => all(inner.canonicalize_to_dimensions_rec(dim + 1, num_dims)),
948 Selection::Any(inner) => any(inner.canonicalize_to_dimensions_rec(dim + 1, num_dims)),
949 Selection::First(inner) => {
950 first(inner.canonicalize_to_dimensions_rec(dim + 1, num_dims))
951 }
952 Selection::Range(r, inner) => {
953 range(r, inner.canonicalize_to_dimensions_rec(dim + 1, num_dims))
954 }
955 Selection::Intersection(a, b) => intersection(
956 a.canonicalize_to_dimensions_rec(dim, num_dims),
957 b.canonicalize_to_dimensions_rec(dim, num_dims),
958 ),
959 Selection::Union(a, b) => union(
960 a.canonicalize_to_dimensions_rec(dim, num_dims),
961 b.canonicalize_to_dimensions_rec(dim, num_dims),
962 ),
963
964 other => other,
965 }
966 }
967
968 pub fn fold<S: SelectionSYM>(&self) -> S {
981 match self {
982 Selection::False => S::false_(),
983 Selection::True => S::true_(),
984 Selection::All(inner) => S::all(inner.fold::<S>()),
985 Selection::First(inner) => S::first(inner.fold::<S>()),
986 Selection::Range(r, inner) => S::range(r.clone(), inner.fold::<S>()),
987 Selection::Label(labels, inner) => S::label(labels.clone(), inner.fold::<S>()),
988 Selection::Any(inner) => S::any(inner.fold::<S>()),
989 Selection::Intersection(a, b) => S::intersection(a.fold::<S>(), b.fold::<S>()),
990 Selection::Union(a, b) => S::union(a.fold::<S>(), b.fold::<S>()),
991 }
992 }
993
994 pub fn difference<'a>(
1002 &self,
1003 opts: &EvalOpts,
1004 slice: &'a Slice,
1005 exclusions: &'a HashSet<usize>,
1006 ) -> Result<impl Iterator<Item = usize> + use<'a>, ShapeError> {
1007 Ok(self
1008 .eval(opts, slice)?
1009 .filter(move |idx| !exclusions.contains(idx)))
1010 }
1011
1012 pub fn without(
1023 &self,
1024 slice: &Slice,
1025 exclusions: &HashSet<usize>,
1026 ) -> Result<Selection, ShapeError> {
1027 let remaining = self
1028 .difference(&EvalOpts::strict(), slice, exclusions)?
1029 .collect::<BTreeSet<_>>();
1030 Ok(Selection::of_ranks(slice, &remaining)?)
1031 }
1032
1033 pub fn of_ranks(slice: &Slice, ranks: &BTreeSet<usize>) -> Result<Selection, SliceError> {
1055 let selections = ranks
1056 .iter()
1057 .map(|&i| {
1058 Ok(slice
1059 .coordinates(i)?
1060 .into_iter()
1061 .rev()
1062 .fold(dsl::true_(), |acc, i| dsl::range(i..=i, acc)))
1063 })
1064 .collect::<Result<Vec<_>, SliceError>>()?;
1065
1066 Ok(selections
1067 .into_iter()
1068 .reduce(dsl::union)
1069 .unwrap_or_else(dsl::false_))
1070 }
1071} mod sealed {
1074 pub trait Sealed {}
1075 impl Sealed for crate::slice::Slice {}
1076}
1077
1078pub trait ReifySlice: sealed::Sealed {
1086 fn reify_slice(&self, slice: &Slice) -> Result<Selection, SliceError>;
1089
1090 fn reify_slices<V: AsRef<[Slice]>>(&self, slices: V) -> Result<Selection, SliceError>;
1093}
1094
1095impl ReifySlice for Slice {
1096 fn reify_slice(&self, slice: &Slice) -> Result<Selection, SliceError> {
1128 if !self.is_contiguous() {
1130 return Err(SliceError::NonContiguous);
1131 }
1132
1133 if slice.is_empty() {
1134 return Ok(dsl::false_());
1135 }
1136
1137 if slice.num_dim() != self.num_dim()
1138 || slice.sizes().iter().zip(self.sizes()).any(|(&v, &s)| v > s)
1139 {
1140 return Selection::of_ranks(self, &slice.iter().collect::<BTreeSet<usize>>());
1141 }
1142
1143 let origin = self.coordinates(slice.offset())?;
1144 let mut acc = dsl::true_();
1145 for dim in (0..self.num_dim()).rev() {
1146 let start = origin[dim];
1147 let len = slice.sizes()[dim];
1148 let slice_stride = slice.strides()[dim];
1149 let base_stride = self.strides()[dim];
1150
1151 if slice_stride % base_stride == 0 {
1152 let step = slice_stride / base_stride;
1154 let end = start + step * len;
1155 if end - 1 > self.sizes()[dim] {
1157 let bad = origin
1158 .iter()
1159 .enumerate()
1160 .map(|(d, &x)| if d == dim { end } else { x })
1161 .collect::<Vec<_>>();
1162 return Err(SliceError::ValueNotInSlice {
1163 value: self.location(&bad).unwrap(),
1164 });
1165 }
1166 acc = dsl::range(crate::shape::Range(start, Some(end), step), acc);
1167 } else {
1168 return Selection::of_ranks(self, &slice.iter().collect::<BTreeSet<_>>());
1170 }
1171 }
1172
1173 Ok(acc)
1174 }
1175
1176 fn reify_slices<V: AsRef<[Slice]>>(&self, slices: V) -> Result<Selection, SliceError> {
1211 let slices = slices.as_ref();
1212 let mut selections = Vec::with_capacity(slices.len());
1213
1214 for slice in slices {
1215 if slice.is_empty() {
1216 continue;
1217 }
1218 selections.push(self.reify_slice(slice)?);
1219 }
1220
1221 let mut iter = selections.into_iter();
1222 let first = iter.next().unwrap_or_else(dsl::false_);
1223 Ok(iter.fold(first, dsl::union))
1224 }
1225}
1226
1227pub fn is_equivalent_true(sel: impl std::borrow::Borrow<Selection>) -> bool {
1229 Selection::is_equivalent_to_true(sel.borrow())
1230}
1231
1232mod iterutils {
1233 pub(crate) fn first<'a, F>(size: usize, mut mk_iter: F) -> impl Iterator<Item = usize> + 'a
1236 where
1237 F: FnMut(usize) -> Box<dyn Iterator<Item = usize> + 'a>,
1238 {
1239 (0..size)
1240 .find_map(move |i| {
1241 let mut iter = mk_iter(i).peekable();
1242 if iter.peek().is_some() {
1243 Some(iter)
1244 } else {
1245 None
1246 }
1247 })
1248 .into_iter()
1249 .flatten()
1250 }
1251}
1252
1253pub fn selection_from_one<'a, R>(
1288 shape: &shape::Shape,
1289 label: &'a str,
1290 rng: R,
1291) -> Result<Selection, ShapeError>
1292where
1293 R: Into<shape::Range>,
1294{
1295 use crate::selection::dsl;
1296
1297 let Some(pos) = shape.labels().iter().position(|l| l == label) else {
1298 return Err(ShapeError::InvalidLabels {
1299 labels: vec![label.to_string()],
1300 });
1301 };
1302
1303 let mut selection = dsl::range(rng.into(), dsl::true_());
1304 for _ in 0..pos {
1305 selection = dsl::all(selection)
1306 }
1307
1308 Ok(selection)
1309}
1310
1311pub fn selection_from<'a, R>(
1348 shape: &shape::Shape,
1349 constraints: &[(&'a str, R)],
1350) -> Result<Selection, ShapeError>
1351where
1352 R: Clone + Into<shape::Range> + 'a,
1353{
1354 use crate::selection::dsl::*;
1355
1356 let mut label_to_constraint = HashMap::new();
1357 for (label, r) in constraints {
1358 if !shape.labels().iter().any(|l| l == label) {
1359 return Err(ShapeError::InvalidLabels {
1360 labels: vec![label.to_string()],
1361 });
1362 }
1363 label_to_constraint.insert(*label, r.clone().into());
1364 }
1365
1366 let selection = shape.labels().iter().rev().fold(true_(), |acc, label| {
1367 if let Some(rng) = label_to_constraint.get(label.as_str()) {
1368 range(rng.clone(), acc)
1369 } else {
1370 all(acc)
1371 }
1372 });
1373
1374 Ok(selection)
1375}
1376
1377#[macro_export]
1423macro_rules! sel_from_shape {
1424 ($shape:expr_2021, $label:ident = $range:expr_2021) => {
1425 $crate::selection::selection_from_one($shape, stringify!($label), $range).unwrap()
1426 };
1427
1428 ($shape:expr_2021, $($label:ident = $val:literal),* $(,)?) => {
1429 $crate::selection::selection_from($shape,
1430 &[
1431 $((stringify!($label), $val..$val+1)),*
1432 ]).unwrap()
1433 };
1434
1435 ($shape:expr_2021, $($label:ident = $range:expr_2021),* $(,)?) => {
1436 $crate::selection::selection_from($shape, &[
1437 $((stringify!($label), $range)),*
1438 ]).unwrap()
1439 };
1440}
1441
1442#[cfg(test)]
1443mod tests {
1444 use std::assert_matches::assert_matches;
1445 use std::collections::BTreeSet;
1446
1447 use super::EvalOpts;
1448 use super::ReifySlice;
1449 use super::Selection;
1450 use super::dsl::*;
1451 use super::is_equivalent_true;
1452 use crate::Range;
1453 use crate::Slice;
1454 use crate::assert_structurally_eq;
1455 use crate::select;
1456 use crate::shape;
1457 use crate::shape::ShapeError;
1458
1459 fn test_slice() -> Slice {
1461 Slice::new(0usize, vec![2, 4, 8], vec![32, 8, 1]).unwrap()
1462 }
1463
1464 fn eval(expr: Selection, slice: &Slice) -> Vec<usize> {
1472 expr.eval(&EvalOpts::lenient(), slice).unwrap().collect()
1473 }
1474
1475 #[test]
1476 fn test_selection_00() {
1477 let slice = &test_slice();
1478
1479 assert!(eval(false_(), slice).is_empty());
1481 assert!(eval(all(false_()), slice).is_empty());
1482 assert!(eval(all(all(false_())), slice).is_empty());
1483
1484 assert_eq!((0..=63).collect::<Vec<_>>(), eval(true_(), slice));
1486 assert_eq!(eval(true_(), slice), eval(all(true_()), slice));
1487 assert_eq!(eval(all(true_()), slice), eval(all(all(true_())), slice));
1488
1489 assert_eq!(eval(true_(), slice), eval(all(all(all(true_()))), slice));
1492 assert!(eval(all(all(all(false_()))), slice).is_empty());
1493 }
1494
1495 #[test]
1496 fn test_selection_01() {
1497 let slice = &test_slice();
1498
1499 let expr = all(all(all(all(true_()))));
1502 let result = expr.validate(&EvalOpts::lenient(), slice);
1503 assert!(
1504 matches!(result, Err(ShapeError::SelectionTooDeep { .. })),
1505 "Unexpected: {:?}",
1506 result
1507 );
1508 assert_eq!(
1509 format!("{}", result.unwrap_err()),
1510 "selection `all(all(all(all(true_()))))` exceeds dimensionality 3"
1511 );
1512 }
1513
1514 #[test]
1515 fn test_selection_02() {
1516 let slice = &test_slice();
1517
1518 let select = range(0..=0, range(0..=0, range(0..=0, true_())));
1520 assert_eq!((0..=0).collect::<Vec<_>>(), eval(select, slice));
1521
1522 let select = range(1..=1, range(1..=1, range(1..=1, true_())));
1524 assert_eq!((41..=41).collect::<Vec<_>>(), eval(select, slice));
1525
1526 let select = all(range(0..=0, all(true_())));
1528 assert_eq!(
1529 (0..=7).chain(32..=39).collect::<Vec<_>>(),
1530 eval(select, slice)
1531 );
1532
1533 let select = all(range(1..=1, all(true_())));
1535 assert_eq!(
1536 (8..=15).chain(40..=47).collect::<Vec<_>>(),
1537 eval(select, slice)
1538 );
1539
1540 let select = all(all(range(0..4, true_())));
1542 assert_eq!(
1543 (0..=3)
1544 .chain(8..=11)
1545 .chain(16..=19)
1546 .chain(24..=27)
1547 .chain(32..=35)
1548 .chain(40..=43)
1549 .chain(48..=51)
1550 .chain(56..=59)
1551 .collect::<Vec<_>>(),
1552 eval(select, slice)
1553 );
1554
1555 let select = all(all(range(4..8, true_())));
1557 assert_eq!(
1558 (4..=7)
1559 .chain(12..=15)
1560 .chain(20..=23)
1561 .chain(28..=31)
1562 .chain(36..=39)
1563 .chain(44..=47)
1564 .chain(52..=55)
1565 .chain(60..=63)
1566 .collect::<Vec<_>>(),
1567 eval(select, slice)
1568 );
1569
1570 let select = all(all(range(shape::Range(1, None, 2), true_())));
1572 assert_eq!(
1573 (1..8)
1574 .step_by(2)
1575 .chain((9..16).step_by(2))
1576 .chain((17..24).step_by(2))
1577 .chain((25..32).step_by(2))
1578 .chain((33..40).step_by(2))
1579 .chain((41..48).step_by(2))
1580 .chain((49..56).step_by(2))
1581 .chain((57..64).step_by(2))
1582 .collect::<Vec<_>>(),
1583 eval(select, slice)
1584 );
1585 }
1586
1587 #[test]
1588 fn test_selection_03() {
1589 let slice = &test_slice();
1590
1591 assert_eq!(
1592 eval(intersection(true_(), true_()), slice),
1593 eval(true_(), slice)
1594 );
1595 assert_eq!(
1596 eval(intersection(true_(), false_()), slice),
1597 eval(false_(), slice)
1598 );
1599 assert_eq!(
1600 eval(intersection(false_(), true_()), slice),
1601 eval(false_(), slice)
1602 );
1603 assert_eq!(
1604 eval(intersection(false_(), false_()), slice),
1605 eval(false_(), slice)
1606 );
1607 assert_eq!(
1608 eval(
1609 intersection(
1610 all(all(range(0..=3, true_()))),
1611 all(all(range(4..=7, true_())))
1612 ),
1613 slice
1614 ),
1615 eval(false_(), slice)
1616 );
1617 assert_eq!(
1618 eval(intersection(true_(), all(all(range(4..8, true_())))), slice),
1619 eval(all(all(range(4..8, true_()))), slice)
1620 );
1621 assert_eq!(
1622 eval(
1623 intersection(
1624 all(all(range(0..=4, true_()))),
1625 all(all(range(4..=7, true_())))
1626 ),
1627 slice
1628 ),
1629 eval(all(all(range(4..=4, true_()))), slice)
1630 );
1631 }
1632
1633 #[test]
1634 fn test_selection_04() {
1635 let slice = &test_slice();
1636
1637 assert_eq!(eval(union(true_(), true_()), slice), eval(true_(), slice));
1638 assert_eq!(eval(union(false_(), true_()), slice), eval(true_(), slice));
1639 assert_eq!(eval(union(true_(), false_()), slice), eval(true_(), slice));
1640 assert_eq!(
1641 eval(union(false_(), false_()), slice),
1642 eval(false_(), slice)
1643 );
1644 assert_eq!(
1645 eval(
1646 union(
1647 all(all(range(0..4, true_()))),
1648 all(all(range(4.., true_())))
1649 ),
1650 slice
1651 ),
1652 eval(true_(), slice)
1653 );
1654
1655 let s = all(range(0..=0, range(0..4, true_())));
1658 let t = all(range(1..=1, range(4.., true_())));
1659 assert_eq!(
1660 (0..=3)
1661 .chain(12..=15)
1662 .chain(32..=35)
1663 .chain(44..=47)
1664 .collect::<Vec<_>>(),
1665 eval(union(s, t), slice)
1666 );
1667
1668 assert_eq!(
1670 (0..=1)
1672 .chain(6..=7)
1673 .chain(8..=9)
1675 .chain(14..=15)
1676 .chain(16..=17)
1678 .chain(22..=23)
1679 .chain(24..=25)
1681 .chain(30..=31)
1682 .chain(32..=33)
1684 .chain(38..=39)
1685 .chain(40..=41)
1687 .chain(46..=47)
1688 .chain(48..=49)
1690 .chain(54..=55)
1691 .chain(56..=57)
1693 .chain(62..=63)
1694 .collect::<Vec<_>>(),
1695 eval(
1696 all(all(union(range(0..2, true_()), range(6..8, true_())))),
1697 slice
1698 )
1699 );
1700
1701 assert_eq!(
1703 (1..8)
1704 .step_by(2)
1705 .chain((9..16).step_by(2))
1706 .chain((17..24).step_by(2))
1707 .chain((25..32).step_by(2))
1708 .chain((33..40).step_by(2))
1709 .chain((41..48).step_by(2))
1710 .chain((49..56).step_by(2))
1711 .chain((57..64).step_by(2))
1712 .collect::<Vec<_>>(),
1713 eval(
1714 all(all(union(
1715 range(shape::Range(1, Some(4), 2), true_()),
1716 range(shape::Range(5, Some(8), 2), true_())
1717 ))),
1718 slice
1719 )
1720 );
1721 assert_eq!(
1722 eval(
1723 all(all(union(
1724 range(shape::Range(1, Some(4), 2), true_()),
1725 range(shape::Range(5, Some(8), 2), true_())
1726 ))),
1727 slice
1728 ),
1729 eval(
1730 all(all(union(
1731 union(range(1..=1, true_()), range(3..=3, true_()),),
1732 union(range(5..=5, true_()), range(7..=7, true_()),),
1733 ))),
1734 slice
1735 ),
1736 );
1737 }
1738
1739 #[test]
1740 fn test_selection_05() {
1741 let slice = &test_slice();
1742
1743 assert!(eval(first(first(false_())), slice).is_empty());
1745 assert_eq!(vec![0], eval(first(first(range(0..1, true_()))), slice));
1747 assert_eq!(
1749 (0..8).collect::<Vec<_>>(),
1750 eval(first(first(true_())), slice)
1751 );
1752
1753 assert!(eval(first(first(first(false_()))), slice).is_empty());
1757 assert_eq!(vec![0], eval(first(first(first(true_()))), slice));
1759
1760 assert_eq!(
1762 (0..8).chain(32..40).collect::<Vec<_>>(),
1763 eval(all(first(true_())), slice)
1764 );
1765
1766 assert_eq!(
1768 (0..3).collect::<Vec<_>>(),
1769 eval(first(first(range(0..=2, true_()))), slice)
1770 );
1771 }
1772
1773 #[test]
1774 fn test_selection_06() {
1775 let slice = &test_slice();
1776
1777 let expr = first(first(first(first(true_()))));
1780 let result = expr.validate(&EvalOpts::lenient(), slice);
1781 assert!(
1782 matches!(result, Err(ShapeError::SelectionTooDeep { .. })),
1783 "Unexpected: {:?}",
1784 result
1785 );
1786 assert_eq!(
1787 format!("{}", result.unwrap_err()),
1788 "selection `first(first(first(first(true_()))))` exceeds dimensionality 3"
1789 );
1790 }
1791
1792 #[test]
1793 fn test_selection_07() {
1794 use crate::select;
1795 use crate::shape;
1796
1797 let s = shape!(host = 2, gpu = 8);
1799
1800 assert_eq!(
1802 select!(s, host = 1)
1803 .unwrap()
1804 .slice()
1805 .iter()
1806 .collect::<Vec<_>>(),
1807 eval(range(1..2, true_()), s.slice())
1808 );
1809
1810 assert_eq!(
1812 select!(s, gpu = 2..)
1813 .unwrap()
1814 .slice()
1815 .iter()
1816 .collect::<Vec<_>>(),
1817 eval(all(range(2.., true_())), s.slice())
1818 );
1819
1820 assert_eq!(
1822 select!(s, gpu = 3..5)
1823 .unwrap()
1824 .slice()
1825 .iter()
1826 .collect::<Vec<_>>(),
1827 eval(all(range(3..5, true_())), s.slice())
1828 );
1829
1830 assert_eq!(
1832 select!(s, gpu = 3..5, host = 1)
1833 .unwrap()
1834 .slice()
1835 .iter()
1836 .collect::<Vec<_>>(),
1837 eval(range(1..=1, range(3..5, true_())), s.slice())
1838 );
1839
1840 assert_matches!(
1842 select!(s, gpu = 1..1).unwrap_err(),
1843 ShapeError::EmptyRange {
1844 range: shape::Range(1, Some(1), 1)
1845 },
1846 );
1847 assert!(eval(all(range(1..1, true_())), s.slice()).is_empty());
1848
1849 assert_matches!(
1851 select!(s, gpu = 8).unwrap_err(),
1852 ShapeError::OutOfRange {
1853 range: shape::Range(8, Some(9), 1),
1854 dim,
1855 size: 8,
1856 } if dim == "gpu",
1857 );
1858 assert!(eval(all(range(8..8, true_())), s.slice()).is_empty());
1859 }
1860
1861 #[test]
1862 fn test_selection_08() {
1863 let s = &shape!(host = 2, gpu = 8);
1864
1865 assert_eq!(
1866 eval(range(1..2, true_()), s.slice()),
1867 eval(sel_from_shape!(s, host = 1), s.slice())
1868 );
1869
1870 assert_eq!(
1871 eval(all(range(2.., true_())), s.slice()),
1872 eval(sel_from_shape!(s, gpu = 2..), s.slice())
1873 );
1874
1875 assert_eq!(
1876 eval(all(range(3..5, true_())), s.slice()),
1877 eval(sel_from_shape!(s, gpu = 3..5), s.slice())
1878 );
1879
1880 assert_eq!(
1881 eval(range(1..2, range(3..5, true_())), s.slice()),
1882 eval(sel_from_shape!(s, host = 1..2, gpu = 3..5), s.slice())
1883 );
1884
1885 assert_eq!(
1886 eval(
1887 union(
1888 sel_from_shape!(s, host = 0..1, gpu = 0..4),
1889 sel_from_shape!(s, host = 1..2, gpu = 4..5)
1890 ),
1891 s.slice()
1892 ),
1893 eval(
1894 union(
1895 range(0..1, range(0..4, true_())),
1896 range(1..2, range(4..5, true_()))
1897 ),
1898 s.slice()
1899 )
1900 );
1901 }
1902
1903 #[test]
1904 fn test_selection_09() {
1905 let slice = &test_slice(); assert_eq!(eval(any(false_()), slice), eval(false_(), slice));
1909
1910 let res = eval(any(any(any(true_()))), slice);
1912 assert_eq!(res.len(), 1);
1913 assert!(res[0] < 64);
1914
1915 let res = eval(range(0, any(range(0..4, true_()))), slice);
1917 assert!((0..4).any(|host| res == eval(range(0, range(host, range(0..4, true_()))), slice)));
1918
1919 let res = eval(range(0, range(0, any(true_()))), slice);
1921 assert_eq!(res.len(), 1);
1922 assert!(res[0] < 8);
1923
1924 let res = eval(range(0, any(true_())), slice);
1926 assert!((0..4).any(|host| res == eval(range(0, range(host, true_())), slice)));
1927
1928 let res = eval(range(1, any(true_())), slice);
1930 assert!((0..4).any(|host| res == eval(range(1, range(host, true_())), slice)));
1931
1932 let mut res = vec![];
1934 while res.len() < 2 {
1935 res = eval(
1936 union(
1937 range(0, range(0, any(true_()))),
1938 range(0, range(0, any(true_()))),
1939 ),
1940 slice,
1941 );
1942 }
1943 assert_matches!(res.as_slice(), [i, j] if *i < *j && *i < 8 && *j < 8);
1944 }
1945
1946 #[test]
1947 fn test_eval_zero_dim_slice() {
1948 let slice_0d = Slice::new(1, vec![], vec![]).unwrap();
1949 assert_eq!(slice_0d.coordinates(1).unwrap(), vec![]);
1952
1953 assert_eq!(eval(true_(), &slice_0d), vec![1]);
1954 assert_eq!(eval(false_(), &slice_0d), vec![]);
1955 assert_eq!(eval(all(true_()), &slice_0d), vec![1]);
1956 assert_eq!(eval(all(false_()), &slice_0d), vec![]);
1957 assert_eq!(eval(union(true_(), true_()), &slice_0d), vec![1]);
1958 assert_eq!(eval(intersection(true_(), false_()), &slice_0d), vec![]);
1959 }
1960
1961 #[test]
1962 fn test_selection_10() {
1963 let slice = &test_slice();
1964 let opts = EvalOpts {
1965 disallow_dynamic_selections: true,
1966 ..EvalOpts::lenient()
1967 };
1968 let expr = any(any(any(true_())));
1969 let res = expr.validate(&opts, slice);
1970 assert_matches!(res, Err(ShapeError::SelectionDynamic { .. }));
1971 }
1972
1973 #[test]
1974 fn test_13() {
1975 assert!(is_equivalent_true(true_()));
1977 assert!(is_equivalent_true(all(true_())));
1978 assert!(is_equivalent_true(all(all(true_()))));
1979 assert!(is_equivalent_true(all(all(all(true_())))));
1980 assert!(is_equivalent_true(all(all(all(all(true_()))))));
1981 assert!(is_equivalent_true(all(all(all(all(all(true_())))))));
1982 assert!(!is_equivalent_true(false_()));
1985 assert!(!is_equivalent_true(union(true_(), true_())));
1986 assert!(!is_equivalent_true(range(0..=0, true_())));
1987 assert!(!is_equivalent_true(all(false_())));
1988 }
1989
1990 #[test]
1991 fn test_14() {
1992 use std::collections::HashSet;
1993
1994 use crate::selection::NormalizedSelectionKey;
1995 use crate::selection::dsl::*;
1996
1997 let a = all(all(true_()));
1998 let b = all(all(true_()));
1999
2000 let key_a = NormalizedSelectionKey::new(&a);
2001 let key_b = NormalizedSelectionKey::new(&b);
2002
2003 assert_eq!(key_a, key_b);
2005
2006 let mut set = HashSet::new();
2008 set.insert(key_a);
2009 assert!(set.contains(&key_b));
2010 }
2011
2012 #[test]
2013 fn test_contains_true() {
2014 let selection = true_();
2015 assert!(selection.contains(&[0, 0, 0]));
2016 assert!(selection.contains(&[1, 2, 3]));
2017 }
2018
2019 #[test]
2020 fn test_contains_false() {
2021 let selection = false_();
2022 assert!(!selection.contains(&[0, 0, 0]));
2023 assert!(!selection.contains(&[1, 2, 3]));
2024 }
2025
2026 #[test]
2027 fn test_contains_all() {
2028 let selection = all(true_());
2029 assert!(selection.contains(&[0, 0, 0]));
2030 assert!(selection.contains(&[1, 2, 3]));
2031 }
2032
2033 #[test]
2034 fn test_contains_range() {
2035 let selection = range(1..3, true_());
2036 assert!(selection.contains(&[1, 0, 0]));
2037 assert!(!selection.contains(&[3, 0, 0]));
2038 }
2039
2040 #[test]
2041 fn test_contains_intersection() {
2042 let selection = intersection(range(1..3, true_()), range(2..4, true_()));
2043 assert!(selection.contains(&[2, 0, 0]));
2044 assert!(!selection.contains(&[1, 0, 0]));
2045 }
2046
2047 #[test]
2048 fn test_contains_union() {
2049 let selection = union(range(1..2, true_()), range(3..4, true_()));
2050 assert!(selection.contains(&[1, 0, 0]));
2051 assert!(!selection.contains(&[2, 0, 0]));
2052 }
2053
2054 #[test]
2055 #[should_panic(expected = "not implemented")]
2056 fn test_contains_any() {
2057 let selection = any(true_());
2058 selection.contains(&[0, 0, 0]);
2059 }
2060
2061 #[test]
2062 #[should_panic(expected = "not implemented")]
2063 fn test_contains_label() {
2064 let selection = label(vec!["zone".to_string()], true_());
2065 selection.contains(&[1, 2, 3]);
2066 }
2067
2068 #[test]
2069 #[should_panic(expected = "not implemented")]
2070 fn test_contains_first() {
2071 let selection = first(true_());
2072 selection.contains(&[0, 0, 0]);
2073 }
2074
2075 #[test]
2076 fn test_difference_1d() {
2077 assert_eq!(
2078 true_()
2079 .difference(
2080 &EvalOpts::strict(),
2081 &Slice::new_row_major([5]),
2082 &[2usize, 4].into(),
2083 )
2084 .unwrap()
2085 .collect::<Vec<_>>(),
2086 vec![0, 1, 3]
2087 );
2088 }
2089
2090 #[test]
2091 fn test_difference_empty_selection() {
2092 assert_eq!(
2093 false_()
2094 .difference(
2095 &EvalOpts::strict(),
2096 &Slice::new_row_major([3]),
2097 &[0usize, 1].into(),
2098 )
2099 .unwrap()
2100 .collect::<Vec<_>>(),
2101 vec![]
2102 );
2103 }
2104
2105 #[test]
2106 fn test_difference_2d() {
2107 assert_eq!(
2111 all(all(true_()))
2112 .difference(
2113 &EvalOpts::strict(),
2114 &Slice::new_row_major([2, 3]),
2115 &[3usize, 4, 5].into(),
2116 )
2117 .unwrap()
2118 .collect::<Vec<_>>(),
2119 vec![0, 1, 2]
2120 );
2121 }
2122
2123 #[test]
2124 fn test_of_ranks_1d() {
2125 let slice = Slice::new_row_major([5]);
2126 let ranks = BTreeSet::from([1, 3]);
2127 let selection = Selection::of_ranks(&slice, &ranks).unwrap();
2128 assert_eq!(
2129 selection
2130 .eval(&EvalOpts::strict(), &slice)
2131 .unwrap()
2132 .collect::<Vec<_>>(),
2133 vec![1, 3]
2134 )
2135 }
2136
2137 #[test]
2138 fn test_of_ranks_empty_set() {
2139 let slice = Slice::new_row_major([4]);
2140 let ranks = BTreeSet::new();
2141 let selection = Selection::of_ranks(&slice, &ranks).unwrap();
2142 assert_eq!(
2143 selection
2144 .eval(&EvalOpts::strict(), &slice)
2145 .unwrap()
2146 .collect::<Vec<_>>(),
2147 vec![]
2148 )
2149 }
2150
2151 #[test]
2152 fn test_of_ranks_singleton_structural() {
2153 let slice = Slice::new_row_major([5]);
2154 let ranks = BTreeSet::from([2]);
2155 let actual = Selection::of_ranks(&slice, &ranks).unwrap();
2156 let expected = range(2..=2, true_());
2157 assert_structurally_eq!(&actual, &expected);
2158 }
2159
2160 #[test]
2161 fn test_of_ranks_union_2d_structural() {
2162 let slice = Slice::new_row_major([2, 3]);
2163 let ranks = BTreeSet::from([2, 3, 4]);
2167 let actual = Selection::of_ranks(&slice, &ranks).unwrap();
2168 let expected = union(
2175 union(range(0, range(2, true_())), range(1, range(0, true_()))),
2176 range(1, range(1, true_())),
2177 );
2178 assert_structurally_eq!(&actual, &expected);
2179 }
2180
2181 #[test]
2182 fn test_of_ranks_3d_structural() {
2183 let slice = Slice::new_row_major([2, 2, 2]);
2184 let ranks = BTreeSet::from([1, 6]);
2189 let actual = Selection::of_ranks(&slice, &ranks).unwrap();
2190 let expected = union(
2191 range(0, range(0, range(1, true_()))), range(1, range(1, range(0, true_()))), );
2194 assert_structurally_eq!(&actual, &expected);
2195 }
2196
2197 #[test]
2198 fn test_of_ranks_invalid_index() {
2199 let slice = Slice::new_row_major([4]);
2200 let ranks = BTreeSet::from([0, 4]); assert!(
2202 Selection::of_ranks(&slice, &ranks).is_err(),
2203 "expected out-of-bounds error"
2204 );
2205 }
2206
2207 #[test]
2208 fn test_reify_slice_empty() {
2209 let slice = Slice::new_row_major([0]);
2210 let selection = slice.reify_slice(&slice).unwrap();
2211 let expected = false_();
2212 assert_structurally_eq!(&selection, expected);
2213 assert_eq!(
2214 selection
2215 .eval(&EvalOpts::lenient(), &slice)
2216 .unwrap()
2217 .collect::<Vec<_>>(),
2218 vec![]
2219 );
2220 }
2221
2222 #[test]
2223 fn test_reify_slice_1d() {
2224 let shape = shape!(x = 6); let base = shape.slice();
2226
2227 let selected = select!(shape, x = 2..5).unwrap();
2228 let view = selected.slice();
2229
2230 let selection = base.reify_slice(view).unwrap();
2231 let expected = range(2..5, true_());
2232 assert_structurally_eq!(&selection, expected);
2233
2234 let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2235 assert_eq!(flat, vec![2, 3, 4]);
2236 }
2237
2238 #[test]
2239 fn test_reify_slice_2d() {
2240 let shape = shape!(x = 4, y = 5); let base = shape.slice();
2242
2243 let selected = select!(shape, x = 1..3, y = 2..5).unwrap();
2245 let view = selected.slice();
2246 let selection = base.reify_slice(view).unwrap();
2247 let expected = range(1..3, range(2..5, true_()));
2248 assert_structurally_eq!(&selection, expected);
2249
2250 let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2251 assert_eq!(
2252 flat,
2253 vec![
2254 base.location(&[1, 2]).unwrap(),
2255 base.location(&[1, 3]).unwrap(),
2256 base.location(&[1, 4]).unwrap(),
2257 base.location(&[2, 2]).unwrap(),
2258 base.location(&[2, 3]).unwrap(),
2259 base.location(&[2, 4]).unwrap(),
2260 ]
2261 );
2262 }
2263
2264 #[test]
2265 #[allow(clippy::identity_op)]
2266 fn test_reify_slice_1d_with_stride() {
2267 let shape = shape!(x = 7); let selected = shape.select("x", Range(0, None, 2)).unwrap();
2269 let view = selected.slice();
2270 assert_eq!(view, &Slice::new(0, vec![4], vec![1 * 2]).unwrap());
2271
2272 let base = shape.slice();
2273 let selection = base.reify_slice(view).unwrap();
2274 let expected = range(Range(0, Some(8), 2), true_());
2278 assert_structurally_eq!(&selection, expected);
2279
2280 let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2281 assert_eq!(flat, vec![0, 2, 4, 6]);
2282 }
2283
2284 #[test]
2285 #[allow(clippy::identity_op)]
2286 fn test_reify_slice_2d_with_stride() {
2287 let base = shape!(x = 4, y = 4);
2289 let shape = base.select("x", Range(1, Some(4), 2)).unwrap();
2291 let shape = shape.select("y", Range(1, Some(4), 2)).unwrap();
2293 let view = shape.slice();
2294 assert_eq!(
2295 view,
2296 &Slice::new(5, vec![2, 2], vec![4 * 2, 1 * 2]).unwrap()
2297 );
2298
2299 let base = base.slice();
2300 let selection = base.reify_slice(view).unwrap();
2301 let expected = range(Range(1, Some(5), 2), range(Range(1, Some(5), 2), true_()));
2307 assert_structurally_eq!(&selection, expected);
2308
2309 let flat: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2310 assert_eq!(flat, vec![5, 7, 13, 15]);
2311 }
2312
2313 #[test]
2314 fn test_reify_slice_selects_column_across_rows() {
2315 let shape = shape!(host = 2, gpu = 4); let base = shape.slice();
2317
2318 let selected = select!(shape, gpu = 2).unwrap(); let view = selected.slice();
2321 let coordinates: Vec<_> = view.iter().map(|i| view.coordinates(i).unwrap()).collect();
2322 assert_eq!(coordinates, [[0, 0], [1, 0]]);
2323
2324 let selection = base.reify_slice(view).unwrap();
2325 let expected = range(0..2, range(2..3, true_()));
2326 assert_structurally_eq!(&selection, expected);
2327
2328 let actual = selection
2329 .eval(&EvalOpts::strict(), base)
2330 .unwrap()
2331 .collect::<Vec<_>>();
2332 assert_eq!(
2333 actual,
2334 vec![
2335 base.location(&[0, 2]).unwrap(),
2336 base.location(&[1, 2]).unwrap()
2337 ]
2338 );
2339 }
2340
2341 #[test]
2342 fn test_reify_slice_dimension_mismatch() {
2343 let shape = shape!(host = 2, gpu = 4);
2344 let base = shape.slice();
2345
2346 let indices = vec![
2349 base.location(&[0, 2]).unwrap(),
2350 base.location(&[1, 2]).unwrap(),
2351 ];
2352
2353 let view = Slice::new(indices[0], vec![indices.len()], vec![4]).unwrap();
2354 let selection = base.reify_slice(&view).unwrap();
2355
2356 let expected = Selection::of_ranks(base, &indices.iter().cloned().collect()).unwrap();
2357 assert_structurally_eq!(&selection, expected);
2358
2359 let actual: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
2360 assert_eq!(actual, indices);
2361 }
2362
2363 #[test]
2364 fn test_union_of_slices_empty() {
2365 let base = Slice::new_row_major([2]);
2366 let sel = base.reify_slices(&[]).unwrap();
2367 assert_structurally_eq!(&sel, &false_());
2368 assert_eq!(
2369 sel.eval(&EvalOpts::strict(), &base)
2370 .unwrap()
2371 .collect::<Vec<_>>(),
2372 vec![]
2373 );
2374 }
2375
2376 #[test]
2377 fn test_union_of_slices_singleton() {
2378 let shape = shape!(x = 3);
2379 let base = shape.slice();
2380 let selected = select!(shape, x = 1).unwrap();
2381 let view = selected.slice().clone();
2382
2383 let selection = base.reify_slices(&[view]).unwrap();
2384 let expected = range(1..=1, true_());
2385 assert_structurally_eq!(&selection, &expected);
2386
2387 assert_eq!(
2388 selection
2389 .eval(&EvalOpts::strict(), base)
2390 .unwrap()
2391 .collect::<Vec<_>>(),
2392 vec![1],
2393 );
2394 }
2395
2396 #[test]
2397 fn test_union_of_slices_disjoint() {
2398 let shape = shape!(x = 2, y = 2); let base = shape.slice();
2400
2401 let a = select!(shape, x = 0).unwrap();
2403 let view_a = a.slice().clone();
2404
2405 let b = select!(shape, x = 1).unwrap();
2407 let view_b = b.slice().clone();
2408
2409 let selection = base.reify_slices(&[view_a, view_b]).unwrap();
2410 let expected = union(
2411 range(0..1, range(0..2, true_())),
2412 range(1..2, range(0..2, true_())),
2413 );
2414 assert_structurally_eq!(&selection, &expected);
2415 assert_eq!(
2416 selection
2417 .eval(&EvalOpts::strict(), base)
2418 .unwrap()
2419 .collect::<Vec<_>>(),
2420 base.iter().collect::<Vec<_>>()
2421 );
2422 }
2423
2424 #[test]
2425 fn test_union_of_slices_overlapping() {
2426 let shape = shape!(x = 1, y = 4); let base = shape.slice();
2428
2429 let selected1 = select!(shape, y = 0..2).unwrap();
2430 let view1 = selected1.slice().clone();
2431
2432 let selected2 = select!(shape, y = 1..4).unwrap();
2433 let view2 = selected2.slice().clone();
2434
2435 let selection = base.reify_slices(&[view1, view2]).unwrap();
2436 let expected = union(
2437 range(0..1, range(0..2, true_())),
2438 range(0..1, range(1..4, true_())),
2439 );
2440 assert_structurally_eq!(&selection, &expected);
2441
2442 assert_eq!(
2443 selection
2444 .eval(&EvalOpts::strict(), base)
2445 .unwrap()
2446 .collect::<Vec<_>>(),
2447 base.iter().collect::<Vec<_>>()
2448 );
2449 }
2450
2451 #[test]
2452 fn test_canonicalize_to_dimensions() {
2453 assert_structurally_eq!(
2454 true_().canonicalize_to_dimensions(3),
2455 &all(all(all(true_())))
2456 );
2457 assert_structurally_eq!(
2458 all(true_()).canonicalize_to_dimensions(3),
2459 &all(all(all(true_())))
2460 );
2461 assert_structurally_eq!(
2462 all(all(true_())).canonicalize_to_dimensions(3),
2463 &all(all(all(true_())))
2464 );
2465 assert_structurally_eq!(
2466 all(all(all(true_()))).canonicalize_to_dimensions(3),
2467 &all(all(all(true_())))
2468 );
2469
2470 assert_structurally_eq!(
2471 false_().canonicalize_to_dimensions(3),
2472 &all(all(all(false_())))
2473 );
2474 assert_structurally_eq!(
2475 all(false_()).canonicalize_to_dimensions(3),
2476 &all(all(all(false_())))
2477 );
2478 assert_structurally_eq!(
2479 all(all(false_())).canonicalize_to_dimensions(3),
2480 &all(all(all(false_())))
2481 );
2482 assert_structurally_eq!(
2483 all(all(all(false_()))).canonicalize_to_dimensions(3),
2484 &all(all(all(false_())))
2485 );
2486
2487 assert_structurally_eq!(
2488 any(true_()).canonicalize_to_dimensions(3),
2489 &any(any(any(true_())))
2490 );
2491 assert_structurally_eq!(
2492 any(any(true_())).canonicalize_to_dimensions(3),
2493 &any(any(any(true_())))
2494 );
2495 assert_structurally_eq!(
2496 any(any(any(true_()))).canonicalize_to_dimensions(3),
2497 &any(any(any(true_())))
2498 );
2499
2500 assert_structurally_eq!(
2502 range(0..1, true_()).canonicalize_to_dimensions(3),
2503 &range(0..1, all(all(true_())))
2504 );
2505 assert_structurally_eq!(
2507 all(range(0..1, true_())).canonicalize_to_dimensions(3),
2508 &all(range(0..1, all(true_())))
2509 );
2510 assert_structurally_eq!(
2512 range(0..1, any(true_())).canonicalize_to_dimensions(3),
2513 &range(0..1, any(any(true_())))
2514 );
2515 assert_structurally_eq!(
2517 range(0..1, any(all(true_()))).canonicalize_to_dimensions(3),
2518 &range(0..1, any(all(true_())))
2519 );
2520 }
2521}