ndslice/selection/
normal.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::BTreeSet;
10
11use crate::Selection;
12use crate::selection::LabelKey;
13use crate::selection::SelectionSYM;
14use crate::selection::dsl;
15use crate::shape;
16
17/// A normalized form of `Selection`, used during canonicalization.
18///
19/// This structure uses `BTreeSet` for `Union` and `Intersection` to
20/// enable flattening, deduplication, and deterministic ordering.
21#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
22pub enum NormalizedSelection {
23    False,
24    True,
25    All(Box<NormalizedSelection>),
26    First(Box<NormalizedSelection>),
27    Range(shape::Range, Box<NormalizedSelection>),
28    Label(Vec<LabelKey>, Box<NormalizedSelection>),
29    Any(Box<NormalizedSelection>),
30    Union(BTreeSet<NormalizedSelection>),
31    Intersection(BTreeSet<NormalizedSelection>),
32}
33
34impl SelectionSYM for NormalizedSelection {
35    fn true_() -> Self {
36        Self::True
37    }
38
39    fn false_() -> Self {
40        Self::False
41    }
42
43    fn all(inner: Self) -> Self {
44        Self::All(Box::new(inner))
45    }
46
47    fn first(inner: Self) -> Self {
48        Self::First(Box::new(inner))
49    }
50
51    fn range<R: Into<shape::Range>>(range: R, inner: Self) -> Self {
52        Self::Range(range.into(), Box::new(inner))
53    }
54
55    fn label<L: Into<LabelKey>>(labels: Vec<L>, inner: Self) -> Self {
56        Self::Label(
57            labels.into_iter().map(Into::into).collect(),
58            Box::new(inner),
59        )
60    }
61
62    fn any(inner: Self) -> Self {
63        Self::Any(Box::new(inner))
64    }
65
66    fn intersection(lhs: Self, rhs: Self) -> Self {
67        let mut set = BTreeSet::new();
68        set.insert(lhs);
69        set.insert(rhs);
70        Self::Intersection(set)
71    }
72
73    fn union(lhs: Self, rhs: Self) -> Self {
74        let mut set = BTreeSet::new();
75        set.insert(lhs);
76        set.insert(rhs);
77        Self::Union(set)
78    }
79}
80
81impl NormalizedSelection {
82    /// Applies a transformation to each child node of the selection.
83    ///
84    /// This performs a single-layer traversal, applying `f` to each
85    /// immediate child and reconstructing the outer node with the
86    /// transformed children.
87    pub fn trav<F>(self, mut f: F) -> Self
88    where
89        F: FnMut(Self) -> Self,
90    {
91        use NormalizedSelection::*;
92
93        match self {
94            All(inner) => All(Box::new(f(*inner))),
95            First(inner) => First(Box::new(f(*inner))),
96            Any(inner) => Any(Box::new(f(*inner))),
97            Range(r, inner) => Range(r, Box::new(f(*inner))),
98            Label(labels, inner) => Label(labels, Box::new(f(*inner))),
99            Union(set) => Union(set.into_iter().map(f).collect()),
100            Intersection(set) => Intersection(set.into_iter().map(f).collect()),
101            leaf @ (True | False) => leaf,
102        }
103    }
104}
105
106/// A trait representing a single bottom-up rewrite rule on normalized
107/// selections.
108///
109/// Implementors define a transformation step applied after children
110/// have been rewritten. These rules are composed into normalization
111/// passes (see [`normalize`]) to simplify or canonicalize selection
112/// expressions.
113///
114/// This trait forms the basis for extensible normalization. Future
115/// systems may support top-down or contextual rewrites as well.
116pub trait RewriteRule {
117    /// Applies a rewrite step to a node whose children have already
118    /// been recursively rewritten.
119    fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection;
120}
121
122impl<R1: RewriteRule, R2: RewriteRule> RewriteRule for (R1, R2) {
123    fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection {
124        self.1.rewrite(self.0.rewrite(node))
125    }
126}
127
128/// Extension trait for composing rewrite rules in a fluent style.
129///
130/// This trait provides a `then` method that allows chaining rewrite
131/// rules together, creating a pipeline where rules are applied
132/// left-to-right.
133pub trait RewriteRuleExt: RewriteRule + Sized {
134    /// Chains this rule with another rule, creating a composite rule
135    /// that applies `self` first, then `other`.
136    fn then<R: RewriteRule>(self, other: R) -> (Self, R) {
137        (self, other)
138    }
139}
140
141impl<T: RewriteRule> RewriteRuleExt for T {}
142
143impl From<NormalizedSelection> for Selection {
144    /// Converts the normalized form back into a standard `Selection`.
145    ///
146    /// Logical semantics are preserved, but normalized shape (e.g.,
147    /// set-based unions and intersections) is reconstructed as
148    /// left-associated binary trees.
149    fn from(norm: NormalizedSelection) -> Self {
150        use NormalizedSelection::*;
151        use dsl::*;
152
153        match norm {
154            True => true_(),
155            False => false_(),
156            All(inner) => all((*inner).into()),
157            First(inner) => first((*inner).into()),
158            Any(inner) => any((*inner).into()),
159            Union(set) => set
160                .into_iter()
161                .map(Into::into)
162                .reduce(Selection::union)
163                .unwrap_or_else(false_),
164            Intersection(set) => set
165                .into_iter()
166                .map(Into::into)
167                .reduce(Selection::intersection)
168                .unwrap_or_else(true_),
169            Range(r, inner) => Selection::range(r, (*inner).into()),
170            Label(labels, inner) => Selection::label(labels, (*inner).into()),
171        }
172    }
173}
174
175/// A normalization rule that applies simple algebraic identities.
176#[derive(Default)]
177pub struct IdentityRules;
178
179impl RewriteRule for IdentityRules {
180    // Identity rewrites:
181    //
182    // - All(All(x))           → All(x)    // idempotence
183    // - All(True)             → True      // identity
184    // - All(False)            → False     // passthrough
185    // - Intersection(True, x) → x         // identity
186    // - Intersection({x})     → x         // trivial
187    // - Intersection({})      → True      // identity
188    // - Union(False, x)       → x         // identity
189    // - Union({x})            → x         // trivial
190    // - Union({})             → False     // trivial
191    //
192    // Absorbtion rules like `Union(True, x) → x` are handled in a
193    // different rewrite.
194    fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection {
195        use NormalizedSelection::*;
196
197        match node {
198            All(inner) => match *inner {
199                All(grandchild) => All(grandchild), // All(All(x)) → All(x)
200                True => True,                       // All(True) → True
201                False => False,                     // All(False) → False
202                _ => All(inner),
203            },
204
205            Intersection(mut set) => {
206                set.remove(&True); // Intersection(True, ...)  → ...
207                match set.len() {
208                    0 => True,
209                    1 => set.into_iter().next().unwrap(), // Intersection(x) → x
210                    _ => Intersection(set),
211                }
212            }
213
214            Union(mut set) => {
215                set.remove(&False); // Union(False, ...) → ...
216                match set.len() {
217                    0 => False,
218                    1 => set.into_iter().next().unwrap(), // Union(x) → x
219                    _ => Union(set),
220                }
221            }
222
223            _ => node,
224        }
225    }
226}
227
228/// A normalization rule that flattens nested unions and
229/// intersections.
230#[derive(Default)]
231pub struct FlatteningRules;
232
233impl RewriteRule for FlatteningRules {
234    // Flattening rewrites:
235    //
236    // - Union(a, Union(b, c))               → Union(a, b, c)           // flatten nested unions
237    // - Intersection(a, Intersection(b, c)) → Intersection(a, b, c)    // flatten nested intersections
238    fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection {
239        use NormalizedSelection::*;
240
241        match node {
242            Union(set) => {
243                let mut flattened = BTreeSet::new();
244                for item in set {
245                    match item {
246                        Union(inner_set) => {
247                            flattened.extend(inner_set);
248                        }
249                        other => {
250                            flattened.insert(other);
251                        }
252                    }
253                }
254                Union(flattened)
255            }
256            Intersection(set) => {
257                let mut flattened = BTreeSet::new();
258                for item in set {
259                    match item {
260                        Intersection(inner_set) => {
261                            flattened.extend(inner_set);
262                        }
263                        other => {
264                            flattened.insert(other);
265                        }
266                    }
267                }
268                Intersection(flattened)
269            }
270            _ => node,
271        }
272    }
273}
274
275/// A normalization rule that applies absorption laws for unions and
276/// intersections.
277///
278/// A union containing `True` always evaluates to `True`, and an
279/// intersection containing `False` always evaluates to `False`.
280#[derive(Default)]
281pub struct AbsorbtionRules;
282
283impl RewriteRule for AbsorbtionRules {
284    // Absorption rewrites:
285    //
286    // - Union(..., True, ...) → True
287    // - Intersection(..., False, ...) → False
288    fn rewrite(&self, node: NormalizedSelection) -> NormalizedSelection {
289        use NormalizedSelection::*;
290
291        match node {
292            Union(set) => {
293                if set.contains(&True) {
294                    True // Union(..., True, ...) → True
295                } else {
296                    Union(set)
297                }
298            }
299            Intersection(set) => {
300                if set.contains(&False) {
301                    False // Intersection(..., False, ...) → False
302                } else {
303                    Intersection(set)
304                }
305            }
306            other => other,
307        }
308    }
309}
310
311impl NormalizedSelection {
312    pub fn rewrite_bottom_up(self, rule: &impl RewriteRule) -> Self {
313        let mapped = self.trav(|child| child.rewrite_bottom_up(rule));
314        rule.rewrite(mapped)
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321    use crate::assert_structurally_eq;
322    use crate::selection;
323    use crate::selection::parse::parse;
324
325    /// Verifies that:
326    /// - Duplicate subtrees are structurally deduplicated by
327    ///   normalization
328    /// - The normalized form reifies to the expected `Selection` in
329    ///   this case
330    #[test]
331    fn normalization_deduplicates_and_reifies() {
332        let sel = parse("(* & *) | (* & *)").unwrap();
333        let norm = sel.fold::<NormalizedSelection>();
334
335        // Expected: Union { Intersection { All(True) } }
336        use NormalizedSelection::*;
337        let mut inner = BTreeSet::new();
338        inner.insert(All(Box::new(True)));
339
340        let mut outer = BTreeSet::new();
341        outer.insert(Intersection(inner));
342
343        assert_eq!(norm, Union(outer));
344
345        use selection::dsl::*;
346        let reified = norm.into();
347        let expected = all(true_());
348
349        assert_structurally_eq!(&reified, &expected);
350    }
351
352    #[test]
353    fn normalize_smoke_test() {
354        use crate::assert_structurally_eq;
355        use crate::selection::dsl::*;
356        use crate::selection::normalize;
357        use crate::selection::parse::parse;
358
359        // The expression (*,*) | (*,*) parses as
360        // Union(All(All(True)), All(All(True))) and normalizes all
361        // the way down to True.
362        let sel = parse("(*,*) | (*,*)").unwrap();
363        let normed = normalize(&sel);
364        let expected = true_();
365
366        assert_structurally_eq!(&normed.into(), &expected);
367    }
368
369    #[test]
370    fn test_union_flattening() {
371        use NormalizedSelection::*;
372
373        // Create Union(a, Union(b, c)) manually
374        let inner_union = {
375            let mut set = BTreeSet::new();
376            set.insert(All(Box::new(True))); // represents 'b'
377            set.insert(Any(Box::new(True))); // represents 'c'
378            Union(set)
379        };
380
381        let outer_union = {
382            let mut set = BTreeSet::new();
383            set.insert(First(Box::new(True))); // represents 'a'
384            set.insert(inner_union);
385            Union(set)
386        };
387
388        let rule = FlatteningRules;
389        let result = rule.rewrite(outer_union);
390
391        // Should be flattened to Union(a, b, c)
392        if let Union(set) = result {
393            assert_eq!(set.len(), 3);
394            assert!(set.contains(&First(Box::new(True))));
395            assert!(set.contains(&All(Box::new(True))));
396            assert!(set.contains(&Any(Box::new(True))));
397        } else {
398            panic!("Expected Union, got {:?}", result);
399        }
400    }
401
402    #[test]
403    fn test_intersection_flattening() {
404        use NormalizedSelection::*;
405
406        // Create Intersection(a, Intersection(b, c)) manually
407        let inner_intersection = {
408            let mut set = BTreeSet::new();
409            set.insert(All(Box::new(True))); // represents 'b'
410            set.insert(Any(Box::new(True))); // represents 'c'
411            Intersection(set)
412        };
413
414        let outer_intersection = {
415            let mut set = BTreeSet::new();
416            set.insert(First(Box::new(True))); // represents 'a'
417            set.insert(inner_intersection);
418            Intersection(set)
419        };
420
421        let rule = FlatteningRules;
422        let result = rule.rewrite(outer_intersection);
423
424        // Should be flattened to Intersection(a, b, c)
425        if let Intersection(set) = result {
426            assert_eq!(set.len(), 3);
427            assert!(set.contains(&First(Box::new(True))));
428            assert!(set.contains(&All(Box::new(True))));
429            assert!(set.contains(&Any(Box::new(True))));
430        } else {
431            panic!("Expected Intersection, got {:?}", result);
432        }
433    }
434}
435
436#[test]
437fn test_absorbtion_rules() {
438    use NormalizedSelection::*;
439
440    // Union(True, Any(True)) should absorb to True
441    let union_case = {
442        let mut set = BTreeSet::new();
443        set.insert(True);
444        set.insert(Any(Box::new(True)));
445        Union(set)
446    };
447
448    let rule = AbsorbtionRules;
449    let result = rule.rewrite(union_case);
450    assert_eq!(result, True);
451
452    // Intersection(False, All(True)) should absorb to False
453    let intersection_case = {
454        let mut set = BTreeSet::new();
455        set.insert(False);
456        set.insert(All(Box::new(True)));
457        Intersection(set)
458    };
459
460    let result = rule.rewrite(intersection_case);
461    assert_eq!(result, False);
462}