ndslice/selection/
test_utils.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::collections::HashSet;
11use std::ops::ControlFlow;
12
13use nom::Parser as _;
14
15use crate::Slice;
16use crate::selection::Selection;
17use crate::selection::routing::RoutingAction;
18use crate::selection::routing::RoutingFrame;
19use crate::selection::routing::RoutingFrameKey;
20use crate::selection::routing::RoutingStep;
21use crate::selection::routing::resolve_routing;
22
23/// Parse an input string to a selection.
24pub fn parse(input: &str) -> Selection {
25    use nom::combinator::all_consuming;
26
27    use crate::selection::parse::expression;
28
29    let (_, selection) = all_consuming(expression).parse(input).unwrap();
30    selection
31}
32
33#[macro_export]
34macro_rules! assert_structurally_eq {
35    ($expected:expr_2021, $actual:expr_2021) => {{
36        let expected = &$expected;
37        let actual = &$actual;
38        assert!(
39            $crate::selection::structurally_equal(expected, actual),
40            "Selections do not match.\nExpected: {:#?}\nActual:   {:#?}",
41            expected,
42            actual,
43        );
44    }};
45}
46
47#[macro_export]
48macro_rules! assert_round_trip {
49    ($selection:expr_2021) => {{
50        let selection: Selection = $selection; // take ownership
51        // Convert `Selection` to representation as compact
52        // syntax.
53        let compact = $crate::selection::pretty::compact(&selection).to_string();
54        // Parse a `Selection` from the compact syntax
55        // representation.
56        let parsed = $crate::selection::test_utils::parse(&compact);
57        // Check that the input and parsed `Selection`s are
58        // structurally equivalent.
59        assert!(
60            $crate::selection::structurally_equal(&selection, &parsed),
61            "input: {} \n compact: {}\n parsed: {}",
62            selection,
63            compact,
64            parsed
65        );
66    }};
67}
68
69/// Determines whether routing frame deduplication is enabled.
70///
71/// By default, deduplication is enabled to reduce redundant routing
72/// steps and improve performance. However, correctness must not
73/// depend on deduplication.
74///
75/// This behavior can be disabled for debugging or testing purposes by
76/// setting the environment variable:
77/// ```ignore
78/// HYPERACTOR_SELECTION_DISABLE_ROUTING_FRAME_DEDUPLICATION = 1
79/// ```
80/// When disabled, all routing steps—including structurally redundant
81/// ones—will be visited, potentially causing re-entry into previously
82/// seen coordinates. This switch helps validate that correctness
83/// derives from the routing algebra itself—not from memoization or
84/// key-based filtering.
85fn allow_frame_dedup() -> bool {
86    // Default: true (deduplication via memoization and normalization
87    // is enabled unless explicitly disabled).
88    std::env::var("HYPERACTOR_SELECTION_DISABLE_ROUTING_FRAME_DEDUPLICATION")
89        .map_or(true, |val| val != "1")
90}
91
92// == Testing (`collect_routed_paths` mesh simulation) ===
93
94/// Message type used in the `collect_routed_paths` mesh routing
95/// simulation.
96///
97/// Each message tracks the current routing state (`frame`) and
98/// the full path (`path`) taken from the origin to the current
99/// node, represented as a list of flat indices.
100///
101/// As the message is forwarded, `path` is extended. This allows
102/// complete routing paths to be observed at the point of
103/// delivery.
104pub struct RoutedMessage<T> {
105    pub path: Vec<usize>,
106    pub frame: RoutingFrame,
107    pub _payload: std::marker::PhantomData<T>,
108}
109
110impl<T> RoutedMessage<T> {
111    pub fn new(path: Vec<usize>, frame: RoutingFrame) -> Self {
112        Self {
113            path,
114            frame,
115            _payload: std::marker::PhantomData,
116        }
117    }
118}
119
120#[derive(Default)]
121pub struct RoutedPathTree {
122    // Map from rank → delivery path (flat indices).
123    pub delivered: HashMap<usize, Vec<usize>>,
124
125    // Map from rank → set of direct predecessor ranks (flat
126    // indices).
127    pub predecessors: HashMap<usize, HashSet<usize>>,
128}
129
130/// Simulates routing from the origin through a slice using a
131/// `Selection`, collecting all delivery destinations **along with
132/// their routing paths**.
133//
134/// Each returned entry is a tuple `(dst, path)`, where `dst` is the
135/// flat index of a delivery node, and `path` is the list of flat
136/// indices representing the route taken from the origin to that node.
137//
138/// Routing begins at `[0, 0, ..., 0]` and proceeds
139/// dimension-by-dimension. At each hop, `next_steps` determines the
140/// next set of forwarding frames.
141//
142/// A node is considered a delivery target if:
143/// - its `selection` is `Selection::True`, and
144/// - it is at the final dimension.
145//
146///   Useful in tests for verifying full routing paths and ensuring
147///   correctness.
148pub fn collect_routed_paths(selection: &Selection, slice: &Slice) -> RoutedPathTree {
149    use std::collections::VecDeque;
150
151    let mut pending = VecDeque::new();
152    let mut delivered = HashMap::new();
153    let mut seen = HashSet::new();
154    let mut predecessors: HashMap<usize, HashSet<usize>> = HashMap::new();
155
156    let root_frame = RoutingFrame::root(selection.clone(), slice.clone());
157    let origin = slice.location(&root_frame.here).unwrap();
158    pending.push_back(RoutedMessage::<()>::new(vec![origin], root_frame));
159
160    while let Some(RoutedMessage { path, frame, .. }) = pending.pop_front() {
161        let mut visitor = |step: RoutingStep| {
162            if let RoutingStep::Forward(next_frame) = step {
163                let key = RoutingFrameKey::new(&next_frame);
164                let should_insert = if allow_frame_dedup() {
165                    seen.insert(key) // true → not seen before
166                } else {
167                    true // unconditionally insert
168                };
169                if should_insert {
170                    let next_rank = slice.location(&next_frame.here).unwrap();
171                    let parent_rank = *path.last().unwrap();
172                    predecessors
173                        .entry(next_rank)
174                        .or_default()
175                        .insert(parent_rank);
176
177                    let mut next_path = path.clone();
178                    next_path.push(next_rank);
179
180                    match next_frame.action() {
181                        RoutingAction::Deliver => {
182                            if let Some(previous) = delivered.insert(next_rank, next_path.clone()) {
183                                panic!(
184                                    "over-delivery detected: node {} delivered twice\nfirst: {:?}\nsecond: {:?}",
185                                    next_rank, previous, next_path
186                                );
187                            }
188                        }
189                        RoutingAction::Forward => {
190                            pending.push_back(RoutedMessage::new(next_path, next_frame));
191                        }
192                    }
193                }
194            }
195            ControlFlow::Continue(())
196        };
197
198        let _ = frame.next_steps(
199            &mut |_| panic!("Choice encountered in collect_routed_nodes"),
200            &mut visitor,
201        );
202    }
203
204    RoutedPathTree {
205        delivered,
206        predecessors,
207    }
208}
209
210/// Simulates routing from the origin and returns the set of
211/// destination nodes (as flat indices) selected by the
212/// `Selection`.
213///
214/// This function discards routing paths and retains only the
215/// final delivery targets. It is useful in tests to compare
216/// routing results against selection evaluation.
217pub fn collect_routed_nodes(selection: &Selection, slice: &Slice) -> Vec<usize> {
218    collect_routed_paths(selection, slice)
219        .delivered
220        .keys()
221        .cloned()
222        .collect()
223}
224
225// == Testing (`collect_commactor_routing_tree` mesh simulation) ===
226
227/// Captures the logical structure of a CommActor multicast operation.
228///
229/// This type models how a message is delivered and forwarded through
230/// a mesh under CommActor routing semantics. It is used in tests to
231/// verify path determinism and understand message propagation
232/// behavior.
233///
234/// - `delivered`: ranks where the message was delivered (`post`
235///   called)
236/// - `visited`: all ranks that participated, including forwarding
237///   only
238/// - `forwards`: maps each rank to the routing frames it forwarded
239#[derive(Default)]
240pub struct CommActorRoutingTree {
241    // Ranks that were delivered the message (i.e. called `post`). Map
242    // from rank → delivery path (flat rank indices) from root to that
243    // rank.
244    pub delivered: HashMap<usize, Vec<usize>>,
245
246    // Ranks that participated in the multicast - either by delivering
247    // the message or forwarding it to peers.
248    pub visited: HashSet<usize>,
249
250    /// Map from rank → routing frames this rank forwarded to other
251    /// ranks.
252    pub forwards: HashMap<usize, Vec<RoutingFrame>>,
253}
254
255/// Represents a routing step in the `collect_commactor_routing_tree`
256/// simulation.
257///
258/// Each instance models a message being forwarded from one rank to
259/// another, including the routing frames being propagated and the
260/// multicast path taken so far.
261///
262/// - `from`: the sender rank
263/// - `to`: the receiver rank
264/// - `frames`: routing frames to evaluate at the receiver
265/// - `path`: the multicast path from root to this step
266#[derive(Debug)]
267pub struct ForwardMessage {
268    /// The rank that is forwarding the message.
269    #[allow(dead_code)] // Never read.
270    pub from: usize,
271
272    /// The rank receiving the message.
273    pub to: usize,
274
275    /// The routing frames being forwarded.
276    pub frames: Vec<RoutingFrame>,
277
278    /// The multicast path taken so far.
279    pub path: Vec<usize>,
280}
281
282/// `collect_commactor_routing_tree` simulates how messages propagate
283/// through a mesh of `CommActor`s during multicast, reconstructing
284/// the full logical routing tree.
285///
286/// This function mirrors the behavior of `CommActor::handle_message`
287/// and `CommActor::forward`, using the shared `resolve_routing` logic
288/// to determine delivery and forwarding at each step. Starting from
289/// the root frame, it simulates how the message would be forwarded
290/// peer-to-peer through the system.
291///
292/// The returned `CommActorRoutingTree` includes:
293/// - `delivered`: ranks where the message would be delivered (i.e.,
294///   `post` called)
295/// - `visited`: all ranks that received or forwarded the message
296/// - `forwards`: frames forwarded from each rank to peers
297///
298/// This model is used in tests to validate routing behavior,
299/// especially invariants like path determinism and delivery coverage.
300pub fn collect_commactor_routing_tree(
301    selection: &Selection,
302    slice: &Slice,
303) -> CommActorRoutingTree {
304    use std::collections::VecDeque;
305
306    let mut pending = VecDeque::new();
307    let mut tree = CommActorRoutingTree::default();
308
309    let root_frame = RoutingFrame::root(selection.clone(), slice.clone());
310    let origin = slice.location(&root_frame.here).unwrap();
311    pending.push_back(ForwardMessage {
312        from: origin,
313        to: origin,
314        frames: vec![root_frame],
315        path: vec![origin],
316    });
317
318    while let Some(ForwardMessage {
319        from: _,
320        to: rank,
321        frames: dests,
322        path,
323    }) = pending.pop_front()
324    {
325        // This loop models the core of `CommActor::handle(...,
326        // fwd_message: ForwardMessage)` +
327        // `CommActor::handle_message(... next_steps...)`.
328        // - `resolve_routing` corresponds to the call in `handle`
329        // - delivery and forwarding match the logic in `handle_message`
330        // - each forward step simulates `CommActor::forward`
331
332        tree.visited.insert(rank);
333
334        let (deliver_here, forwards) =
335            resolve_routing(rank, dests, &mut |_| panic!("choice unexpected")).unwrap();
336
337        if deliver_here {
338            tree.delivered.insert(rank, path.clone());
339        }
340
341        let messages: Vec<_> = forwards
342            .into_iter()
343            .map(|(peer, peer_frames)| {
344                tree.forwards
345                    .entry(rank)
346                    .or_default()
347                    .extend(peer_frames.clone());
348
349                let mut peer_path = path.clone();
350                peer_path.push(peer);
351
352                ForwardMessage {
353                    from: rank,
354                    to: peer,
355                    frames: peer_frames,
356                    path: peer_path,
357                }
358            })
359            .collect();
360
361        for message in messages {
362            pending.push_back(message);
363        }
364    }
365
366    tree
367}