ndslice/selection/
routing.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # Routing
10//!
11//! This module defines [`RoutingFrame`] and its [`next_steps`]
12//! method, which model how messages propagate through a
13//! multidimensional mesh based on a [`Selection`] expression.
14//!
15//! A [`RoutingFrame`] represents the state of routing at a particular
16//! point in the mesh. It tracks the current coordinate (`here`), the
17//! remaining selection to apply (`selection`), the mesh layout
18//! (`slice`), and the current dimension of traversal (`dim`).
19//!
20//! [`next_steps`] defines a routing-specific evaluation
21//! strategy for `Selection`. Unlike [`Selection::eval`], which
22//! produces flat indices that match a selection, this method produces
23//! intermediate routing states — new frames or deferred steps to
24//! continue traversing.
25//!
26//! Rather than returning raw frames directly, [`next_steps`]
27//! produces a stream of [`RoutingStep`]s via a callback — each
28//! representing a distinct kind of routing progression:
29//!
30//! - [`RoutingStep::Forward`] indicates that routing proceeds
31//!   deterministically to a new [`RoutingFrame`] — the next
32//!   coordinate is fully determined by the current selection and
33//!   frame state.
34//!
35//! - [`RoutingStep::Choice`] represents a deferred decision: it
36//!   returns a set of admissible indices, and **the caller must
37//!   select one** (e.g., for load balancing or policy-based routing)
38//!   **before routing can proceed**.
39//!
40//! In this way, non-determinism is treated as a **first-class,
41//! policy-driven** aspect of the routing system — enabling
42//! inspection, customization, and future extensions without
43//! complicating the core traversal logic.
44//!
45//! A frame is considered a delivery target if its selection is
46//! [`Selection::True`] and all dimensions have been traversed, as
47//! determined by [`RoutingFrame::deliver_here`]. All other frames are
48//! forwarded further using [`RoutingFrame::should_route`].
49//!
50//! This design enables **compositional**, **local**, and **scalable**
51//! routing:
52//! - **Compositional**: complex selection expressions decompose into
53//!   simpler, independently evaluated sub-selections.
54//! - **Local**: each frame carries exactly the state needed for its
55//!   next step — no global coordination or lookahead is required.
56//! - **Scalable**: routing unfolds recursively, one hop at a time,
57//!   allowing for efficient traversal even in high-dimensional spaces.
58//!
59//! This module provides the foundation for building structured,
60//! recursive routing logic over multidimensional coordinate spaces.
61use std::collections::HashMap;
62use std::collections::HashSet;
63use std::fmt::Write;
64use std::hash::Hash;
65use std::ops::ControlFlow;
66use std::sync::Arc;
67
68use anyhow::Result;
69use enum_as_inner::EnumAsInner;
70use serde::Deserialize;
71use serde::Serialize;
72use serde::de::DeserializeOwned;
73
74use crate::SliceError;
75use crate::selection::NormalizedSelectionKey;
76use crate::selection::Selection;
77use crate::selection::Slice;
78
79/// Represents the outcome of evaluating a routing step.
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum RoutingAction {
82    Deliver,
83    Forward,
84}
85
86/// `RoutingFrame` captures the state of a selection being evaluated:
87/// the current coordinate (`here`), the remaining selection to apply,
88/// the shape and layout information (`slice`), and the current
89/// dimension (`dim`).
90///
91/// Each frame represents an independent routing decision and produces
92/// zero or more new frames via `next_steps`.
93#[derive(Clone, Debug, Serialize, Deserialize)]
94pub struct RoutingFrame {
95    /// The current coordinate in the mesh where this frame is being
96    /// evaluated.
97    ///
98    /// This is the source location for the next routing step.
99    pub here: Vec<usize>,
100
101    /// The residual selection expression describing where routing
102    /// should continue.
103    ///
104    /// At each step, only the current dimension (tracked by `dim`) of
105    /// this selection is considered.
106    pub selection: Selection,
107
108    /// The shape and layout of the full multidimensional space being
109    /// routed.
110    ///
111    /// This determines the bounds and stride information used to
112    /// compute coordinates and flat indices.
113    pub slice: Arc<Slice>,
114
115    /// The current axis of traversal within the selection and slice.
116    ///
117    /// Routing proceeds dimension-by-dimension; this value tracks how
118    /// many dimensions have already been routed.
119    pub dim: usize,
120}
121
122// Compile-time check: ensure `RoutingFrame` is thread-safe and fully
123// owned.
124fn _assert_routing_frame_traits()
125where
126    RoutingFrame: Send + Sync + Serialize + DeserializeOwned + 'static,
127{
128}
129
130/// A `RoutingStep` represents a unit of progress in the routing
131/// process.
132///
133/// Emitted by [`RoutingFrame::next_steps`], each step describes
134/// how routing should proceed from a given frame:
135///
136/// - [`RoutingStep::Forward`] represents a deterministic hop to the
137///   next coordinate in the mesh, with an updated [`RoutingFrame`].
138///
139/// - [`RoutingStep::Choice`] indicates that routing cannot proceed
140///   until the caller selects one of several admissible indices. This
141///   allows for policy-driven or non-deterministic routing behavior,
142///   such as load balancing.
143#[derive(Debug, Clone, EnumAsInner)]
144pub enum RoutingStep {
145    /// A deterministic routing hop to the next coordinate. Carries an
146    /// updated [`RoutingFrame`] describing the new position and
147    /// residual selection.
148    Forward(RoutingFrame),
149
150    /// A deferred routing decision at the current dimension. Contains
151    /// a set of admissible indices and a residual [`RoutingFrame`] to
152    /// continue routing once a choice is made.
153    Choice(Choice),
154}
155
156/// A deferred routing decision as contained in a
157/// [`RoutingStep::Choice`].
158///
159/// A `Choice` contains:
160/// - `candidates`: the admissible indices at the current dimension
161/// - `frame`: the residual [`RoutingFrame`] describing how routing
162///   continues once a choice is made
163///
164/// To continue routing, the caller must select one of the
165/// `candidates` and call [`Choice::choose`] to produce the
166/// corresponding [`RoutingStep::Forward`].
167#[derive(Debug, Clone)]
168pub struct Choice {
169    pub(crate) candidates: Vec<usize>,
170    pub(crate) frame: RoutingFrame,
171}
172
173impl Choice {
174    /// Returns the list of admissible indices at the current
175    /// dimension.
176    ///
177    /// These represent the valid choices that the caller can select
178    /// from when resolving this deferred routing step.
179    pub fn candidates(&self) -> &[usize] {
180        &self.candidates
181    }
182
183    /// Returns a reference to the residual [`RoutingFrame`]
184    /// associated with this choice.
185    ///
186    /// This frame encodes the selection and mesh context to be used
187    /// once a choice is made, and routing continues at the next
188    /// dimension.
189    pub fn frame(&self) -> &RoutingFrame {
190        &self.frame
191    }
192
193    /// Resolves the choice by selecting a specific index.
194    ///
195    /// Constrains the residual selection to the chosen index at the
196    /// current dimension and returns a [`RoutingStep::Forward`] for
197    /// continued routing.
198    pub fn choose(self, index: usize) -> RoutingStep {
199        // The only thing `next()` has to do is constrain the
200        // selection to a concrete choice at the current dimension.
201        // `self.frame.selection` is the residual (inner) selection to
202        // be applied *past* the current dimension.
203        RoutingStep::Forward(RoutingFrame {
204            selection: crate::dsl::range(index..=index, self.frame.selection),
205            ..self.frame
206        })
207    }
208}
209
210/// Key used to deduplicate routing frames.
211#[derive(Debug, Hash, PartialEq, Eq)]
212pub struct RoutingFrameKey {
213    here: Vec<usize>,
214    dim: usize,
215    selection: NormalizedSelectionKey,
216}
217
218impl RoutingFrameKey {
219    /// Constructs a `RoutingFrameKey` from a `RoutingFrame`.
220    ///
221    /// This key uniquely identifies a routing frame by its coordinate
222    /// (`here`), current dimension, and normalized selection. It is
223    /// used during traversal for purposes such as deduplication and
224    /// memoization.
225    pub fn new(frame: &RoutingFrame) -> Self {
226        Self {
227            here: frame.here.clone(),
228            dim: frame.dim,
229            selection: NormalizedSelectionKey::new(&frame.selection),
230        }
231    }
232}
233
234impl RoutingFrame {
235    /// Constructs the initial frame at the root coordinate (all
236    /// zeros). Selections are expanded as necessary to ensure they
237    /// have depth equal to the slice dimensionality. See the docs for
238    /// `canonicalize_to_dimensions` for the rules.
239    ///
240    /// ### Canonical Handling of Zero-Dimensional Slices
241    ///
242    /// A `Slice` with zero dimensions represents the empty product
243    /// `∏_{i=1}^{0} Xᵢ`, which has exactly one element: the empty
244    /// tuple. To maintain uniform routing semantics, we canonically
245    /// embed such 0D slices as 1D slices of extent 1:
246    ///
247    /// ```text
248    /// Slice::new(offset, [1], [1])
249    /// ```
250    ///
251    /// This embedding preserves the correct number of addressable
252    /// points and allows the routing machinery to proceed through the
253    /// usual recursive strategy without introducing special cases. The
254    /// selected coordinate is `vec![0]`, and `dim = 0` proceeds as
255    /// usual. This makes the routing logic consistent with evaluation
256    /// and avoids edge case handling throughout the codebase.
257    pub fn root(selection: Selection, slice: Slice) -> Self {
258        // Canonically embed 0D as 1D (extent 1).
259        let slice = if slice.num_dim() > 0 {
260            Arc::new(slice)
261        } else {
262            Arc::new(Slice::new(slice.offset(), vec![1], vec![1]).unwrap())
263        };
264        let n = slice.num_dim();
265        RoutingFrame {
266            here: vec![0; n],
267            selection: selection.canonicalize_to_dimensions(n),
268            slice,
269            dim: 0,
270        }
271    }
272
273    /// Produces a new frame advanced to the next dimension with
274    /// updated position and selection.
275    pub fn advance(&self, here: Vec<usize>, selection: Selection) -> Self {
276        RoutingFrame {
277            here,
278            selection,
279            slice: Arc::clone(&self.slice),
280            dim: self.dim + 1,
281        }
282    }
283
284    /// Returns a new frame with the same position and dimension but a
285    /// different selection.
286    pub fn with_selection(&self, selection: Selection) -> Self {
287        RoutingFrame {
288            here: self.here.clone(),
289            selection,
290            slice: Arc::clone(&self.slice),
291            dim: self.dim,
292        }
293    }
294
295    /// Determines the appropriate routing action for this frame.
296    ///
297    /// Returns [`RoutingAction::Deliver`] if the message should be
298    /// delivered at this coordinate, or [`RoutingAction::Forward`] if
299    /// it should be routed further.
300    pub fn action(&self) -> RoutingAction {
301        if self.deliver_here() {
302            RoutingAction::Deliver
303        } else {
304            RoutingAction::Forward
305        }
306    }
307
308    /// Returns the location of this frame in the underlying slice.
309    pub fn location(&self) -> Result<usize, SliceError> {
310        self.slice.location(&self.here)
311    }
312}
313
314impl RoutingFrame {
315    /// Visits the next routing steps from this frame using a
316    /// callback-based traversal.
317    /// This method structurally recurses on the [`Selection`]
318    /// expression, yielding [`RoutingStep`]s via the `f` callback.
319    /// Early termination is supported via [`ControlFlow::Break`].
320    ///
321    /// Compared to the (old, now removed) [`next_steps`] method, this
322    /// avoids intermediate allocation, supports interruptibility, and
323    /// allows policy-driven handling of [`RoutingStep::Choice`]s via
324    /// the `chooser` callback.
325    ///
326    /// ---
327    ///
328    /// ### Traversal Strategy
329    ///
330    /// The traversal proceeds **dimension-by-dimension**,
331    /// structurally mirroring the shape of the selection expression:
332    ///
333    /// - [`Selection::All`] and [`Selection::Range`] iterate over a
334    ///   range of coordinates, emitting one [`RoutingStep::Forward`]
335    ///   per valid index.
336    /// - [`Selection::Union`] and [`Selection::Intersection`] recurse
337    ///   into both branches. Intersection steps are joined at matching
338    ///   coordinates and residual selections are reduced.
339    /// - [`Selection::Any`] randomly selects one index along the
340    ///   current dimension and emits a single step.
341    /// - [`Selection::True`] and [`Selection::False`] emit no steps.
342    ///
343    /// At each step, only the current dimension (tracked via `self.dim`)
344    /// is evaluated. Future dimensions remain untouched until deeper
345    /// recursion.
346    ///
347    /// ---
348    ///
349    /// ### Evaluation Semantics
350    ///
351    /// - **Selection::True**
352    ///   No further routing is performed — if this frame is at the
353    ///   final dimension, delivery occurs.
354    /// - **Selection::False**
355    ///   No match — routing halts.
356    ///
357    /// - **Selection::All / Selection::Range**
358    ///   Emits one [`RoutingStep::Forward`] per matching index, each
359    ///   advancing to the next dimension with the inner selection.
360    ///
361    /// - **Selection::Union**
362    ///   Evaluates both branches independently and emits all
363    ///   resulting steps.
364    ///
365    /// - **Selection::Intersection**
366    ///   Emits only those steps where both branches produce the same
367    ///   coordinate, combining the residual selections at that point.
368    ///
369    /// - **Selection::Any**
370    ///   Randomly selects one index and emits a single
371    ///   [`RoutingStep::Forward`].
372    ///
373    /// - **Selection::Choice**
374    ///   Defers decision to the caller by invoking the `chooser`
375    ///   function, which resolves the candidate index.
376    ///
377    /// ---
378    ///
379    /// ### Delivery Semantics
380    ///
381    /// Message delivery is determined by
382    /// [`RoutingFrame::deliver_here`], which returns true when:
383    ///
384    /// - The frame’s selection is [`Selection::True`], and
385    /// - All dimensions have been traversed (`dim ==
386    ///   slice.num_dim()`).
387    ///
388    /// ---
389    ///
390    /// ### Panics
391    ///
392    /// Panics if `slice.num_dim() == 0`. Use a canonical embedding
393    /// (e.g., 0D → 1D) before calling this (see e.g.
394    /// `RoutingFrame::root`).
395    ///
396    /// ---
397    ///
398    /// ### Summary
399    ///
400    /// - **Structure-driven**: Mirrors the shape of the selection
401    ///   expression.
402    /// - **Compositional**: Each variant defines its own traversal
403    ///   behavior.
404    /// - **Interruptible**: Early termination is supported via
405    ///   [`ControlFlow`].
406    /// - **Minimally allocating**: Avoids intermediate buffers in
407    ///   most cases; only [`Selection::Intersection`] allocates
408    ///   temporary state for pairwise matching.
409    /// - **Policy-ready**: Integrates with runtime routing policies
410    ///   via the `chooser`.
411    pub fn next_steps(
412        &self,
413        _chooser: &mut dyn FnMut(&Choice) -> usize,
414        f: &mut dyn FnMut(RoutingStep) -> ControlFlow<()>,
415    ) -> ControlFlow<()> {
416        assert!(self.slice.num_dim() > 0, "next_steps requires num_dims > 0");
417
418        match &self.selection {
419            Selection::True => ControlFlow::Continue(()),
420            Selection::False => ControlFlow::Continue(()),
421            Selection::All(inner) => {
422                let size = self.slice.sizes()[self.dim];
423                for i in 0..size {
424                    let mut coord = self.here.clone();
425                    coord[self.dim] = i;
426                    let frame = self.advance(coord, (**inner).clone());
427                    if let ControlFlow::Break(_) = f(RoutingStep::Forward(frame)) {
428                        return ControlFlow::Break(());
429                    }
430                }
431                ControlFlow::Continue(())
432            }
433
434            Selection::Range(range, inner) => {
435                let size = self.slice.sizes()[self.dim];
436                let (min, max, step) = range.resolve(size);
437
438                for i in (min..max).step_by(step) {
439                    let mut coord = self.here.clone();
440                    coord[self.dim] = i;
441                    let frame = self.advance(coord, (**inner).clone());
442                    if let ControlFlow::Break(_) = f(RoutingStep::Forward(frame)) {
443                        return ControlFlow::Break(());
444                    }
445                }
446
447                ControlFlow::Continue(())
448            }
449
450            Selection::Any(inner) => {
451                let size = self.slice.sizes()[self.dim];
452                if size == 0 {
453                    return ControlFlow::Continue(());
454                }
455
456                use rand::Rng;
457                let mut rng: rand::prelude::ThreadRng = rand::thread_rng();
458                let i = rng.gen_range(0..size);
459                let mut coord = self.here.clone();
460                coord[self.dim] = i;
461                let frame = self.advance(coord, (**inner).clone());
462                f(RoutingStep::Forward(frame))
463            }
464
465            Selection::Union(a, b) => {
466                if let ControlFlow::Break(_) =
467                    self.with_selection((**a).clone()).next_steps(_chooser, f)
468                {
469                    return ControlFlow::Break(());
470                }
471                self.with_selection((**b).clone()).next_steps(_chooser, f)
472            }
473
474            Selection::Intersection(a, b) => {
475                let mut left = vec![];
476                let mut right = vec![];
477
478                let mut collect_left = |step: RoutingStep| {
479                    if let RoutingStep::Forward(frame) = step {
480                        left.push(frame);
481                    }
482                    ControlFlow::Continue(())
483                };
484                let mut collect_right = |step: RoutingStep| {
485                    if let RoutingStep::Forward(frame) = step {
486                        right.push(frame);
487                    }
488                    ControlFlow::Continue(())
489                };
490
491                self.with_selection((**a).clone())
492                    .next_steps(_chooser, &mut collect_left)?;
493                self.with_selection((**b).clone())
494                    .next_steps(_chooser, &mut collect_right)?;
495
496                for fa in &left {
497                    for fb in &right {
498                        if fa.here == fb.here {
499                            let residual = fa
500                                .selection
501                                .clone()
502                                .reduce_intersection(fb.selection.clone());
503                            let frame = self.advance(fa.here.clone(), residual);
504                            if let ControlFlow::Break(_) = f(RoutingStep::Forward(frame)) {
505                                return ControlFlow::Break(());
506                            }
507                        }
508                    }
509                }
510
511                ControlFlow::Continue(())
512            }
513
514            // TODO(SF, 2025-04-30): This term is not in the algebra
515            // yet.
516            // Selection::LoadBalanced(inner) => {
517            //     let size = self.slice.sizes()[self.dim];
518            //     if size == 0 {
519            //         ControlFlow::Continue(())
520            //     } else {
521            //         let candidates = (0..size).collect();
522            //         let choice = Choice {
523            //             candidates,
524            //             frame: self.with_selection((*inner).clone()),
525            //         };
526            //         let index = chooser(&choice);
527            //         f(choice.choose(index))
528            //     }
529            // }
530
531            // Catch-all for future combinators (e.g., Label).
532            _ => unimplemented!(),
533        }
534    }
535
536    /// Returns true if this frame represents a terminal delivery
537    /// point — i.e., the selection is `True` and all dimensions have
538    /// been traversed.
539    pub fn deliver_here(&self) -> bool {
540        matches!(self.selection, Selection::True) && self.dim == self.slice.num_dim()
541    }
542
543    /// Returns true if the message has not yet reached its final
544    /// destination and should be forwarded to the next routing step.
545    pub fn should_route(&self) -> bool {
546        !self.deliver_here()
547    }
548}
549
550impl RoutingFrame {
551    /// Traces the unique routing path to the given destination
552    /// coordinate.
553    ///
554    /// Returns `Some(vec![root, ..., dest])` if `dest` is selected,
555    /// or `None` if not.
556    pub fn trace_route(&self, dest: &[usize]) -> Option<Vec<Vec<usize>>> {
557        use std::collections::HashSet;
558        use std::ops::ControlFlow;
559
560        use crate::selection::routing::RoutingFrameKey;
561
562        fn go(
563            frame: RoutingFrame,
564            dest: &[usize],
565            mut path: Vec<Vec<usize>>,
566            seen: &mut HashSet<RoutingFrameKey>,
567        ) -> Option<Vec<Vec<usize>>> {
568            let key = RoutingFrameKey::new(&frame);
569            if !seen.insert(key) {
570                return None;
571            }
572
573            path.push(frame.here.clone());
574
575            if frame.deliver_here() && frame.here == dest {
576                return Some(path);
577            }
578
579            let mut found = None;
580            let _ = frame.next_steps(
581                &mut |_| panic!("Choice encountered in trace_route"),
582                &mut |step: RoutingStep| {
583                    let next = step.into_forward().unwrap();
584                    if let Some(result) = go(next, dest, path.clone(), seen) {
585                        found = Some(result);
586                        ControlFlow::Break(())
587                    } else {
588                        ControlFlow::Continue(())
589                    }
590                },
591            );
592
593            found
594        }
595
596        let mut seen = HashSet::new();
597        go(self.clone(), dest, Vec::new(), &mut seen)
598    }
599}
600
601/// Formats a routing path as a string, showing each hop in order.
602///
603/// Each line shows the hop index, an arrow (`→` for intermediate
604/// steps, `⇨` for the final destination), and the coordinate as a
605/// tuple (e.g., `(0, 1)`).
606/// # Example
607///
608/// ```text
609///  0 → (0, 0)
610///  1 → (0, 1)
611///  2 ⇨ (1, 1)
612/// ```
613#[track_caller]
614#[allow(dead_code)]
615pub fn format_route(route: &[Vec<usize>]) -> String {
616    let mut out = String::new();
617    for (i, hop) in route.iter().enumerate() {
618        let arrow = if i == route.len() - 1 { "⇨" } else { "→" };
619        let coord = format!(
620            "({})",
621            hop.iter()
622                .map(ToString::to_string)
623                .collect::<Vec<_>>()
624                .join(", ")
625        );
626        let _ = writeln!(&mut out, "{:>2} {} {}", i, arrow, coord);
627    }
628    out
629}
630
631/// Formats a routing tree as an indented string.
632///
633/// Traverses the tree of `RoutingFrame`s starting from the root,
634/// displaying each step with indentation by dimension. Delivery
635/// targets are marked `✅`.
636///
637/// # Example
638/// ```text
639/// (0, 0)
640///   (0, 1) ✅
641/// (1, 0)
642///   (1, 1) ✅
643/// ```
644#[track_caller]
645#[allow(dead_code)]
646pub fn format_routing_tree(selection: Selection, slice: &Slice) -> String {
647    let root = RoutingFrame::root(selection, slice.clone());
648    let mut out = String::new();
649    let mut seen = HashSet::new();
650    format_routing_tree_rec(&root, 0, &mut out, &mut seen).unwrap();
651    out
652}
653
654fn format_routing_tree_rec(
655    frame: &RoutingFrame,
656    indent: usize,
657    out: &mut String,
658    seen: &mut HashSet<RoutingFrameKey>,
659) -> std::fmt::Result {
660    use crate::selection::routing::RoutingFrameKey;
661
662    let key = RoutingFrameKey::new(frame);
663    if !seen.insert(key) {
664        return Ok(()); // already visited
665    }
666
667    let indent_str = "  ".repeat(indent);
668    let coord_str = format!(
669        "({})",
670        frame
671            .here
672            .iter()
673            .map(ToString::to_string)
674            .collect::<Vec<_>>()
675            .join(", ")
676    );
677
678    match frame.action() {
679        RoutingAction::Deliver => {
680            writeln!(out, "{}{} ✅", indent_str, coord_str)?;
681        }
682        RoutingAction::Forward => {
683            writeln!(out, "{}{}", indent_str, coord_str)?;
684            let _ = frame.next_steps(
685                &mut |_| panic!("Choice encountered in format_routing_tree_rec"),
686                &mut |step| {
687                    let next = step.into_forward().unwrap();
688                    format_routing_tree_rec(&next, indent + 1, out, seen).unwrap();
689                    ControlFlow::Continue(())
690                },
691            );
692        }
693    }
694
695    Ok(())
696}
697
698// Pretty-prints a routing path from source to destination.
699//
700// Each hop is shown as a numbered step with directional arrows.
701#[track_caller]
702#[allow(dead_code)]
703pub fn print_route(route: &[Vec<usize>]) {
704    println!("{}", format_route(route));
705}
706
707/// Prints the routing tree for a selection over a slice.
708///
709/// Traverses the routing structure from the root, printing each step
710/// with indentation by dimension. Delivery points are marked with
711/// `✅`.
712#[track_caller]
713#[allow(dead_code)]
714pub fn print_routing_tree(selection: Selection, slice: &Slice) {
715    println!("{}", format_routing_tree(selection, slice));
716}
717
718// == "CommActor multicast" routing ==
719
720/// Resolves the current set of routing frames (`dests`) to determine
721/// whether the message should be delivered at this rank, and which
722/// routing frames should be forwarded to peer ranks.
723///
724/// This is the continuation of a multicast operation: each forwarded
725/// message contains one or more `RoutingFrame`s that represent
726/// partial routing state. This call determines how those frames
727/// propagate next.
728///
729/// `deliver_here` is true if any frame targets this rank and
730/// indicates delivery. `next_steps` contains the peer ranks and frames
731/// to forward.
732///
733/// This is also the top-level entry point for CommActor's routing
734/// logic.
735pub fn resolve_routing(
736    rank: usize,
737    frames: impl IntoIterator<Item = RoutingFrame>,
738    chooser: &mut dyn FnMut(&Choice) -> usize,
739) -> Result<(bool, HashMap<usize, Vec<RoutingFrame>>)> {
740    let mut deliver_here = false;
741    let mut next_steps = HashMap::new();
742    for frame in frames {
743        resolve_routing_one(rank, frame, chooser, &mut deliver_here, &mut next_steps)?;
744    }
745    Ok((deliver_here, next_steps))
746}
747
748/// Recursively resolves routing for a single `RoutingFrame` at the
749/// given rank, determining whether the message should be delivered
750/// locally and which frames should be forwarded to peer ranks.
751///
752/// - If the frame targets the local `rank` and is a delivery point,
753///   `deliver_here` is set to `true`.
754/// - If the frame targets the local `rank` but is not a delivery
755///   point, the function recurses on its forward steps.
756/// - If the frame targets a different rank, it is added to
757///   `next_steps`.
758///
759/// Deduplication is handled by `get_next_steps`. Dynamic constructs
760/// such as `Any` or `First` must be resolved by the provided
761/// `chooser`, which selects an index from a `Choice`.
762///
763/// Traversal is depth-first within a rank and breadth-first across
764/// ranks. This defines the exact routing behavior used by
765/// `CommActor`: it exhaustively evaluates all local routing structure
766/// before forwarding to peers.
767///
768/// The resulting `next_steps` map contains all non-local ranks that
769/// should receive forwarded frames, where each entry maps a peer rank
770/// to a list of routing continuations to evaluate at that peer.
771/// `deliver_here` is set to `true` if the current rank is a final
772/// delivery point.
773pub(crate) fn resolve_routing_one(
774    rank: usize,
775    frame: RoutingFrame,
776    chooser: &mut dyn FnMut(&Choice) -> usize,
777    deliver_here: &mut bool,
778    next_steps: &mut HashMap<usize, Vec<RoutingFrame>>,
779) -> Result<()> {
780    let frame_rank = frame.slice.location(&frame.here)?;
781    if frame_rank == rank {
782        if frame.deliver_here() {
783            *deliver_here = true;
784        } else {
785            for next in get_next_steps(frame, chooser)? {
786                resolve_routing_one(rank, next, chooser, deliver_here, next_steps)?;
787            }
788        }
789    } else {
790        next_steps.entry(frame_rank).or_default().push(frame);
791    }
792    Ok(())
793}
794
795/// Computes the set of `Forward` routing frames reachable from the
796/// given `RoutingFrame`.
797///
798/// This function traverses the result of `frame.next_steps(...)`,
799/// collecting only `RoutingStep::Forward(_)` steps. The caller
800/// provides a `chooser` function to resolve dynamic constructs such
801/// as `Any` or `First`.
802///
803/// Some obviously redundant steps may be filtered, but no strict
804/// guarantee is made about structural uniqueness.
805fn get_next_steps(
806    dest: RoutingFrame,
807    chooser: &mut dyn FnMut(&Choice) -> usize,
808) -> Result<Vec<RoutingFrame>> {
809    let mut seen = HashSet::new();
810    let mut unique_steps = vec![];
811    let _ = dest.next_steps(chooser, &mut |step| {
812        if let RoutingStep::Forward(frame) = step {
813            let key = RoutingFrameKey::new(&frame);
814            if seen.insert(key) {
815                unique_steps.push(frame);
816            }
817        }
818        ControlFlow::Continue(())
819    });
820    Ok(unique_steps)
821}
822
823// == Testing (`collect_commactor_routing_tree` mesh simulation) ===
824
825/// Captures the logical structure of a CommActor multicast operation.
826///
827/// This type models how a message is delivered and forwarded through
828/// a mesh under CommActor routing semantics. It is used in tests to
829/// verify path determinism and understand message propagation
830/// behavior.
831///
832/// - `delivered`: ranks where the message was delivered (`post`
833///   called)
834/// - `visited`: all ranks that participated, including forwarding
835///   only
836/// - `forwards`: maps each rank to the routing frames it forwarded
837#[cfg(test)]
838#[allow(dead_code)]
839#[derive(Default)]
840pub(crate) struct CommActorRoutingTree {
841    // Ranks that were delivered the message (i.e. called `post`). Map
842    // from rank → delivery path (flat rank indices) from root to that
843    // rank.
844    pub delivered: HashMap<usize, Vec<usize>>,
845
846    // Ranks that participated in the multicast - either by delivering
847    // the message or forwarding it to peers.
848    pub visited: HashSet<usize>,
849
850    /// Map from rank → routing frames this rank forwarded to other
851    /// ranks.
852    pub forwards: HashMap<usize, Vec<RoutingFrame>>,
853}
854
855#[cfg(test)]
856mod tests {
857    use std::collections::HashSet;
858    use std::collections::VecDeque;
859
860    use super::RoutingAction;
861    use super::RoutingFrame;
862    use super::print_route;
863    use super::print_routing_tree;
864    use crate::Slice;
865    use crate::selection::EvalOpts;
866    use crate::selection::Selection;
867    use crate::selection::dsl::*;
868    use crate::selection::test_utils::RoutedMessage;
869    use crate::selection::test_utils::collect_commactor_routing_tree;
870    use crate::selection::test_utils::collect_routed_nodes;
871    use crate::selection::test_utils::collect_routed_paths;
872    use crate::shape;
873
874    // A test slice: (zones = 2, hosts = 4, gpus = 8).
875    fn test_slice() -> Slice {
876        Slice::new(0usize, vec![2, 4, 8], vec![32, 8, 1]).unwrap()
877    }
878
879    /// Asserts that a routing strategy produces the same set of nodes
880    /// as `Selection::eval`.
881    ///
882    /// This macro compares the result of evaluating a `Selection`
883    /// using the given `collector` against the reference
884    /// implementation `Selection::eval` (with lenient options).
885    ///
886    /// The `collector` should be a function or closure of type
887    /// `Fn(&Selection, &Slice) -> Vec<usize>`, such as
888    /// `collect_routed_nodes` or a CommActor-based simulation.
889    ///
890    /// Panics if the two sets of routed nodes differ.
891    ///
892    /// # Example
893    /// ```
894    /// assert_routing_eq_with!(slice, selection, collect_routed_nodes);
895    /// ```
896    macro_rules! assert_routing_eq_with {
897        ($slice:expr_2021, $sel:expr_2021, $collector:expr_2021) => {{
898            let sel = $sel;
899            let slice = $slice.clone();
900            let mut expected: Vec<_> = sel.eval(&EvalOpts::lenient(), &slice).unwrap().collect();
901            expected.sort();
902            let mut actual: Vec<_> = ($collector)(&sel, &slice);
903            actual.sort();
904            assert_eq!(actual, expected, "Mismatch for selection: {}", sel);
905        }};
906    }
907
908    /// Asserts that `collect_routed_nodes` matches `Selection::eval`
909    /// on the given slice.
910    macro_rules! assert_collect_routed_nodes_eq {
911        ($slice:expr_2021, $sel:expr_2021) => {
912            assert_routing_eq_with!($slice, $sel, collect_routed_nodes)
913        };
914    }
915
916    /// Asserts that CommActor routing delivers to the same nodes as
917    /// `Selection::eval`.
918    macro_rules! assert_commactor_routing_eq {
919        ($slice:expr_2021, $sel:expr_2021) => {
920            assert_routing_eq_with!($slice, $sel, |s, sl| {
921                collect_commactor_routing_tree(s, sl)
922                    .delivered
923                    .into_keys()
924                    .collect()
925            });
926        };
927    }
928
929    /// Asserts that all routing strategies produce the same set of
930    /// routed nodes as `Selection::eval`.
931    ///
932    /// Compares both the direct strategy (`collect_routed_nodes`) and
933    /// the CommActor routing simulation
934    /// (`collect_commactor_routing_tree`) against the expected output
935    /// from `Selection::eval`.
936    macro_rules! assert_all_routing_strategies_eq {
937        ($slice:expr_2021, $sel:expr_2021) => {
938            assert_collect_routed_nodes_eq!($slice, $sel);
939            assert_commactor_routing_eq!($slice, $sel);
940        };
941    }
942
943    #[test]
944    fn test_routing_04() {
945        use crate::selection::dsl::*;
946
947        let slice = test_slice(); // [2, 4, 8], strides [32, 8, 1]
948
949        // Destination: GPU 2 on host 2 in zone 1.
950        let dest = vec![1, 2, 2];
951        let selection = range(1, range(2, range(2, true_())));
952        let root = RoutingFrame::root(selection.clone(), slice.clone());
953        let path = root.trace_route(&dest).expect("no route found");
954        println!(
955            "\ndest: {:?}, (singleton-)selection: ({})\n",
956            &dest, &selection
957        );
958        print_route(&path);
959        println!("\n");
960        assert_eq!(path.last(), Some(&dest));
961
962        // Destination: "Right back where we started from 🙂".
963        let dest = vec![0, 0, 0];
964        let selection = range(0, range(0, range(0, true_())));
965        let root = RoutingFrame::root(selection.clone(), slice.clone());
966        let path = root.trace_route(&dest).expect("no route found");
967        println!(
968            "\ndest: {:?}, (singleton-)selection: ({})\n",
969            &dest, &selection
970        );
971        print_route(&path);
972        println!("\n");
973        assert_eq!(path.last(), Some(&dest));
974    }
975
976    #[test]
977    fn test_routing_05() {
978        use crate::selection::dsl::*;
979
980        // "Jun's example" -- a 2 x 2 row major mesh.
981        let slice = Slice::new(0usize, vec![2, 2], vec![2, 1]).unwrap();
982        // Thats is,
983        //  (0, 0)    (0, 1)
984        //  (0, 1)    (1, 0)
985        //
986        // and we want to cast to {(0, 1), (1, 0) and (1, 1)}:
987        //
988        //  (0, 0)❌    (0, 1)✅
989        //  (0, 1)✅    (1, 0)✅
990        //
991        // One reasonable selection expression describing the
992        // destination set.
993        let selection = union(range(0, range(1, true_())), range(1, all(true_())));
994
995        // Now print the routing tree.
996        print_routing_tree(selection, &slice);
997
998        // Prints:
999        // (0, 0)
1000        //   (0, 0)
1001        //     (0, 1) ✅
1002        //   (1, 0)
1003        //     (1, 0) ✅
1004        //     (1, 1) ✅
1005
1006        // Another example: (zones = 2, hosts = 4, gpus = 8).
1007        let slice = Slice::new(0usize, vec![2, 4, 8], vec![32, 8, 1]).unwrap();
1008        // Let's have all the odd GPUs on hosts 1, 2 and 3 in zone 0.
1009        let selection = range(
1010            0,
1011            range(1..4, range(shape::Range(1, None, /*step*/ 2), true_())),
1012        );
1013
1014        // Now print the routing tree.
1015        print_routing_tree(selection, &slice);
1016
1017        // Prints:
1018        // (0, 0, 0)
1019        //   (0, 0, 0)
1020        //     (0, 1, 0)
1021        //       (0, 1, 1) ✅
1022        //       (0, 1, 3) ✅
1023        //       (0, 1, 5) ✅
1024        //       (0, 1, 7) ✅
1025        //     (0, 2, 0)
1026        //       (0, 2, 1) ✅
1027        //       (0, 2, 3) ✅
1028        //       (0, 2, 5) ✅
1029        //       (0, 2, 7) ✅
1030        //     (0, 3, 0)
1031        //       (0, 3, 1) ✅
1032        //       (0, 3, 3) ✅
1033        //       (0, 3, 5) ✅
1034        //       (0, 3, 7) ✅
1035    }
1036
1037    #[test]
1038    fn test_routing_00() {
1039        let slice = test_slice();
1040
1041        assert_all_routing_strategies_eq!(slice, false_());
1042        assert_all_routing_strategies_eq!(slice, true_());
1043        assert_all_routing_strategies_eq!(slice, all(true_()));
1044        assert_all_routing_strategies_eq!(slice, all(all(true_())));
1045        assert_all_routing_strategies_eq!(slice, all(all(false_())));
1046        assert_all_routing_strategies_eq!(slice, all(all(all(true_()))));
1047        assert_all_routing_strategies_eq!(slice, all(range(0..=0, all(true_()))));
1048        assert_all_routing_strategies_eq!(slice, all(all(range(0..4, true_()))));
1049        assert_all_routing_strategies_eq!(slice, all(range(1..=2, all(true_()))));
1050        assert_all_routing_strategies_eq!(slice, all(all(range(2..6, true_()))));
1051        assert_all_routing_strategies_eq!(slice, all(all(range(3..=3, true_()))));
1052        assert_all_routing_strategies_eq!(slice, all(range(1..3, all(true_()))));
1053        assert_all_routing_strategies_eq!(slice, all(all(range(0..=0, true_()))));
1054        assert_all_routing_strategies_eq!(slice, range(1..=1, range(3..=3, range(0..=2, true_()))));
1055        assert_all_routing_strategies_eq!(
1056            slice,
1057            all(all(range(shape::Range(0, Some(8), 2), true_())))
1058        );
1059        assert_all_routing_strategies_eq!(
1060            slice,
1061            all(range(shape::Range(1, Some(4), 2), all(true_())))
1062        );
1063    }
1064
1065    #[test]
1066    fn test_routing_03() {
1067        let slice = test_slice();
1068
1069        assert_all_routing_strategies_eq!(
1070            slice,
1071            // sel!(0 & (0,(1|3), *))
1072            intersection(
1073                range(0, true_()),
1074                range(0, union(range(1, all(true_())), range(3, all(true_()))))
1075            )
1076        );
1077        assert_all_routing_strategies_eq!(
1078            slice,
1079            // sel!(0 & (0, (3|1), *)),
1080            intersection(
1081                range(0, true_()),
1082                range(0, union(range(3, all(true_())), range(1, all(true_()))))
1083            )
1084        );
1085        assert_all_routing_strategies_eq!(
1086            slice,
1087            // sel!((*, *, *) & (*, *, (2 | 4)))
1088            intersection(
1089                all(all(all(true_()))),
1090                all(all(union(range(2, true_()), range(4, true_()))))
1091            )
1092        );
1093        assert_all_routing_strategies_eq!(
1094            slice,
1095            // sel!((*, *, *) & (*, *, (4 | 2)))
1096            intersection(
1097                all(all(all(true_()))),
1098                all(all(union(range(4, true_()), range(2, true_()))))
1099            )
1100        );
1101        assert_all_routing_strategies_eq!(
1102            slice,
1103            // sel!((*, (1 | 2)) & (*, (2 | 1)))
1104            intersection(
1105                all(union(range(1, true_()), range(2, true_()))),
1106                all(union(range(2, true_()), range(1, true_())))
1107            )
1108        );
1109        assert_all_routing_strategies_eq!(
1110            slice,
1111            intersection(all(all(all(true_()))), all(true_()))
1112        );
1113        assert_all_routing_strategies_eq!(slice, intersection(true_(), all(all(all(true_())))));
1114        assert_all_routing_strategies_eq!(slice, intersection(all(all(all(true_()))), false_()));
1115        assert_all_routing_strategies_eq!(slice, intersection(false_(), all(all(all(true_())))));
1116        assert_all_routing_strategies_eq!(
1117            slice,
1118            intersection(
1119                all(all(range(0..4, true_()))),
1120                all(all(range(0..4, true_())))
1121            )
1122        );
1123        assert_all_routing_strategies_eq!(
1124            slice,
1125            intersection(all(all(range(1, true_()))), all(all(range(2, true_()))))
1126        );
1127        assert_all_routing_strategies_eq!(
1128            slice,
1129            intersection(all(all(range(2, true_()))), all(all(range(1, true_()))))
1130        );
1131        assert_all_routing_strategies_eq!(
1132            slice,
1133            intersection(
1134                all(all(range(1, true_()))),
1135                intersection(all(all(true_())), all(all(range(1, true_()))))
1136            )
1137        );
1138        assert_all_routing_strategies_eq!(
1139            slice,
1140            intersection(
1141                range(0, true_()),
1142                range(0, all(union(range(1, true_()), range(3, true_()))))
1143            )
1144        );
1145        assert_all_routing_strategies_eq!(
1146            slice,
1147            range(
1148                0,
1149                intersection(true_(), all(union(range(1, true_()), range(3, true_()))))
1150            )
1151        );
1152        assert_all_routing_strategies_eq!(
1153            slice,
1154            intersection(all(range(1..=2, true_())), all(range(2..=3, true_())))
1155        );
1156        assert_all_routing_strategies_eq!(
1157            slice,
1158            intersection(
1159                range(0, true_()),
1160                intersection(range(0, all(true_())), range(0, range(1, all(true_()))))
1161            )
1162        );
1163        assert_all_routing_strategies_eq!(
1164            slice,
1165            intersection(
1166                range(0, range(1, all(true_()))),
1167                intersection(range(0, all(true_())), range(0, true_()))
1168            )
1169        );
1170        assert_all_routing_strategies_eq!(
1171            slice,
1172            // sel!( (*, *, *) & ((*, *, *) & (*, *, *)) ),
1173            intersection(
1174                all(all(all(true_()))),
1175                intersection(all(all(all(true_()))), all(all(all(true_()))))
1176            )
1177        );
1178        assert_all_routing_strategies_eq!(
1179            slice,
1180            union(
1181                intersection(range(0, true_()), range(0, range(1, all(true_())))),
1182                range(1, all(all(true_())))
1183            )
1184        );
1185        assert_all_routing_strategies_eq!(
1186            slice,
1187            // sel!((1, *, *) | (0 & (0, 3, *)))
1188            union(
1189                range(1, all(all(true_()))),
1190                intersection(range(0, true_()), range(0, range(3, all(true_()))))
1191            )
1192        );
1193        assert_all_routing_strategies_eq!(
1194            slice,
1195            intersection(
1196                union(range(0, true_()), range(1, true_())),
1197                union(range(1, true_()), range(0, true_()))
1198            )
1199        );
1200        assert_all_routing_strategies_eq!(
1201            slice,
1202            union(
1203                intersection(range(0, range(1, true_())), range(0, range(1, true_()))),
1204                intersection(range(1, range(3, true_())), range(1, range(3, true_())))
1205            )
1206        );
1207        assert_all_routing_strategies_eq!(
1208            slice,
1209            // sel!(*, 8 : 8)
1210            all(range(8..8, true_()))
1211        );
1212        assert_all_routing_strategies_eq!(
1213            slice,
1214            // sel!((*, 1) & (*, 8 : 8))
1215            intersection(all(range(1..2, true_())), all(range(8..8, true_())))
1216        );
1217        assert_all_routing_strategies_eq!(
1218            slice,
1219            // sel!((*, 8 : 8) | (*, 1))
1220            union(all(range(8..8, true_())), all(range(1..2, true_())))
1221        );
1222        assert_all_routing_strategies_eq!(
1223            slice,
1224            // sel!((*, 1) | (*, 2:8))
1225            union(all(range(1..2, true_())), all(range(2..8, true_())))
1226        );
1227        assert_all_routing_strategies_eq!(
1228            slice,
1229            // sel!((*, *, *) & (*, *, 2:8))
1230            intersection(all(all(all(true_()))), all(all(range(2..8, true_()))))
1231        );
1232    }
1233
1234    #[test]
1235    fn test_routing_02() {
1236        let slice = test_slice();
1237
1238        // zone 0 or 1: sel!(0 | 1, *, *)
1239        assert_all_routing_strategies_eq!(slice, union(range(0, true_()), range(1, true_())));
1240        assert_all_routing_strategies_eq!(
1241            slice,
1242            union(range(0, all(true_())), range(1, all(true_())))
1243        );
1244        // hosts 1 and 3 in zone 0: sel!(0, (1 | 3), *)
1245        assert_all_routing_strategies_eq!(
1246            slice,
1247            range(0, union(range(1, all(true_())), range(3, all(true_()))))
1248        );
1249        // sel!(0, 1:3 | 5:7, *)
1250        assert_all_routing_strategies_eq!(
1251            slice,
1252            range(
1253                0,
1254                union(
1255                    range(shape::Range(1, Some(3), 1), all(true_())),
1256                    range(shape::Range(5, Some(7), 1), all(true_()))
1257                )
1258            )
1259        );
1260
1261        // sel!(* | *): We start with `union(true_(), true_())`.
1262        //
1263        // Evaluating the left branch generates routing frames
1264        // recursively. Evaluating the right branch generates the same
1265        // frames again.
1266        //
1267        // As a result, we produce duplicate `RoutingFrame`s that
1268        // have:
1269        // - the same `here` coordinate,
1270        // - the same dimension (`dim`), and
1271        // - the same residual selection (`True`).
1272        //
1273        // When both frames reach the delivery condition, the second
1274        // call to `delivered.insert()` returns `false`. If we put an
1275        // `assert!` on that line this would trigger assertion failure
1276        // in the routing simulation.
1277        //
1278        // TODO: We need memoization to avoid redundant work.
1279        //
1280        // This can be achieved without transforming the algebra itself.
1281        // However, adding normalization will make memoization more
1282        // effective, so we should plan to implement both.
1283        //
1284        // Once that's done, we can safely restore the `assert!`.
1285        assert_all_routing_strategies_eq!(slice, union(true_(), true_()));
1286        // sel!(*, *, * | *, *, *)
1287        assert_all_routing_strategies_eq!(
1288            slice,
1289            union(all(all(all(true_()))), all(all(all(true_()))))
1290        );
1291        // no 'false' support in sel!
1292        assert_all_routing_strategies_eq!(slice, union(false_(), all(all(all(true_())))));
1293        assert_all_routing_strategies_eq!(slice, union(all(all(all(true_()))), false_()));
1294        // sel!(0, 0:4, 0 | 1 | 2)
1295        assert_all_routing_strategies_eq!(
1296            slice,
1297            range(
1298                0,
1299                range(
1300                    shape::Range(0, Some(4), 1),
1301                    union(
1302                        range(0, true_()),
1303                        union(range(1, true_()), range(2, true_()))
1304                    )
1305                )
1306            )
1307        );
1308        assert_all_routing_strategies_eq!(
1309            slice,
1310            range(
1311                0,
1312                union(range(2, range(4, true_())), range(3, range(5, true_())),),
1313            )
1314        );
1315        assert_all_routing_strategies_eq!(
1316            slice,
1317            range(0, range(2, union(range(4, true_()), range(5, true_()),),),)
1318        );
1319        assert_all_routing_strategies_eq!(
1320            slice,
1321            range(
1322                0,
1323                union(range(2, range(4, true_())), range(3, range(5, true_())),),
1324            )
1325        );
1326        assert_all_routing_strategies_eq!(
1327            slice,
1328            union(
1329                range(
1330                    0,
1331                    union(range(2, range(4, true_())), range(3, range(5, true_())))
1332                ),
1333                range(
1334                    1,
1335                    union(range(2, range(4, true_())), range(3, range(5, true_())))
1336                )
1337            )
1338        );
1339    }
1340
1341    #[test]
1342    fn test_routing_01() {
1343        use std::ops::ControlFlow;
1344
1345        let slice = test_slice();
1346        let sel = range(0..=0, all(true_()));
1347
1348        let expected_fanouts: &[&[&[usize]]] = &[
1349            &[&[0, 0, 0]],
1350            &[&[0, 0, 0], &[0, 1, 0], &[0, 2, 0], &[0, 3, 0]],
1351            &[
1352                &[0, 0, 0],
1353                &[0, 0, 1],
1354                &[0, 0, 2],
1355                &[0, 0, 3],
1356                &[0, 0, 4],
1357                &[0, 0, 5],
1358                &[0, 0, 6],
1359                &[0, 0, 7],
1360            ],
1361            &[
1362                &[0, 1, 0],
1363                &[0, 1, 1],
1364                &[0, 1, 2],
1365                &[0, 1, 3],
1366                &[0, 1, 4],
1367                &[0, 1, 5],
1368                &[0, 1, 6],
1369                &[0, 1, 7],
1370            ],
1371            &[
1372                &[0, 2, 0],
1373                &[0, 2, 1],
1374                &[0, 2, 2],
1375                &[0, 2, 3],
1376                &[0, 2, 4],
1377                &[0, 2, 5],
1378                &[0, 2, 6],
1379                &[0, 2, 7],
1380            ],
1381            &[
1382                &[0, 3, 0],
1383                &[0, 3, 1],
1384                &[0, 3, 2],
1385                &[0, 3, 3],
1386                &[0, 3, 4],
1387                &[0, 3, 5],
1388                &[0, 3, 6],
1389                &[0, 3, 7],
1390            ],
1391        ];
1392
1393        let expected_deliveries: &[bool] = &[
1394            false, false, false, false, false, false, // Steps 0–5
1395            true, true, true, true, true, true, true, true, true, true, true, true, true, true,
1396            true, true, true, true, true, true, true, true, true, true, true, true, true, true,
1397            true, true, true, true, true, // Steps 6–38
1398        ];
1399
1400        let mut step = 0;
1401        let mut pending = VecDeque::new();
1402
1403        pending.push_back(RoutingFrame::root(sel.clone(), slice.clone()));
1404
1405        println!("Fan-out trace for selection: {}", sel);
1406
1407        while let Some(frame) = pending.pop_front() {
1408            let mut next_coords = vec![];
1409
1410            let deliver_here = frame.deliver_here();
1411
1412            let _ = frame.next_steps(
1413                &mut |_| panic!("Choice encountered in test_routing_01"),
1414                &mut |step| {
1415                    let next = step.into_forward().unwrap();
1416                    next_coords.push(next.here.clone());
1417                    pending.push_back(next);
1418                    ControlFlow::Continue(())
1419                },
1420            );
1421
1422            println!(
1423                "Step {:>2}: from {:?} (flat = {:>2}) | deliver = {} | fan-out count = {} | selection = {:?}",
1424                step,
1425                frame.here,
1426                frame.slice.location(&frame.here).unwrap(),
1427                deliver_here,
1428                next_coords.len(),
1429                format!("{}", frame.selection),
1430            );
1431
1432            for next in &next_coords {
1433                println!("         → {:?}", next);
1434            }
1435
1436            if step < expected_fanouts.len() {
1437                let expected = expected_fanouts[step]
1438                    .iter()
1439                    .map(|v| v.to_vec())
1440                    .collect::<Vec<_>>();
1441                assert_eq!(
1442                    next_coords, expected,
1443                    "Mismatch in next_coords at step {}",
1444                    step
1445                );
1446            }
1447
1448            if step < expected_deliveries.len() {
1449                assert_eq!(
1450                    deliver_here, expected_deliveries[step],
1451                    "Mismatch in deliver_here at step {} (coord = {:?})",
1452                    step, frame.here
1453                );
1454            }
1455
1456            step += 1;
1457        }
1458    }
1459
1460    #[test]
1461    fn test_routing_06() {
1462        use std::ops::ControlFlow;
1463
1464        use crate::selection::dsl::*;
1465        use crate::selection::routing::RoutingFrameKey;
1466        use crate::selection::routing::RoutingStep;
1467
1468        let slice = test_slice();
1469        let selection = union(all(true_()), all(true_()));
1470
1471        let mut pending = VecDeque::new();
1472        let mut dedup_delivered = Vec::new();
1473        let mut nodup_delivered = Vec::new();
1474        let mut seen = HashSet::new();
1475
1476        let root = RoutingFrame::root(selection.clone(), slice.clone());
1477        pending.push_back(RoutedMessage::<()>::new(root.here.clone(), root));
1478
1479        while let Some(RoutedMessage { frame, .. }) = pending.pop_front() {
1480            let mut visitor = |step: RoutingStep| {
1481                let next = step.into_forward().unwrap();
1482
1483                if next.action() == RoutingAction::Deliver {
1484                    nodup_delivered.push(next.slice.location(&next.here).unwrap());
1485                }
1486
1487                let key = RoutingFrameKey::new(&next);
1488                if seen.insert(key) && next.action() == RoutingAction::Deliver {
1489                    dedup_delivered.push(next.slice.location(&next.here).unwrap());
1490                }
1491
1492                if next.action() == RoutingAction::Forward {
1493                    pending.push_back(RoutedMessage::new(frame.here.clone(), next));
1494                }
1495
1496                ControlFlow::Continue(())
1497            };
1498
1499            let _ = frame.next_steps(
1500                &mut |_| panic!("Choice encountered in test_routing_06"),
1501                &mut visitor,
1502            );
1503        }
1504
1505        assert_eq!(dedup_delivered.len(), 64);
1506        assert_eq!(nodup_delivered.len(), 128);
1507    }
1508
1509    #[test]
1510    fn test_routing_07() {
1511        use std::ops::ControlFlow;
1512
1513        use crate::selection::dsl::*;
1514        use crate::selection::routing::RoutingFrame;
1515        use crate::selection::routing::RoutingStep;
1516
1517        let slice = test_slice(); // shape: [2, 4, 8]
1518
1519        // Selection: any zone, all hosts, all gpus.
1520        let selection = any(all(all(true_())));
1521        let frame = RoutingFrame::root(selection, slice.clone());
1522
1523        let mut steps = vec![];
1524        let _ = frame.next_steps(
1525            &mut |_| panic!("Choice encountered in test_routing_07"),
1526            &mut |step: RoutingStep| {
1527                steps.push(step);
1528                ControlFlow::Continue(())
1529            },
1530        );
1531
1532        // Only one hop should be produced at the `any` dimension.
1533        assert_eq!(steps.len(), 1);
1534
1535        // Reject choices.
1536        let hop = &steps[0].as_forward().unwrap();
1537
1538        // There should be 3 components to the frame's coordinate.
1539        assert_eq!(hop.here.len(), 3);
1540
1541        // The selected zone (dim 0) should be in bounds.
1542        let zone = hop.here[0];
1543        assert!(zone < 2, "zone out of bounds: {}", zone);
1544
1545        // Inner selection should still be All(All(True))
1546        assert!(matches!(hop.selection, Selection::All(_)));
1547    }
1548
1549    // This test relies on a deep structural property of the routing
1550    // semantics:
1551    //
1552    //   Overdelivery is prevented not by ad hoc guards, but by the
1553    //   structure of the traversal itself — particularly in the
1554    //   presence of routing frame deduplication.
1555    //
1556    // When a frame reaches the final dimension with `selection ==
1557    // True`, it becomes a delivery frame. If multiple such frames
1558    // target the same coordinate, then:
1559    //
1560    //   - They must share the same coordinate `here`
1561    //   - They must have reached it via the same routing path (by the
1562    //     Unique Path Theorem)
1563    //   - Their `RoutingFrame` state is thus structurally identical:
1564    //       - Same `here`
1565    //       - Same `dim` (equal to `slice.num_dim()`)
1566    //       - Same residual `selection == True`
1567    //
1568    // The deduplication logic (via `RoutingFrameKey`) collapses such
1569    // structurally equivalent frames. As a result, only one frame
1570    // delivers to the target coordinate, and overdelivery is
1571    // structurally ruled out.
1572    //
1573    // This test verifies that behavior holds as expected — and, when
1574    // deduplication is disabled, confirms that overdelivery becomes
1575    // observable.
1576    #[test]
1577    fn test_routing_deduplication_precludes_overdelivery() {
1578        // Ensure the environment is clean — this test depends on a
1579        // known configuration of deduplication behavior.
1580        let var = "HYPERACTOR_SELECTION_DISABLE_ROUTING_FRAME_DEDUPLICATION";
1581        assert!(
1582            std::env::var_os(var).is_none(),
1583            "env var `{}` should not be set prior to test",
1584            var
1585        );
1586        let slice = test_slice();
1587
1588        // Construct a structurally duplicated selection.
1589        //
1590        // The union duplicates a singleton selection expression.
1591        // Without deduplication, this would result in two logically
1592        // identical frames targeting the same node — which should
1593        // trigger an over-delivery panic in the simulation.
1594        let a = range(0, range(0, range(0, true_())));
1595        let sel = union(a.clone(), a.clone());
1596
1597        // Sanity check: with deduplication enabled (default), this
1598        // selection does not cause overdelivery.
1599        let result = std::panic::catch_unwind(|| {
1600            let _ = collect_routed_paths(&sel, &slice);
1601        });
1602        assert!(result.is_ok(), "Unexpected panic due to overdelivery");
1603
1604        // Now explicitly disable deduplication.
1605        // SAFETY: TODO: Audit that the environment access only
1606        // happens in single-threaded code.
1607        unsafe { std::env::set_var(var, "1") };
1608
1609        // Expect overdelivery: the duplicated union arms will each
1610        // produce a delivery to the same coordinate.
1611        let result = std::panic::catch_unwind(|| {
1612            let _ = collect_routed_paths(&sel, &slice);
1613        });
1614
1615        // Clean up: restore environment to avoid affecting other
1616        // tests.
1617        // SAFETY: TODO: Audit that the environment access only
1618        // happens in single-threaded code.
1619        unsafe { std::env::remove_var(var) };
1620
1621        assert!(
1622            result.is_err(),
1623            "Expected panic due to overdelivery, but no panic occurred"
1624        );
1625    }
1626
1627    #[test]
1628    fn test_next_steps_zero_dim_slice() {
1629        use std::ops::ControlFlow;
1630
1631        use crate::selection::dsl::*;
1632
1633        let slice = Slice::new(42, vec![], vec![]).unwrap();
1634
1635        let selection = true_();
1636        let frame = RoutingFrame::root(selection, slice.clone());
1637        let mut steps = vec![];
1638        let _ = frame.next_steps(
1639            &mut |_| panic!("Unexpected Choice in 0D test"),
1640            &mut |step| {
1641                steps.push(step);
1642                ControlFlow::Continue(())
1643            },
1644        );
1645
1646        assert_eq!(steps.len(), 1);
1647        let step = steps[0].as_forward().unwrap();
1648        assert_eq!(step.here, vec![0]);
1649        assert!(step.deliver_here());
1650        assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1651
1652        let selection = all(true_());
1653        let frame = RoutingFrame::root(selection, slice.clone());
1654        let mut steps = vec![];
1655        let _ = frame.next_steps(
1656            &mut |_| panic!("Unexpected Choice in 0D test"),
1657            &mut |step| {
1658                steps.push(step);
1659                ControlFlow::Continue(())
1660            },
1661        );
1662
1663        assert_eq!(steps.len(), 1);
1664        let step = steps[0].as_forward().unwrap();
1665        assert_eq!(step.here, vec![0]);
1666        assert!(step.deliver_here());
1667        assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1668
1669        let selection = false_();
1670        let frame = RoutingFrame::root(selection, slice.clone());
1671        let mut steps = vec![];
1672        let _ = frame.next_steps(
1673            &mut |_| panic!("Unexpected Choice in 0D test"),
1674            &mut |step| {
1675                steps.push(step);
1676                ControlFlow::Continue(())
1677            },
1678        );
1679
1680        assert_eq!(steps.len(), 1);
1681        let step = steps[0].as_forward().unwrap();
1682        assert_eq!(step.here, vec![0]);
1683        assert!(!step.deliver_here());
1684        assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1685
1686        let selection = all(false_());
1687        let frame = RoutingFrame::root(selection, slice.clone());
1688        let mut steps = vec![];
1689        let _ = frame.next_steps(
1690            &mut |_| panic!("Unexpected Choice in 0D test"),
1691            &mut |step| {
1692                steps.push(step);
1693                ControlFlow::Continue(())
1694            },
1695        );
1696        assert_eq!(steps.len(), 1);
1697        let step = steps[0].as_forward().unwrap();
1698        assert_eq!(step.here, vec![0]);
1699        assert!(!step.deliver_here());
1700        assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1701    }
1702}