ndslice/selection/
token_parser.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! A `TokenStream` to [`Selection`] parser used by the `sel!`
10//! procedural macro.
11//!
12//! This module implements a compile-time parser that converts a
13//! [`proc_macro2::TokenStream`] into a [`Selection`] syntax tree.
14//!
15//! The grammar and interpretation of selection expressions are the
16//! same as those described in the [`parse`] module. See that module
17//! for full documentation of the syntax and semantics.
18//!
19//! See [`parse_tokens`] for the entry point, and
20//! [`selection_to_tokens`] for the inverse.
21
22use std::iter::Peekable;
23
24use proc_macro2::Delimiter;
25use proc_macro2::TokenStream;
26use proc_macro2::TokenTree;
27use quote::quote;
28
29use crate::selection::Selection;
30use crate::selection::dsl;
31use crate::shape;
32
33// Selection expressions grammar:
34// ```text
35// expression ::= union
36// union      ::= intersection ('|' intersection)*
37// intersection ::= dimension ('&' dimension)*
38// dimension  ::= group (',' group)*
39// group      ::= range | index | * | ? | (expression)
40// ```
41
42/// Parses a [`proc_macro2::TokenStream`] representing a selection
43/// expression into a [`Selection`] syntax tree.
44///
45/// This is intended for use at compile time in the `sel!` procedural macro,
46/// and is the inverse of [`selection_to_tokens`].
47pub fn parse_tokens(tokens: TokenStream) -> Result<Selection, String> {
48    let mut iter = tokens.into_iter().peekable();
49    parse_expression(&mut iter)
50}
51
52pub fn selection_to_tokens(sel: &Selection) -> proc_macro2::TokenStream {
53    match sel {
54        Selection::True => quote!(Selection::True),
55        Selection::False => quote!(Selection::False),
56        Selection::All(inner) => {
57            let inner = selection_to_tokens(inner);
58            quote!(Selection::All(Box::new(#inner)))
59        }
60        Selection::Range(r, inner) => {
61            let start = r.0;
62            let end = match &r.1 {
63                Some(e) => quote!(Some(#e)),
64                None => quote!(None),
65            };
66            let step = r.2;
67            let inner = selection_to_tokens(inner);
68            quote! {
69                ::ndslice::selection::Selection::Range(
70                    ::ndslice::shape::Range(#start, #end, #step),
71                    Box::new(#inner)
72                )
73            }
74        }
75        Selection::Any(inner) => {
76            let inner = selection_to_tokens(inner);
77            quote!(Selection::Any(Box::new(#inner)))
78        }
79        Selection::Intersection(a, b) => {
80            let a = selection_to_tokens(a);
81            let b = selection_to_tokens(b);
82            quote!(Selection::Intersection(Box::new(#a), Box::new(#b)))
83        }
84        Selection::Union(a, b) => {
85            let a = selection_to_tokens(a);
86            let b = selection_to_tokens(b);
87            quote!(Selection::Union(Box::new(#a), Box::new(#b)))
88        }
89        _ => unimplemented!(),
90    }
91}
92
93fn parse_expression<I>(tokens: &mut Peekable<I>) -> Result<Selection, String>
94where
95    I: Iterator<Item = TokenTree>,
96{
97    let mut lhs = parse_intersection(tokens)?;
98    while let Some(TokenTree::Punct(p)) = tokens.peek() {
99        if p.as_char() == '|' {
100            tokens.next(); // consume |
101            let rhs = parse_intersection(tokens)?;
102            lhs = dsl::union(lhs, rhs);
103        } else {
104            break;
105        }
106    }
107    Ok(lhs)
108}
109
110fn parse_intersection<I>(tokens: &mut Peekable<I>) -> Result<Selection, String>
111where
112    I: Iterator<Item = TokenTree>,
113{
114    let mut lhs = parse_dimensions(tokens)?;
115    while let Some(TokenTree::Punct(p)) = tokens.peek() {
116        if p.as_char() == '&' {
117            tokens.next(); // consume &
118            let rhs = parse_dimensions(tokens)?;
119            lhs = dsl::intersection(lhs, rhs);
120        } else {
121            break;
122        }
123    }
124    Ok(lhs)
125}
126
127fn parse_dimensions<I>(tokens: &mut Peekable<I>) -> Result<Selection, String>
128where
129    I: Iterator<Item = TokenTree>,
130{
131    let mut dims = vec![];
132
133    loop {
134        dims.push(parse_atom(tokens)?);
135
136        match tokens.peek() {
137            Some(TokenTree::Punct(p)) if p.as_char() == ',' => {
138                tokens.next(); // consume comma
139            }
140            _ => break,
141        }
142    }
143
144    let mut result = dsl::true_();
145
146    for dim in dims.into_iter().rev() {
147        result = apply_dimension_chain(dim, result)?;
148    }
149
150    Ok(result)
151}
152
153fn apply_dimension_chain(sel: Selection, tail: Selection) -> Result<Selection, String> {
154    Ok(match sel {
155        Selection::All(inner) => dsl::all(apply_dimension_chain(*inner, tail)?),
156        Selection::Any(inner) => dsl::any(apply_dimension_chain(*inner, tail)?),
157        Selection::Range(r, inner) => dsl::range(r, apply_dimension_chain(*inner, tail)?),
158        Selection::Union(a, b) => dsl::union(
159            apply_dimension_chain(*a, tail.clone())?,
160            apply_dimension_chain(*b, tail)?,
161        ),
162        Selection::Intersection(a, b) => dsl::intersection(
163            apply_dimension_chain(*a, tail.clone())?,
164            apply_dimension_chain(*b, tail)?,
165        ),
166        Selection::True => tail,
167        Selection::False => dsl::false_(),
168        other => {
169            return Err(format!(
170                "unexpected selection type in dimension chain: {:?}",
171                other
172            ));
173        }
174    })
175}
176
177fn parse_atom<I>(tokens: &mut Peekable<I>) -> Result<Selection, String>
178where
179    I: Iterator<Item = TokenTree>,
180{
181    match tokens.peek() {
182        Some(TokenTree::Punct(p)) if p.as_char() == '*' => {
183            tokens.next();
184            Ok(dsl::all(dsl::true_()))
185        }
186        Some(TokenTree::Punct(p)) if p.as_char() == '?' => {
187            tokens.next();
188            Ok(dsl::any(dsl::true_()))
189        }
190        Some(TokenTree::Punct(p)) if p.as_char() == ':' => {
191            tokens.next(); // consume ':'
192
193            // Optional end
194            let end = match tokens.peek() {
195                Some(TokenTree::Literal(_lit)) => {
196                    let lit = tokens
197                        .next()
198                        .ok_or_else(|| "expected literal after ':'".to_string())?;
199                    Some(
200                        lit.to_string()
201                            .parse::<usize>()
202                            .map_err(|e| e.to_string())?,
203                    )
204                }
205                _ => None,
206            };
207
208            // Optional step
209            let step = match tokens.peek() {
210                Some(TokenTree::Punct(p)) if p.as_char() == ':' => {
211                    tokens.next(); // consume second ':'
212                    let lit = match tokens.next() {
213                        Some(TokenTree::Literal(lit)) => lit,
214                        other => return Err(format!("expected step after ::, got {:?}", other)),
215                    };
216                    lit.to_string()
217                        .parse::<usize>()
218                        .map_err(|e| e.to_string())?
219                }
220                _ => 1,
221            };
222
223            Ok(dsl::range(shape::Range(0, end, step), dsl::true_()))
224        }
225        Some(TokenTree::Literal(_)) => {
226            // literal-prefixed range or index
227            parse_range_or_index(tokens)
228        }
229        Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Parenthesis => {
230            let group = tokens.next().unwrap(); // consume group
231            let mut inner = match group {
232                TokenTree::Group(g) => g.stream().into_iter().peekable(),
233                _ => unreachable!(),
234            };
235            parse_expression(&mut inner)
236        }
237        Some(t) => Err(format!("unexpected token: {:?}", t)),
238        None => Err("unexpected end of input".to_string()),
239    }
240}
241
242fn parse_range_or_index<I>(tokens: &mut Peekable<I>) -> Result<Selection, String>
243where
244    I: Iterator<Item = TokenTree>,
245{
246    // Peek and parse the start literal
247    let start_lit = match tokens.next() {
248        Some(TokenTree::Literal(lit)) => lit,
249        other => return Err(format!("expected number, got {:?}", other)),
250    };
251
252    let start = start_lit
253        .to_string()
254        .parse::<usize>()
255        .map_err(|e| format!("invalid number: {}", e))?;
256
257    // Check if this is a range by looking for a colon
258    if let Some(TokenTree::Punct(p)) = tokens.peek() {
259        if p.as_char() == ':' {
260            tokens.next(); // consume ':'
261
262            // Try to parse optional end
263            let end = match tokens.peek() {
264                Some(TokenTree::Literal(_lit)) => {
265                    let lit = tokens.next().unwrap();
266                    Some(
267                        lit.to_string()
268                            .parse::<usize>()
269                            .map_err(|e| format!("invalid range end: {}", e))?,
270                    )
271                }
272                Some(TokenTree::Punct(p)) if p.as_char() == ':' => None,
273                _ => None,
274            };
275
276            // Try to parse optional step
277            let step = match tokens.peek() {
278                Some(TokenTree::Punct(p)) if p.as_char() == ':' => {
279                    tokens.next(); // consume second ':'
280                    let lit = tokens.next().ok_or("expected number for step after ::")?;
281                    match lit {
282                        TokenTree::Literal(lit) => lit
283                            .to_string()
284                            .parse::<usize>()
285                            .map_err(|e| format!("invalid step size: {}", e))?,
286                        other => return Err(format!("expected literal for step, got {:?}", other)),
287                    }
288                }
289                _ => 1,
290            };
291
292            Ok(dsl::range(shape::Range(start, end, step), dsl::true_()))
293        } else {
294            // Not a range, treat as index
295            Ok(dsl::range(
296                shape::Range(start, Some(start + 1), 1),
297                dsl::true_(),
298            ))
299        }
300    } else {
301        // No colon → single index
302        Ok(dsl::range(
303            shape::Range(start, Some(start + 1), 1),
304            dsl::true_(),
305        ))
306    }
307}