ndslice/selection/test_utils.rs
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10use std::collections::HashSet;
11use std::ops::ControlFlow;
12
13use nom::Parser as _;
14
15use crate::Slice;
16use crate::selection::Selection;
17use crate::selection::routing::RoutingAction;
18use crate::selection::routing::RoutingFrame;
19use crate::selection::routing::RoutingFrameKey;
20use crate::selection::routing::RoutingStep;
21use crate::selection::routing::resolve_routing;
22
23/// Parse an input string to a selection.
24pub fn parse(input: &str) -> Selection {
25 use nom::combinator::all_consuming;
26
27 use crate::selection::parse::expression;
28
29 let (_, selection) = all_consuming(expression).parse(input).unwrap();
30 selection
31}
32
33#[macro_export]
34macro_rules! assert_structurally_eq {
35 ($expected:expr_2021, $actual:expr_2021) => {{
36 let expected = &$expected;
37 let actual = &$actual;
38 assert!(
39 $crate::selection::structurally_equal(expected, actual),
40 "Selections do not match.\nExpected: {:#?}\nActual: {:#?}",
41 expected,
42 actual,
43 );
44 }};
45}
46
47#[macro_export]
48macro_rules! assert_round_trip {
49 ($selection:expr_2021) => {{
50 let selection: Selection = $selection; // take ownership
51 // Convert `Selection` to representation as compact
52 // syntax.
53 let compact = $crate::selection::pretty::compact(&selection).to_string();
54 // Parse a `Selection` from the compact syntax
55 // representation.
56 let parsed = $crate::selection::test_utils::parse(&compact);
57 // Check that the input and parsed `Selection`s are
58 // structurally equivalent.
59 assert!(
60 $crate::selection::structurally_equal(&selection, &parsed),
61 "input: {} \n compact: {}\n parsed: {}",
62 selection,
63 compact,
64 parsed
65 );
66 }};
67}
68
69/// Determines whether routing frame deduplication is enabled.
70///
71/// By default, deduplication is enabled to reduce redundant routing
72/// steps and improve performance. However, correctness must not
73/// depend on deduplication.
74///
75/// This behavior can be disabled for debugging or testing purposes by
76/// setting the environment variable:
77/// ```ignore
78/// HYPERACTOR_SELECTION_DISABLE_ROUTING_FRAME_DEDUPLICATION = 1
79/// ```
80/// When disabled, all routing steps—including structurally redundant
81/// ones—will be visited, potentially causing re-entry into previously
82/// seen coordinates. This switch helps validate that correctness
83/// derives from the routing algebra itself—not from memoization or
84/// key-based filtering.
85fn allow_frame_dedup() -> bool {
86 // Default: true (deduplication via memoization and normalization
87 // is enabled unless explicitly disabled).
88 std::env::var("HYPERACTOR_SELECTION_DISABLE_ROUTING_FRAME_DEDUPLICATION")
89 .map_or(true, |val| val != "1")
90}
91
92// == Testing (`collect_routed_paths` mesh simulation) ===
93
94/// Message type used in the `collect_routed_paths` mesh routing
95/// simulation.
96///
97/// Each message tracks the current routing state (`frame`) and
98/// the full path (`path`) taken from the origin to the current
99/// node, represented as a list of flat indices.
100///
101/// As the message is forwarded, `path` is extended. This allows
102/// complete routing paths to be observed at the point of
103/// delivery.
104pub struct RoutedMessage<T> {
105 pub path: Vec<usize>,
106 pub frame: RoutingFrame,
107 pub _payload: std::marker::PhantomData<T>,
108}
109
110impl<T> RoutedMessage<T> {
111 pub fn new(path: Vec<usize>, frame: RoutingFrame) -> Self {
112 Self {
113 path,
114 frame,
115 _payload: std::marker::PhantomData,
116 }
117 }
118}
119
120#[derive(Default)]
121pub struct RoutedPathTree {
122 // Map from rank → delivery path (flat indices).
123 pub delivered: HashMap<usize, Vec<usize>>,
124
125 // Map from rank → set of direct predecessor ranks (flat
126 // indices).
127 pub predecessors: HashMap<usize, HashSet<usize>>,
128}
129
130/// Simulates routing from the origin through a slice using a
131/// `Selection`, collecting all delivery destinations **along with
132/// their routing paths**.
133//
134/// Each returned entry is a tuple `(dst, path)`, where `dst` is the
135/// flat index of a delivery node, and `path` is the list of flat
136/// indices representing the route taken from the origin to that node.
137//
138/// Routing begins at `[0, 0, ..., 0]` and proceeds
139/// dimension-by-dimension. At each hop, `next_steps` determines the
140/// next set of forwarding frames.
141//
142/// A node is considered a delivery target if:
143/// - its `selection` is `Selection::True`, and
144/// - it is at the final dimension.
145//
146/// Useful in tests for verifying full routing paths and ensuring
147/// correctness.
148pub fn collect_routed_paths(selection: &Selection, slice: &Slice) -> RoutedPathTree {
149 use std::collections::VecDeque;
150
151 let mut pending = VecDeque::new();
152 let mut delivered = HashMap::new();
153 let mut seen = HashSet::new();
154 let mut predecessors: HashMap<usize, HashSet<usize>> = HashMap::new();
155
156 let root_frame = RoutingFrame::root(selection.clone(), slice.clone());
157 let origin = slice.location(&root_frame.here).unwrap();
158 pending.push_back(RoutedMessage::<()>::new(vec![origin], root_frame));
159
160 while let Some(RoutedMessage { path, frame, .. }) = pending.pop_front() {
161 let mut visitor = |step: RoutingStep| {
162 if let RoutingStep::Forward(next_frame) = step {
163 let key = RoutingFrameKey::new(&next_frame);
164 let should_insert = if allow_frame_dedup() {
165 seen.insert(key) // true → not seen before
166 } else {
167 true // unconditionally insert
168 };
169 if should_insert {
170 let next_rank = slice.location(&next_frame.here).unwrap();
171 let parent_rank = *path.last().unwrap();
172 predecessors
173 .entry(next_rank)
174 .or_default()
175 .insert(parent_rank);
176
177 let mut next_path = path.clone();
178 next_path.push(next_rank);
179
180 match next_frame.action() {
181 RoutingAction::Deliver => {
182 if let Some(previous) = delivered.insert(next_rank, next_path.clone()) {
183 panic!(
184 "over-delivery detected: node {} delivered twice\nfirst: {:?}\nsecond: {:?}",
185 next_rank, previous, next_path
186 );
187 }
188 }
189 RoutingAction::Forward => {
190 pending.push_back(RoutedMessage::new(next_path, next_frame));
191 }
192 }
193 }
194 }
195 ControlFlow::Continue(())
196 };
197
198 let _ = frame.next_steps(
199 &mut |_| panic!("Choice encountered in collect_routed_nodes"),
200 &mut visitor,
201 );
202 }
203
204 RoutedPathTree {
205 delivered,
206 predecessors,
207 }
208}
209
210/// Simulates routing from the origin and returns the set of
211/// destination nodes (as flat indices) selected by the
212/// `Selection`.
213///
214/// This function discards routing paths and retains only the
215/// final delivery targets. It is useful in tests to compare
216/// routing results against selection evaluation.
217pub fn collect_routed_nodes(selection: &Selection, slice: &Slice) -> Vec<usize> {
218 collect_routed_paths(selection, slice)
219 .delivered
220 .keys()
221 .cloned()
222 .collect()
223}
224
225// == Testing (`collect_commactor_routing_tree` mesh simulation) ===
226
227/// Captures the logical structure of a CommActor multicast operation.
228///
229/// This type models how a message is delivered and forwarded through
230/// a mesh under CommActor routing semantics. It is used in tests to
231/// verify path determinism and understand message propagation
232/// behavior.
233///
234/// - `delivered`: ranks where the message was delivered (`post`
235/// called)
236/// - `visited`: all ranks that participated, including forwarding
237/// only
238/// - `forwards`: maps each rank to the routing frames it forwarded
239#[derive(Default)]
240pub struct CommActorRoutingTree {
241 // Ranks that were delivered the message (i.e. called `post`). Map
242 // from rank → delivery path (flat rank indices) from root to that
243 // rank.
244 pub delivered: HashMap<usize, Vec<usize>>,
245
246 // Ranks that participated in the multicast - either by delivering
247 // the message or forwarding it to peers.
248 pub visited: HashSet<usize>,
249
250 /// Map from rank → routing frames this rank forwarded to other
251 /// ranks.
252 pub forwards: HashMap<usize, Vec<RoutingFrame>>,
253}
254
255/// Represents a routing step in the `collect_commactor_routing_tree`
256/// simulation.
257///
258/// Each instance models a message being forwarded from one rank to
259/// another, including the routing frames being propagated and the
260/// multicast path taken so far.
261///
262/// - `from`: the sender rank
263/// - `to`: the receiver rank
264/// - `frames`: routing frames to evaluate at the receiver
265/// - `path`: the multicast path from root to this step
266#[derive(Debug)]
267pub struct ForwardMessage {
268 /// The rank that is forwarding the message.
269 #[allow(dead_code)] // Never read.
270 pub from: usize,
271
272 /// The rank receiving the message.
273 pub to: usize,
274
275 /// The routing frames being forwarded.
276 pub frames: Vec<RoutingFrame>,
277
278 /// The multicast path taken so far.
279 pub path: Vec<usize>,
280}
281
282/// `collect_commactor_routing_tree` simulates how messages propagate
283/// through a mesh of `CommActor`s during multicast, reconstructing
284/// the full logical routing tree.
285///
286/// This function mirrors the behavior of `CommActor::handle_message`
287/// and `CommActor::forward`, using the shared `resolve_routing` logic
288/// to determine delivery and forwarding at each step. Starting from
289/// the root frame, it simulates how the message would be forwarded
290/// peer-to-peer through the system.
291///
292/// The returned `CommActorRoutingTree` includes:
293/// - `delivered`: ranks where the message would be delivered (i.e.,
294/// `post` called)
295/// - `visited`: all ranks that received or forwarded the message
296/// - `forwards`: frames forwarded from each rank to peers
297///
298/// This model is used in tests to validate routing behavior,
299/// especially invariants like path determinism and delivery coverage.
300pub fn collect_commactor_routing_tree(
301 selection: &Selection,
302 slice: &Slice,
303) -> CommActorRoutingTree {
304 use std::collections::VecDeque;
305
306 let mut pending = VecDeque::new();
307 let mut tree = CommActorRoutingTree::default();
308
309 let root_frame = RoutingFrame::root(selection.clone(), slice.clone());
310 let origin = slice.location(&root_frame.here).unwrap();
311 pending.push_back(ForwardMessage {
312 from: origin,
313 to: origin,
314 frames: vec![root_frame],
315 path: vec![origin],
316 });
317
318 while let Some(ForwardMessage {
319 from: _,
320 to: rank,
321 frames: dests,
322 path,
323 }) = pending.pop_front()
324 {
325 // This loop models the core of `CommActor::handle(...,
326 // fwd_message: ForwardMessage)` +
327 // `CommActor::handle_message(... next_steps...)`.
328 // - `resolve_routing` corresponds to the call in `handle`
329 // - delivery and forwarding match the logic in `handle_message`
330 // - each forward step simulates `CommActor::forward`
331
332 tree.visited.insert(rank);
333
334 let (deliver_here, forwards) =
335 resolve_routing(rank, dests, &mut |_| panic!("choice unexpected")).unwrap();
336
337 if deliver_here {
338 tree.delivered.insert(rank, path.clone());
339 }
340
341 let messages: Vec<_> = forwards
342 .into_iter()
343 .map(|(peer, peer_frames)| {
344 tree.forwards
345 .entry(rank)
346 .or_default()
347 .extend(peer_frames.clone());
348
349 let mut peer_path = path.clone();
350 peer_path.push(peer);
351
352 ForwardMessage {
353 from: rank,
354 to: peer,
355 frames: peer_frames,
356 path: peer_path,
357 }
358 })
359 .collect();
360
361 for message in messages {
362 pending.push_back(message);
363 }
364 }
365
366 tree
367}