1use std::collections::BTreeSet;
10
11use crate::Selection;
12use crate::selection::LabelKey;
13use crate::selection::SelectionSYM;
14use crate::selection::dsl;
15use crate::shape;
16
17#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
22pub enum NormalizedSelection {
23 False,
24 True,
25 All(Box<NormalizedSelection>),
26 First(Box<NormalizedSelection>),
27 Range(shape::Range, Box<NormalizedSelection>),
28 Label(Vec<LabelKey>, Box<NormalizedSelection>),
29 Any(Box<NormalizedSelection>),
30 Union(BTreeSet<NormalizedSelection>),
31 Intersection(BTreeSet<NormalizedSelection>),
32}
33
34impl SelectionSYM for NormalizedSelection {
35 fn true_() -> Self {
36 Self::True
37 }
38
39 fn false_() -> Self {
40 Self::False
41 }
42
43 fn all(inner: Self) -> Self {
44 Self::All(Box::new(inner))
45 }
46
47 fn first(inner: Self) -> Self {
48 Self::First(Box::new(inner))
49 }
50
51 fn range<R: Into<shape::Range>>(range: R, inner: Self) -> Self {
52 Self::Range(range.into(), Box::new(inner))
53 }
54
55 fn label<L: Into<LabelKey>>(labels: Vec<L>, inner: Self) -> Self {
56 Self::Label(
57 labels.into_iter().map(Into::into).collect(),
58 Box::new(inner),
59 )
60 }
61
62 fn any(inner: Self) -> Self {
63 Self::Any(Box::new(inner))
64 }
65
66 fn intersection(lhs: Self, rhs: Self) -> Self {
67 let mut set = BTreeSet::new();
68 set.insert(lhs);
69 set.insert(rhs);
70 Self::Intersection(set)
71 }
72
73 fn union(lhs: Self, rhs: Self) -> Self {
74 let mut set = BTreeSet::new();
75 set.insert(lhs);
76 set.insert(rhs);
77 Self::Union(set)
78 }
79}
80
81impl NormalizedSelection {
82 pub fn trav<F>(self, mut f: F) -> Self
88 where
89 F: FnMut(Self) -> Self,
90 {
91 use NormalizedSelection::*;
92
93 match self {
94 All(inner) => All(Box::new(f(*inner))),
95 First(inner) => First(Box::new(f(*inner))),
96 Any(inner) => Any(Box::new(f(*inner))),
97 Range(r, inner) => Range(r, Box::new(f(*inner))),
98 Label(labels, inner) => Label(labels, Box::new(f(*inner))),
99 Union(set) => Union(set.into_iter().map(f).collect()),
100 Intersection(set) => Intersection(set.into_iter().map(f).collect()),
101 leaf @ (True | False) => leaf,
102 }
103 }
104}
105
106pub trait RewriteRule {
117 fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection;
120}
121
122impl<R1: RewriteRule, R2: RewriteRule> RewriteRule for (R1, R2) {
123 fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection {
124 self.1.rewrite(self.0.rewrite(node))
125 }
126}
127
128pub trait RewriteRuleExt: RewriteRule + Sized {
134 fn then<R: RewriteRule>(self, other: R) -> (Self, R) {
137 (self, other)
138 }
139}
140
141impl<T: RewriteRule> RewriteRuleExt for T {}
142
143impl From<NormalizedSelection> for Selection {
144 fn from(norm: NormalizedSelection) -> Self {
150 use NormalizedSelection::*;
151 use dsl::*;
152
153 match norm {
154 True => true_(),
155 False => false_(),
156 All(inner) => all((*inner).into()),
157 First(inner) => first((*inner).into()),
158 Any(inner) => any((*inner).into()),
159 Union(set) => set
160 .into_iter()
161 .map(Into::into)
162 .reduce(Selection::union)
163 .unwrap_or_else(false_),
164 Intersection(set) => set
165 .into_iter()
166 .map(Into::into)
167 .reduce(Selection::intersection)
168 .unwrap_or_else(true_),
169 Range(r, inner) => Selection::range(r, (*inner).into()),
170 Label(labels, inner) => Selection::label(labels, (*inner).into()),
171 }
172 }
173}
174
175#[derive(Default)]
177pub struct IdentityRules;
178
179impl RewriteRule for IdentityRules {
180 fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection {
195 use NormalizedSelection::*;
196
197 match node {
198 All(inner) => match *inner {
199 All(grandchild) => All(grandchild), True => True, False => False, _ => All(inner),
203 },
204
205 Intersection(mut set) => {
206 set.remove(&True); match set.len() {
208 0 => True,
209 1 => set.into_iter().next().unwrap(), _ => Intersection(set),
211 }
212 }
213
214 Union(mut set) => {
215 set.remove(&False); match set.len() {
217 0 => False,
218 1 => set.into_iter().next().unwrap(), _ => Union(set),
220 }
221 }
222
223 _ => node,
224 }
225 }
226}
227
228#[derive(Default)]
231pub struct FlatteningRules;
232
233impl RewriteRule for FlatteningRules {
234 fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection {
239 use NormalizedSelection::*;
240
241 match node {
242 Union(set) => {
243 let mut flattened = BTreeSet::new();
244 for item in set {
245 match item {
246 Union(inner_set) => {
247 flattened.extend(inner_set);
248 }
249 other => {
250 flattened.insert(other);
251 }
252 }
253 }
254 Union(flattened)
255 }
256 Intersection(set) => {
257 let mut flattened = BTreeSet::new();
258 for item in set {
259 match item {
260 Intersection(inner_set) => {
261 flattened.extend(inner_set);
262 }
263 other => {
264 flattened.insert(other);
265 }
266 }
267 }
268 Intersection(flattened)
269 }
270 _ => node,
271 }
272 }
273}
274
275#[derive(Default)]
281pub struct AbsorbtionRules;
282
283impl RewriteRule for AbsorbtionRules {
284 fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection {
289 use NormalizedSelection::*;
290
291 match node {
292 Union(set) => {
293 if set.contains(&True) {
294 True } else {
296 Union(set)
297 }
298 }
299 Intersection(set) => {
300 if set.contains(&False) {
301 False } else {
303 Intersection(set)
304 }
305 }
306 other => other,
307 }
308 }
309}
310
311impl NormalizedSelection {
312 pub fn rewrite_bottom_up(self, rule: &impl RewriteRule) -> Self {
313 let mapped = self.trav(|child| child.rewrite_bottom_up(rule));
314 rule.rewrite(mapped)
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321 use crate::assert_structurally_eq;
322 use crate::selection;
323 use crate::selection::parse::parse;
324
325 #[test]
331 fn normalization_deduplicates_and_reifies() {
332 let sel = parse("(* & *) | (* & *)").unwrap();
333 let norm = sel.fold::<NormalizedSelection>();
334
335 use NormalizedSelection::*;
337 let mut inner = BTreeSet::new();
338 inner.insert(All(Box::new(True)));
339
340 let mut outer = BTreeSet::new();
341 outer.insert(Intersection(inner));
342
343 assert_eq!(norm, Union(outer));
344
345 use selection::dsl::*;
346 let reified = norm.into();
347 let expected = all(true_());
348
349 assert_structurally_eq!(&reified, &expected);
350 }
351
352 #[test]
353 fn normalize_smoke_test() {
354 use crate::assert_structurally_eq;
355 use crate::selection::dsl::*;
356 use crate::selection::normalize;
357 use crate::selection::parse::parse;
358
359 let sel = parse("(*,*) | (*,*)").unwrap();
363 let normed = normalize(&sel);
364 let expected = true_();
365
366 assert_structurally_eq!(&normed.into(), &expected);
367 }
368
369 #[test]
370 fn test_union_flattening() {
371 use NormalizedSelection::*;
372
373 let inner_union = {
375 let mut set = BTreeSet::new();
376 set.insert(All(Box::new(True))); set.insert(Any(Box::new(True))); Union(set)
379 };
380
381 let outer_union = {
382 let mut set = BTreeSet::new();
383 set.insert(First(Box::new(True))); set.insert(inner_union);
385 Union(set)
386 };
387
388 let rule = FlatteningRules;
389 let result = rule.rewrite(outer_union);
390
391 if let Union(set) = result {
393 assert_eq!(set.len(), 3);
394 assert!(set.contains(&First(Box::new(True))));
395 assert!(set.contains(&All(Box::new(True))));
396 assert!(set.contains(&Any(Box::new(True))));
397 } else {
398 panic!("Expected Union, got {:?}", result);
399 }
400 }
401
402 #[test]
403 fn test_intersection_flattening() {
404 use NormalizedSelection::*;
405
406 let inner_intersection = {
408 let mut set = BTreeSet::new();
409 set.insert(All(Box::new(True))); set.insert(Any(Box::new(True))); Intersection(set)
412 };
413
414 let outer_intersection = {
415 let mut set = BTreeSet::new();
416 set.insert(First(Box::new(True))); set.insert(inner_intersection);
418 Intersection(set)
419 };
420
421 let rule = FlatteningRules;
422 let result = rule.rewrite(outer_intersection);
423
424 if let Intersection(set) = result {
426 assert_eq!(set.len(), 3);
427 assert!(set.contains(&First(Box::new(True))));
428 assert!(set.contains(&All(Box::new(True))));
429 assert!(set.contains(&Any(Box::new(True))));
430 } else {
431 panic!("Expected Intersection, got {:?}", result);
432 }
433 }
434}
435
436#[test]
437fn test_absorbtion_rules() {
438 use NormalizedSelection::*;
439
440 let union_case = {
442 let mut set = BTreeSet::new();
443 set.insert(True);
444 set.insert(Any(Box::new(True)));
445 Union(set)
446 };
447
448 let rule = AbsorbtionRules;
449 let result = rule.rewrite(union_case);
450 assert_eq!(result, True);
451
452 let intersection_case = {
454 let mut set = BTreeSet::new();
455 set.insert(False);
456 set.insert(All(Box::new(True)));
457 Intersection(set)
458 };
459
460 let result = rule.rewrite(intersection_case);
461 assert_eq!(result, False);
462}