ndslice/selection/routing.rs
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # Routing
10//!
11//! This module defines [`RoutingFrame`] and its [`next_steps`]
12//! method, which model how messages propagate through a
13//! multidimensional mesh based on a [`Selection`] expression.
14//!
15//! A [`RoutingFrame`] represents the state of routing at a particular
16//! point in the mesh. It tracks the current coordinate (`here`), the
17//! remaining selection to apply (`selection`), the mesh layout
18//! (`slice`), and the current dimension of traversal (`dim`).
19//!
20//! [`next_steps`] defines a routing-specific evaluation
21//! strategy for `Selection`. Unlike [`Selection::eval`], which
22//! produces flat indices that match a selection, this method produces
23//! intermediate routing states — new frames or deferred steps to
24//! continue traversing.
25//!
26//! Rather than returning raw frames directly, [`next_steps`]
27//! produces a stream of [`RoutingStep`]s via a callback — each
28//! representing a distinct kind of routing progression:
29//!
30//! - [`RoutingStep::Forward`] indicates that routing proceeds
31//! deterministically to a new [`RoutingFrame`] — the next
32//! coordinate is fully determined by the current selection and
33//! frame state.
34//!
35//! - [`RoutingStep::Choice`] represents a deferred decision: it
36//! returns a set of admissible indices, and **the caller must
37//! select one** (e.g., for load balancing or policy-based routing)
38//! **before routing can proceed**.
39//!
40//! In this way, non-determinism is treated as a **first-class,
41//! policy-driven** aspect of the routing system — enabling
42//! inspection, customization, and future extensions without
43//! complicating the core traversal logic.
44//!
45//! A frame is considered a delivery target if its selection is
46//! [`Selection::True`] and all dimensions have been traversed, as
47//! determined by [`RoutingFrame::deliver_here`]. All other frames are
48//! forwarded further using [`RoutingFrame::should_route`].
49//!
50//! This design enables **compositional**, **local**, and **scalable**
51//! routing:
52//! - **Compositional**: complex selection expressions decompose into
53//! simpler, independently evaluated sub-selections.
54//! - **Local**: each frame carries exactly the state needed for its
55//! next step — no global coordination or lookahead is required.
56//! - **Scalable**: routing unfolds recursively, one hop at a time,
57//! allowing for efficient traversal even in high-dimensional spaces.
58//!
59//! This module provides the foundation for building structured,
60//! recursive routing logic over multidimensional coordinate spaces.
61use std::collections::HashMap;
62use std::collections::HashSet;
63use std::fmt::Write;
64use std::hash::Hash;
65use std::ops::ControlFlow;
66use std::sync::Arc;
67
68use anyhow::Result;
69use enum_as_inner::EnumAsInner;
70use serde::Deserialize;
71use serde::Serialize;
72use serde::de::DeserializeOwned;
73
74use crate::SliceError;
75use crate::selection::NormalizedSelectionKey;
76use crate::selection::Selection;
77use crate::selection::Slice;
78
79/// Represents the outcome of evaluating a routing step.
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum RoutingAction {
82 Deliver,
83 Forward,
84}
85
86/// `RoutingFrame` captures the state of a selection being evaluated:
87/// the current coordinate (`here`), the remaining selection to apply,
88/// the shape and layout information (`slice`), and the current
89/// dimension (`dim`).
90///
91/// Each frame represents an independent routing decision and produces
92/// zero or more new frames via `next_steps`.
93#[derive(Clone, Debug, Serialize, Deserialize)]
94pub struct RoutingFrame {
95 /// The current coordinate in the mesh where this frame is being
96 /// evaluated.
97 ///
98 /// This is the source location for the next routing step.
99 pub here: Vec<usize>,
100
101 /// The residual selection expression describing where routing
102 /// should continue.
103 ///
104 /// At each step, only the current dimension (tracked by `dim`) of
105 /// this selection is considered.
106 pub selection: Selection,
107
108 /// The shape and layout of the full multidimensional space being
109 /// routed.
110 ///
111 /// This determines the bounds and stride information used to
112 /// compute coordinates and flat indices.
113 pub slice: Arc<Slice>,
114
115 /// The current axis of traversal within the selection and slice.
116 ///
117 /// Routing proceeds dimension-by-dimension; this value tracks how
118 /// many dimensions have already been routed.
119 pub dim: usize,
120}
121
122// Compile-time check: ensure `RoutingFrame` is thread-safe and fully
123// owned.
124fn _assert_routing_frame_traits()
125where
126 RoutingFrame: Send + Sync + Serialize + DeserializeOwned + 'static,
127{
128}
129
130/// A `RoutingStep` represents a unit of progress in the routing
131/// process.
132///
133/// Emitted by [`RoutingFrame::next_steps`], each step describes
134/// how routing should proceed from a given frame:
135///
136/// - [`RoutingStep::Forward`] represents a deterministic hop to the
137/// next coordinate in the mesh, with an updated [`RoutingFrame`].
138///
139/// - [`RoutingStep::Choice`] indicates that routing cannot proceed
140/// until the caller selects one of several admissible indices. This
141/// allows for policy-driven or non-deterministic routing behavior,
142/// such as load balancing.
143#[derive(Debug, Clone, EnumAsInner)]
144pub enum RoutingStep {
145 /// A deterministic routing hop to the next coordinate. Carries an
146 /// updated [`RoutingFrame`] describing the new position and
147 /// residual selection.
148 Forward(RoutingFrame),
149
150 /// A deferred routing decision at the current dimension. Contains
151 /// a set of admissible indices and a residual [`RoutingFrame`] to
152 /// continue routing once a choice is made.
153 Choice(Choice),
154}
155
156/// A deferred routing decision as contained in a
157/// [`RoutingStep::Choice`].
158///
159/// A `Choice` contains:
160/// - `candidates`: the admissible indices at the current dimension
161/// - `frame`: the residual [`RoutingFrame`] describing how routing
162/// continues once a choice is made
163///
164/// To continue routing, the caller must select one of the
165/// `candidates` and call [`Choice::choose`] to produce the
166/// corresponding [`RoutingStep::Forward`].
167#[derive(Debug, Clone)]
168pub struct Choice {
169 pub(crate) candidates: Vec<usize>,
170 pub(crate) frame: RoutingFrame,
171}
172
173impl Choice {
174 /// Returns the list of admissible indices at the current
175 /// dimension.
176 ///
177 /// These represent the valid choices that the caller can select
178 /// from when resolving this deferred routing step.
179 pub fn candidates(&self) -> &[usize] {
180 &self.candidates
181 }
182
183 /// Returns a reference to the residual [`RoutingFrame`]
184 /// associated with this choice.
185 ///
186 /// This frame encodes the selection and mesh context to be used
187 /// once a choice is made, and routing continues at the next
188 /// dimension.
189 pub fn frame(&self) -> &RoutingFrame {
190 &self.frame
191 }
192
193 /// Resolves the choice by selecting a specific index.
194 ///
195 /// Constrains the residual selection to the chosen index at the
196 /// current dimension and returns a [`RoutingStep::Forward`] for
197 /// continued routing.
198 pub fn choose(self, index: usize) -> RoutingStep {
199 // The only thing `next()` has to do is constrain the
200 // selection to a concrete choice at the current dimension.
201 // `self.frame.selection` is the residual (inner) selection to
202 // be applied *past* the current dimension.
203 RoutingStep::Forward(RoutingFrame {
204 selection: crate::dsl::range(index..=index, self.frame.selection),
205 ..self.frame
206 })
207 }
208}
209
210/// Key used to deduplicate routing frames.
211#[derive(Debug, Hash, PartialEq, Eq)]
212pub struct RoutingFrameKey {
213 here: Vec<usize>,
214 dim: usize,
215 selection: NormalizedSelectionKey,
216}
217
218impl RoutingFrameKey {
219 /// Constructs a `RoutingFrameKey` from a `RoutingFrame`.
220 ///
221 /// This key uniquely identifies a routing frame by its coordinate
222 /// (`here`), current dimension, and normalized selection. It is
223 /// used during traversal for purposes such as deduplication and
224 /// memoization.
225 pub fn new(frame: &RoutingFrame) -> Self {
226 Self {
227 here: frame.here.clone(),
228 dim: frame.dim,
229 selection: NormalizedSelectionKey::new(&frame.selection),
230 }
231 }
232}
233
234impl RoutingFrame {
235 /// Constructs the initial frame at the root coordinate (all
236 /// zeros). Selections are expanded as necessary to ensure they
237 /// have depth equal to the slice dimensionality. See the docs for
238 /// `canonicalize_to_dimensions` for the rules.
239 ///
240 /// ### Canonical Handling of Zero-Dimensional Slices
241 ///
242 /// A `Slice` with zero dimensions represents the empty product
243 /// `∏_{i=1}^{0} Xᵢ`, which has exactly one element: the empty
244 /// tuple. To maintain uniform routing semantics, we canonically
245 /// embed such 0D slices as 1D slices of extent 1:
246 ///
247 /// ```text
248 /// Slice::new(offset, [1], [1])
249 /// ```
250 ///
251 /// This embedding preserves the correct number of addressable
252 /// points and allows the routing machinery to proceed through the
253 /// usual recursive strategy without introducing special cases. The
254 /// selected coordinate is `vec![0]`, and `dim = 0` proceeds as
255 /// usual. This makes the routing logic consistent with evaluation
256 /// and avoids edge case handling throughout the codebase.
257 pub fn root(selection: Selection, slice: Slice) -> Self {
258 // Canonically embed 0D as 1D (extent 1).
259 let slice = if slice.num_dim() > 0 {
260 Arc::new(slice)
261 } else {
262 Arc::new(Slice::new(slice.offset(), vec![1], vec![1]).unwrap())
263 };
264 let n = slice.num_dim();
265 RoutingFrame {
266 here: vec![0; n],
267 selection: selection.canonicalize_to_dimensions(n),
268 slice,
269 dim: 0,
270 }
271 }
272
273 /// Produces a new frame advanced to the next dimension with
274 /// updated position and selection.
275 pub fn advance(&self, here: Vec<usize>, selection: Selection) -> Self {
276 RoutingFrame {
277 here,
278 selection,
279 slice: Arc::clone(&self.slice),
280 dim: self.dim + 1,
281 }
282 }
283
284 /// Returns a new frame with the same position and dimension but a
285 /// different selection.
286 pub fn with_selection(&self, selection: Selection) -> Self {
287 RoutingFrame {
288 here: self.here.clone(),
289 selection,
290 slice: Arc::clone(&self.slice),
291 dim: self.dim,
292 }
293 }
294
295 /// Determines the appropriate routing action for this frame.
296 ///
297 /// Returns [`RoutingAction::Deliver`] if the message should be
298 /// delivered at this coordinate, or [`RoutingAction::Forward`] if
299 /// it should be routed further.
300 pub fn action(&self) -> RoutingAction {
301 if self.deliver_here() {
302 RoutingAction::Deliver
303 } else {
304 RoutingAction::Forward
305 }
306 }
307
308 /// Returns the location of this frame in the underlying slice.
309 pub fn location(&self) -> Result<usize, SliceError> {
310 self.slice.location(&self.here)
311 }
312}
313
314impl RoutingFrame {
315 /// Visits the next routing steps from this frame using a
316 /// callback-based traversal.
317 /// This method structurally recurses on the [`Selection`]
318 /// expression, yielding [`RoutingStep`]s via the `f` callback.
319 /// Early termination is supported via [`ControlFlow::Break`].
320 ///
321 /// Compared to the (old, now removed) [`next_steps`] method, this
322 /// avoids intermediate allocation, supports interruptibility, and
323 /// allows policy-driven handling of [`RoutingStep::Choice`]s via
324 /// the `chooser` callback.
325 ///
326 /// ---
327 ///
328 /// ### Traversal Strategy
329 ///
330 /// The traversal proceeds **dimension-by-dimension**,
331 /// structurally mirroring the shape of the selection expression:
332 ///
333 /// - [`Selection::All`] and [`Selection::Range`] iterate over a
334 /// range of coordinates, emitting one [`RoutingStep::Forward`]
335 /// per valid index.
336 /// - [`Selection::Union`] and [`Selection::Intersection`] recurse
337 /// into both branches. Intersection steps are joined at matching
338 /// coordinates and residual selections are reduced.
339 /// - [`Selection::Any`] randomly selects one index along the
340 /// current dimension and emits a single step.
341 /// - [`Selection::True`] and [`Selection::False`] emit no steps.
342 ///
343 /// At each step, only the current dimension (tracked via `self.dim`)
344 /// is evaluated. Future dimensions remain untouched until deeper
345 /// recursion.
346 ///
347 /// ---
348 ///
349 /// ### Evaluation Semantics
350 ///
351 /// - **Selection::True**
352 /// No further routing is performed — if this frame is at the
353 /// final dimension, delivery occurs.
354 /// - **Selection::False**
355 /// No match — routing halts.
356 ///
357 /// - **Selection::All / Selection::Range**
358 /// Emits one [`RoutingStep::Forward`] per matching index, each
359 /// advancing to the next dimension with the inner selection.
360 ///
361 /// - **Selection::Union**
362 /// Evaluates both branches independently and emits all
363 /// resulting steps.
364 ///
365 /// - **Selection::Intersection**
366 /// Emits only those steps where both branches produce the same
367 /// coordinate, combining the residual selections at that point.
368 ///
369 /// - **Selection::Any**
370 /// Randomly selects one index and emits a single
371 /// [`RoutingStep::Forward`].
372 ///
373 /// - **Selection::Choice**
374 /// Defers decision to the caller by invoking the `chooser`
375 /// function, which resolves the candidate index.
376 ///
377 /// ---
378 ///
379 /// ### Delivery Semantics
380 ///
381 /// Message delivery is determined by
382 /// [`RoutingFrame::deliver_here`], which returns true when:
383 ///
384 /// - The frame’s selection is [`Selection::True`], and
385 /// - All dimensions have been traversed (`dim ==
386 /// slice.num_dim()`).
387 ///
388 /// ---
389 ///
390 /// ### Panics
391 ///
392 /// Panics if `slice.num_dim() == 0`. Use a canonical embedding
393 /// (e.g., 0D → 1D) before calling this (see e.g.
394 /// `RoutingFrame::root`).
395 ///
396 /// ---
397 ///
398 /// ### Summary
399 ///
400 /// - **Structure-driven**: Mirrors the shape of the selection
401 /// expression.
402 /// - **Compositional**: Each variant defines its own traversal
403 /// behavior.
404 /// - **Interruptible**: Early termination is supported via
405 /// [`ControlFlow`].
406 /// - **Minimally allocating**: Avoids intermediate buffers in
407 /// most cases; only [`Selection::Intersection`] allocates
408 /// temporary state for pairwise matching.
409 /// - **Policy-ready**: Integrates with runtime routing policies
410 /// via the `chooser`.
411 pub fn next_steps(
412 &self,
413 _chooser: &mut dyn FnMut(&Choice) -> usize,
414 f: &mut dyn FnMut(RoutingStep) -> ControlFlow<()>,
415 ) -> ControlFlow<()> {
416 assert!(self.slice.num_dim() > 0, "next_steps requires num_dims > 0");
417
418 match &self.selection {
419 Selection::True => ControlFlow::Continue(()),
420 Selection::False => ControlFlow::Continue(()),
421 Selection::All(inner) => {
422 let size = self.slice.sizes()[self.dim];
423 for i in 0..size {
424 let mut coord = self.here.clone();
425 coord[self.dim] = i;
426 let frame = self.advance(coord, (**inner).clone());
427 if let ControlFlow::Break(_) = f(RoutingStep::Forward(frame)) {
428 return ControlFlow::Break(());
429 }
430 }
431 ControlFlow::Continue(())
432 }
433
434 Selection::Range(range, inner) => {
435 let size = self.slice.sizes()[self.dim];
436 let (min, max, step) = range.resolve(size);
437
438 for i in (min..max).step_by(step) {
439 let mut coord = self.here.clone();
440 coord[self.dim] = i;
441 let frame = self.advance(coord, (**inner).clone());
442 if let ControlFlow::Break(_) = f(RoutingStep::Forward(frame)) {
443 return ControlFlow::Break(());
444 }
445 }
446
447 ControlFlow::Continue(())
448 }
449
450 Selection::Any(inner) => {
451 let size = self.slice.sizes()[self.dim];
452 if size == 0 {
453 return ControlFlow::Continue(());
454 }
455
456 let i = rand::random_range(0..size);
457 let mut coord = self.here.clone();
458 coord[self.dim] = i;
459 let frame = self.advance(coord, (**inner).clone());
460 f(RoutingStep::Forward(frame))
461 }
462
463 Selection::Union(a, b) => {
464 if let ControlFlow::Break(_) =
465 self.with_selection((**a).clone()).next_steps(_chooser, f)
466 {
467 return ControlFlow::Break(());
468 }
469 self.with_selection((**b).clone()).next_steps(_chooser, f)
470 }
471
472 Selection::Intersection(a, b) => {
473 let mut left = vec![];
474 let mut right = vec![];
475
476 let mut collect_left = |step: RoutingStep| {
477 if let RoutingStep::Forward(frame) = step {
478 left.push(frame);
479 }
480 ControlFlow::Continue(())
481 };
482 let mut collect_right = |step: RoutingStep| {
483 if let RoutingStep::Forward(frame) = step {
484 right.push(frame);
485 }
486 ControlFlow::Continue(())
487 };
488
489 self.with_selection((**a).clone())
490 .next_steps(_chooser, &mut collect_left)?;
491 self.with_selection((**b).clone())
492 .next_steps(_chooser, &mut collect_right)?;
493
494 for fa in &left {
495 for fb in &right {
496 if fa.here == fb.here {
497 let residual = fa
498 .selection
499 .clone()
500 .reduce_intersection(fb.selection.clone());
501 let frame = self.advance(fa.here.clone(), residual);
502 if let ControlFlow::Break(_) = f(RoutingStep::Forward(frame)) {
503 return ControlFlow::Break(());
504 }
505 }
506 }
507 }
508
509 ControlFlow::Continue(())
510 }
511
512 // TODO(SF, 2025-04-30): This term is not in the algebra
513 // yet.
514 // Selection::LoadBalanced(inner) => {
515 // let size = self.slice.sizes()[self.dim];
516 // if size == 0 {
517 // ControlFlow::Continue(())
518 // } else {
519 // let candidates = (0..size).collect();
520 // let choice = Choice {
521 // candidates,
522 // frame: self.with_selection((*inner).clone()),
523 // };
524 // let index = chooser(&choice);
525 // f(choice.choose(index))
526 // }
527 // }
528
529 // Catch-all for future combinators (e.g., Label).
530 _ => unimplemented!(),
531 }
532 }
533
534 /// Returns true if this frame represents a terminal delivery
535 /// point — i.e., the selection is `True` and all dimensions have
536 /// been traversed.
537 pub fn deliver_here(&self) -> bool {
538 matches!(self.selection, Selection::True) && self.dim == self.slice.num_dim()
539 }
540
541 /// Returns true if the message has not yet reached its final
542 /// destination and should be forwarded to the next routing step.
543 pub fn should_route(&self) -> bool {
544 !self.deliver_here()
545 }
546}
547
548impl RoutingFrame {
549 /// Traces the unique routing path to the given destination
550 /// coordinate.
551 ///
552 /// Returns `Some(vec![root, ..., dest])` if `dest` is selected,
553 /// or `None` if not.
554 pub fn trace_route(&self, dest: &[usize]) -> Option<Vec<Vec<usize>>> {
555 use std::collections::HashSet;
556 use std::ops::ControlFlow;
557
558 use crate::selection::routing::RoutingFrameKey;
559
560 fn go(
561 frame: RoutingFrame,
562 dest: &[usize],
563 mut path: Vec<Vec<usize>>,
564 seen: &mut HashSet<RoutingFrameKey>,
565 ) -> Option<Vec<Vec<usize>>> {
566 let key = RoutingFrameKey::new(&frame);
567 if !seen.insert(key) {
568 return None;
569 }
570
571 path.push(frame.here.clone());
572
573 if frame.deliver_here() && frame.here == dest {
574 return Some(path);
575 }
576
577 let mut found = None;
578 let _ = frame.next_steps(
579 &mut |_| panic!("Choice encountered in trace_route"),
580 &mut |step: RoutingStep| {
581 let next = step.into_forward().unwrap();
582 if let Some(result) = go(next, dest, path.clone(), seen) {
583 found = Some(result);
584 ControlFlow::Break(())
585 } else {
586 ControlFlow::Continue(())
587 }
588 },
589 );
590
591 found
592 }
593
594 let mut seen = HashSet::new();
595 go(self.clone(), dest, Vec::new(), &mut seen)
596 }
597}
598
599/// Formats a routing path as a string, showing each hop in order.
600///
601/// Each line shows the hop index, an arrow (`→` for intermediate
602/// steps, `⇨` for the final destination), and the coordinate as a
603/// tuple (e.g., `(0, 1)`).
604/// # Example
605///
606/// ```text
607/// 0 → (0, 0)
608/// 1 → (0, 1)
609/// 2 ⇨ (1, 1)
610/// ```
611#[track_caller]
612#[allow(dead_code)]
613pub fn format_route(route: &[Vec<usize>]) -> String {
614 let mut out = String::new();
615 for (i, hop) in route.iter().enumerate() {
616 let arrow = if i == route.len() - 1 { "⇨" } else { "→" };
617 let coord = format!(
618 "({})",
619 hop.iter()
620 .map(ToString::to_string)
621 .collect::<Vec<_>>()
622 .join(", ")
623 );
624 let _ = writeln!(&mut out, "{:>2} {} {}", i, arrow, coord);
625 }
626 out
627}
628
629/// Formats a routing tree as an indented string.
630///
631/// Traverses the tree of `RoutingFrame`s starting from the root,
632/// displaying each step with indentation by dimension. Delivery
633/// targets are marked `✅`.
634///
635/// # Example
636/// ```text
637/// (0, 0)
638/// (0, 1) ✅
639/// (1, 0)
640/// (1, 1) ✅
641/// ```
642#[track_caller]
643#[allow(dead_code)]
644pub fn format_routing_tree(selection: Selection, slice: &Slice) -> String {
645 let root = RoutingFrame::root(selection, slice.clone());
646 let mut out = String::new();
647 let mut seen = HashSet::new();
648 format_routing_tree_rec(&root, 0, &mut out, &mut seen).unwrap();
649 out
650}
651
652fn format_routing_tree_rec(
653 frame: &RoutingFrame,
654 indent: usize,
655 out: &mut String,
656 seen: &mut HashSet<RoutingFrameKey>,
657) -> std::fmt::Result {
658 use crate::selection::routing::RoutingFrameKey;
659
660 let key = RoutingFrameKey::new(frame);
661 if !seen.insert(key) {
662 return Ok(()); // already visited
663 }
664
665 let indent_str = " ".repeat(indent);
666 let coord_str = format!(
667 "({})",
668 frame
669 .here
670 .iter()
671 .map(ToString::to_string)
672 .collect::<Vec<_>>()
673 .join(", ")
674 );
675
676 match frame.action() {
677 RoutingAction::Deliver => {
678 writeln!(out, "{}{} ✅", indent_str, coord_str)?;
679 }
680 RoutingAction::Forward => {
681 writeln!(out, "{}{}", indent_str, coord_str)?;
682 let _ = frame.next_steps(
683 &mut |_| panic!("Choice encountered in format_routing_tree_rec"),
684 &mut |step| {
685 let next = step.into_forward().unwrap();
686 format_routing_tree_rec(&next, indent + 1, out, seen).unwrap();
687 ControlFlow::Continue(())
688 },
689 );
690 }
691 }
692
693 Ok(())
694}
695
696// Pretty-prints a routing path from source to destination.
697//
698// Each hop is shown as a numbered step with directional arrows.
699#[track_caller]
700#[allow(dead_code)]
701pub fn print_route(route: &[Vec<usize>]) {
702 println!("{}", format_route(route));
703}
704
705/// Prints the routing tree for a selection over a slice.
706///
707/// Traverses the routing structure from the root, printing each step
708/// with indentation by dimension. Delivery points are marked with
709/// `✅`.
710#[track_caller]
711#[allow(dead_code)]
712pub fn print_routing_tree(selection: Selection, slice: &Slice) {
713 println!("{}", format_routing_tree(selection, slice));
714}
715
716// == "CommActor multicast" routing ==
717
718/// Resolves the current set of routing frames (`dests`) to determine
719/// whether the message should be delivered at this rank, and which
720/// routing frames should be forwarded to peer ranks.
721///
722/// This is the continuation of a multicast operation: each forwarded
723/// message contains one or more `RoutingFrame`s that represent
724/// partial routing state. This call determines how those frames
725/// propagate next.
726///
727/// `deliver_here` is true if any frame targets this rank and
728/// indicates delivery. `next_steps` contains the peer ranks and frames
729/// to forward.
730///
731/// This is also the top-level entry point for CommActor's routing
732/// logic.
733pub fn resolve_routing(
734 rank: usize,
735 frames: impl IntoIterator<Item = RoutingFrame>,
736 chooser: &mut dyn FnMut(&Choice) -> usize,
737) -> Result<(bool, HashMap<usize, Vec<RoutingFrame>>)> {
738 let mut deliver_here = false;
739 let mut next_steps = HashMap::new();
740 for frame in frames {
741 resolve_routing_one(rank, frame, chooser, &mut deliver_here, &mut next_steps)?;
742 }
743 Ok((deliver_here, next_steps))
744}
745
746/// Recursively resolves routing for a single `RoutingFrame` at the
747/// given rank, determining whether the message should be delivered
748/// locally and which frames should be forwarded to peer ranks.
749///
750/// - If the frame targets the local `rank` and is a delivery point,
751/// `deliver_here` is set to `true`.
752/// - If the frame targets the local `rank` but is not a delivery
753/// point, the function recurses on its forward steps.
754/// - If the frame targets a different rank, it is added to
755/// `next_steps`.
756///
757/// Deduplication is handled by `get_next_steps`. Dynamic constructs
758/// such as `Any` or `First` must be resolved by the provided
759/// `chooser`, which selects an index from a `Choice`.
760///
761/// Traversal is depth-first within a rank and breadth-first across
762/// ranks. This defines the exact routing behavior used by
763/// `CommActor`: it exhaustively evaluates all local routing structure
764/// before forwarding to peers.
765///
766/// The resulting `next_steps` map contains all non-local ranks that
767/// should receive forwarded frames, where each entry maps a peer rank
768/// to a list of routing continuations to evaluate at that peer.
769/// `deliver_here` is set to `true` if the current rank is a final
770/// delivery point.
771pub(crate) fn resolve_routing_one(
772 rank: usize,
773 frame: RoutingFrame,
774 chooser: &mut dyn FnMut(&Choice) -> usize,
775 deliver_here: &mut bool,
776 next_steps: &mut HashMap<usize, Vec<RoutingFrame>>,
777) -> Result<()> {
778 let frame_rank = frame.slice.location(&frame.here)?;
779 if frame_rank == rank {
780 if frame.deliver_here() {
781 *deliver_here = true;
782 } else {
783 for next in get_next_steps(frame, chooser)? {
784 resolve_routing_one(rank, next, chooser, deliver_here, next_steps)?;
785 }
786 }
787 } else {
788 next_steps.entry(frame_rank).or_default().push(frame);
789 }
790 Ok(())
791}
792
793/// Computes the set of `Forward` routing frames reachable from the
794/// given `RoutingFrame`.
795///
796/// This function traverses the result of `frame.next_steps(...)`,
797/// collecting only `RoutingStep::Forward(_)` steps. The caller
798/// provides a `chooser` function to resolve dynamic constructs such
799/// as `Any` or `First`.
800///
801/// Some obviously redundant steps may be filtered, but no strict
802/// guarantee is made about structural uniqueness.
803fn get_next_steps(
804 dest: RoutingFrame,
805 chooser: &mut dyn FnMut(&Choice) -> usize,
806) -> Result<Vec<RoutingFrame>> {
807 let mut seen = HashSet::new();
808 let mut unique_steps = vec![];
809 let _ = dest.next_steps(chooser, &mut |step| {
810 if let RoutingStep::Forward(frame) = step {
811 let key = RoutingFrameKey::new(&frame);
812 if seen.insert(key) {
813 unique_steps.push(frame);
814 }
815 }
816 ControlFlow::Continue(())
817 });
818 Ok(unique_steps)
819}
820
821// == Testing (`collect_commactor_routing_tree` mesh simulation) ===
822
823/// Captures the logical structure of a CommActor multicast operation.
824///
825/// This type models how a message is delivered and forwarded through
826/// a mesh under CommActor routing semantics. It is used in tests to
827/// verify path determinism and understand message propagation
828/// behavior.
829///
830/// - `delivered`: ranks where the message was delivered (`post`
831/// called)
832/// - `visited`: all ranks that participated, including forwarding
833/// only
834/// - `forwards`: maps each rank to the routing frames it forwarded
835#[cfg(test)]
836#[allow(dead_code)]
837#[derive(Default)]
838pub(crate) struct CommActorRoutingTree {
839 // Ranks that were delivered the message (i.e. called `post`). Map
840 // from rank → delivery path (flat rank indices) from root to that
841 // rank.
842 pub delivered: HashMap<usize, Vec<usize>>,
843
844 // Ranks that participated in the multicast - either by delivering
845 // the message or forwarding it to peers.
846 pub visited: HashSet<usize>,
847
848 /// Map from rank → routing frames this rank forwarded to other
849 /// ranks.
850 pub forwards: HashMap<usize, Vec<RoutingFrame>>,
851}
852
853#[cfg(test)]
854mod tests {
855 use std::collections::HashSet;
856 use std::collections::VecDeque;
857
858 use super::RoutingAction;
859 use super::RoutingFrame;
860 use super::print_route;
861 use super::print_routing_tree;
862 use crate::Slice;
863 use crate::selection::EvalOpts;
864 use crate::selection::Selection;
865 use crate::selection::dsl::*;
866 use crate::selection::test_utils::RoutedMessage;
867 use crate::selection::test_utils::collect_commactor_routing_tree;
868 use crate::selection::test_utils::collect_routed_nodes;
869 use crate::selection::test_utils::collect_routed_paths;
870 use crate::shape;
871
872 // A test slice: (zones = 2, hosts = 4, gpus = 8).
873 fn test_slice() -> Slice {
874 Slice::new(0usize, vec![2, 4, 8], vec![32, 8, 1]).unwrap()
875 }
876
877 /// Asserts that a routing strategy produces the same set of nodes
878 /// as `Selection::eval`.
879 ///
880 /// This macro compares the result of evaluating a `Selection`
881 /// using the given `collector` against the reference
882 /// implementation `Selection::eval` (with lenient options).
883 ///
884 /// The `collector` should be a function or closure of type
885 /// `Fn(&Selection, &Slice) -> Vec<usize>`, such as
886 /// `collect_routed_nodes` or a CommActor-based simulation.
887 ///
888 /// Panics if the two sets of routed nodes differ.
889 ///
890 /// # Example
891 /// ```
892 /// assert_routing_eq_with!(slice, selection, collect_routed_nodes);
893 /// ```
894 macro_rules! assert_routing_eq_with {
895 ($slice:expr, $sel:expr, $collector:expr) => {{
896 let sel = $sel;
897 let slice = $slice.clone();
898 let mut expected: Vec<_> = sel.eval(&EvalOpts::lenient(), &slice).unwrap().collect();
899 expected.sort();
900 let mut actual: Vec<_> = ($collector)(&sel, &slice);
901 actual.sort();
902 assert_eq!(actual, expected, "Mismatch for selection: {}", sel);
903 }};
904 }
905
906 /// Asserts that `collect_routed_nodes` matches `Selection::eval`
907 /// on the given slice.
908 macro_rules! assert_collect_routed_nodes_eq {
909 ($slice:expr, $sel:expr) => {
910 assert_routing_eq_with!($slice, $sel, collect_routed_nodes)
911 };
912 }
913
914 /// Asserts that CommActor routing delivers to the same nodes as
915 /// `Selection::eval`.
916 macro_rules! assert_commactor_routing_eq {
917 ($slice:expr, $sel:expr) => {
918 assert_routing_eq_with!($slice, $sel, |s, sl| {
919 collect_commactor_routing_tree(s, sl)
920 .delivered
921 .into_keys()
922 .collect()
923 });
924 };
925 }
926
927 /// Asserts that all routing strategies produce the same set of
928 /// routed nodes as `Selection::eval`.
929 ///
930 /// Compares both the direct strategy (`collect_routed_nodes`) and
931 /// the CommActor routing simulation
932 /// (`collect_commactor_routing_tree`) against the expected output
933 /// from `Selection::eval`.
934 macro_rules! assert_all_routing_strategies_eq {
935 ($slice:expr, $sel:expr) => {
936 assert_collect_routed_nodes_eq!($slice, $sel);
937 assert_commactor_routing_eq!($slice, $sel);
938 };
939 }
940
941 #[test]
942 fn test_routing_04() {
943 use crate::selection::dsl::*;
944
945 let slice = test_slice(); // [2, 4, 8], strides [32, 8, 1]
946
947 // Destination: GPU 2 on host 2 in zone 1.
948 let dest = vec![1, 2, 2];
949 let selection = range(1, range(2, range(2, true_())));
950 let root = RoutingFrame::root(selection.clone(), slice.clone());
951 let path = root.trace_route(&dest).expect("no route found");
952 println!(
953 "\ndest: {:?}, (singleton-)selection: ({})\n",
954 &dest, &selection
955 );
956 print_route(&path);
957 println!("\n");
958 assert_eq!(path.last(), Some(&dest));
959
960 // Destination: "Right back where we started from 🙂".
961 let dest = vec![0, 0, 0];
962 let selection = range(0, range(0, range(0, true_())));
963 let root = RoutingFrame::root(selection.clone(), slice.clone());
964 let path = root.trace_route(&dest).expect("no route found");
965 println!(
966 "\ndest: {:?}, (singleton-)selection: ({})\n",
967 &dest, &selection
968 );
969 print_route(&path);
970 println!("\n");
971 assert_eq!(path.last(), Some(&dest));
972 }
973
974 #[test]
975 fn test_routing_05() {
976 use crate::selection::dsl::*;
977
978 // "Jun's example" -- a 2 x 2 row major mesh.
979 let slice = Slice::new(0usize, vec![2, 2], vec![2, 1]).unwrap();
980 // Thats is,
981 // (0, 0) (0, 1)
982 // (0, 1) (1, 0)
983 //
984 // and we want to cast to {(0, 1), (1, 0) and (1, 1)}:
985 //
986 // (0, 0)❌ (0, 1)✅
987 // (0, 1)✅ (1, 0)✅
988 //
989 // One reasonable selection expression describing the
990 // destination set.
991 let selection = union(range(0, range(1, true_())), range(1, all(true_())));
992
993 // Now print the routing tree.
994 print_routing_tree(selection, &slice);
995
996 // Prints:
997 // (0, 0)
998 // (0, 0)
999 // (0, 1) ✅
1000 // (1, 0)
1001 // (1, 0) ✅
1002 // (1, 1) ✅
1003
1004 // Another example: (zones = 2, hosts = 4, gpus = 8).
1005 let slice = Slice::new(0usize, vec![2, 4, 8], vec![32, 8, 1]).unwrap();
1006 // Let's have all the odd GPUs on hosts 1, 2 and 3 in zone 0.
1007 let selection = range(
1008 0,
1009 range(1..4, range(shape::Range(1, None, /*step*/ 2), true_())),
1010 );
1011
1012 // Now print the routing tree.
1013 print_routing_tree(selection, &slice);
1014
1015 // Prints:
1016 // (0, 0, 0)
1017 // (0, 0, 0)
1018 // (0, 1, 0)
1019 // (0, 1, 1) ✅
1020 // (0, 1, 3) ✅
1021 // (0, 1, 5) ✅
1022 // (0, 1, 7) ✅
1023 // (0, 2, 0)
1024 // (0, 2, 1) ✅
1025 // (0, 2, 3) ✅
1026 // (0, 2, 5) ✅
1027 // (0, 2, 7) ✅
1028 // (0, 3, 0)
1029 // (0, 3, 1) ✅
1030 // (0, 3, 3) ✅
1031 // (0, 3, 5) ✅
1032 // (0, 3, 7) ✅
1033 }
1034
1035 #[test]
1036 fn test_routing_00() {
1037 let slice = test_slice();
1038
1039 assert_all_routing_strategies_eq!(slice, false_());
1040 assert_all_routing_strategies_eq!(slice, true_());
1041 assert_all_routing_strategies_eq!(slice, all(true_()));
1042 assert_all_routing_strategies_eq!(slice, all(all(true_())));
1043 assert_all_routing_strategies_eq!(slice, all(all(false_())));
1044 assert_all_routing_strategies_eq!(slice, all(all(all(true_()))));
1045 assert_all_routing_strategies_eq!(slice, all(range(0..=0, all(true_()))));
1046 assert_all_routing_strategies_eq!(slice, all(all(range(0..4, true_()))));
1047 assert_all_routing_strategies_eq!(slice, all(range(1..=2, all(true_()))));
1048 assert_all_routing_strategies_eq!(slice, all(all(range(2..6, true_()))));
1049 assert_all_routing_strategies_eq!(slice, all(all(range(3..=3, true_()))));
1050 assert_all_routing_strategies_eq!(slice, all(range(1..3, all(true_()))));
1051 assert_all_routing_strategies_eq!(slice, all(all(range(0..=0, true_()))));
1052 assert_all_routing_strategies_eq!(slice, range(1..=1, range(3..=3, range(0..=2, true_()))));
1053 assert_all_routing_strategies_eq!(
1054 slice,
1055 all(all(range(shape::Range(0, Some(8), 2), true_())))
1056 );
1057 assert_all_routing_strategies_eq!(
1058 slice,
1059 all(range(shape::Range(1, Some(4), 2), all(true_())))
1060 );
1061 }
1062
1063 #[test]
1064 fn test_routing_03() {
1065 let slice = test_slice();
1066
1067 assert_all_routing_strategies_eq!(
1068 slice,
1069 // sel!(0 & (0,(1|3), *))
1070 intersection(
1071 range(0, true_()),
1072 range(0, union(range(1, all(true_())), range(3, all(true_()))))
1073 )
1074 );
1075 assert_all_routing_strategies_eq!(
1076 slice,
1077 // sel!(0 & (0, (3|1), *)),
1078 intersection(
1079 range(0, true_()),
1080 range(0, union(range(3, all(true_())), range(1, all(true_()))))
1081 )
1082 );
1083 assert_all_routing_strategies_eq!(
1084 slice,
1085 // sel!((*, *, *) & (*, *, (2 | 4)))
1086 intersection(
1087 all(all(all(true_()))),
1088 all(all(union(range(2, true_()), range(4, true_()))))
1089 )
1090 );
1091 assert_all_routing_strategies_eq!(
1092 slice,
1093 // sel!((*, *, *) & (*, *, (4 | 2)))
1094 intersection(
1095 all(all(all(true_()))),
1096 all(all(union(range(4, true_()), range(2, true_()))))
1097 )
1098 );
1099 assert_all_routing_strategies_eq!(
1100 slice,
1101 // sel!((*, (1 | 2)) & (*, (2 | 1)))
1102 intersection(
1103 all(union(range(1, true_()), range(2, true_()))),
1104 all(union(range(2, true_()), range(1, true_())))
1105 )
1106 );
1107 assert_all_routing_strategies_eq!(
1108 slice,
1109 intersection(all(all(all(true_()))), all(true_()))
1110 );
1111 assert_all_routing_strategies_eq!(slice, intersection(true_(), all(all(all(true_())))));
1112 assert_all_routing_strategies_eq!(slice, intersection(all(all(all(true_()))), false_()));
1113 assert_all_routing_strategies_eq!(slice, intersection(false_(), all(all(all(true_())))));
1114 assert_all_routing_strategies_eq!(
1115 slice,
1116 intersection(
1117 all(all(range(0..4, true_()))),
1118 all(all(range(0..4, true_())))
1119 )
1120 );
1121 assert_all_routing_strategies_eq!(
1122 slice,
1123 intersection(all(all(range(1, true_()))), all(all(range(2, true_()))))
1124 );
1125 assert_all_routing_strategies_eq!(
1126 slice,
1127 intersection(all(all(range(2, true_()))), all(all(range(1, true_()))))
1128 );
1129 assert_all_routing_strategies_eq!(
1130 slice,
1131 intersection(
1132 all(all(range(1, true_()))),
1133 intersection(all(all(true_())), all(all(range(1, true_()))))
1134 )
1135 );
1136 assert_all_routing_strategies_eq!(
1137 slice,
1138 intersection(
1139 range(0, true_()),
1140 range(0, all(union(range(1, true_()), range(3, true_()))))
1141 )
1142 );
1143 assert_all_routing_strategies_eq!(
1144 slice,
1145 range(
1146 0,
1147 intersection(true_(), all(union(range(1, true_()), range(3, true_()))))
1148 )
1149 );
1150 assert_all_routing_strategies_eq!(
1151 slice,
1152 intersection(all(range(1..=2, true_())), all(range(2..=3, true_())))
1153 );
1154 assert_all_routing_strategies_eq!(
1155 slice,
1156 intersection(
1157 range(0, true_()),
1158 intersection(range(0, all(true_())), range(0, range(1, all(true_()))))
1159 )
1160 );
1161 assert_all_routing_strategies_eq!(
1162 slice,
1163 intersection(
1164 range(0, range(1, all(true_()))),
1165 intersection(range(0, all(true_())), range(0, true_()))
1166 )
1167 );
1168 assert_all_routing_strategies_eq!(
1169 slice,
1170 // sel!( (*, *, *) & ((*, *, *) & (*, *, *)) ),
1171 intersection(
1172 all(all(all(true_()))),
1173 intersection(all(all(all(true_()))), all(all(all(true_()))))
1174 )
1175 );
1176 assert_all_routing_strategies_eq!(
1177 slice,
1178 union(
1179 intersection(range(0, true_()), range(0, range(1, all(true_())))),
1180 range(1, all(all(true_())))
1181 )
1182 );
1183 assert_all_routing_strategies_eq!(
1184 slice,
1185 // sel!((1, *, *) | (0 & (0, 3, *)))
1186 union(
1187 range(1, all(all(true_()))),
1188 intersection(range(0, true_()), range(0, range(3, all(true_()))))
1189 )
1190 );
1191 assert_all_routing_strategies_eq!(
1192 slice,
1193 intersection(
1194 union(range(0, true_()), range(1, true_())),
1195 union(range(1, true_()), range(0, true_()))
1196 )
1197 );
1198 assert_all_routing_strategies_eq!(
1199 slice,
1200 union(
1201 intersection(range(0, range(1, true_())), range(0, range(1, true_()))),
1202 intersection(range(1, range(3, true_())), range(1, range(3, true_())))
1203 )
1204 );
1205 assert_all_routing_strategies_eq!(
1206 slice,
1207 // sel!(*, 8 : 8)
1208 all(range(8..8, true_()))
1209 );
1210 assert_all_routing_strategies_eq!(
1211 slice,
1212 // sel!((*, 1) & (*, 8 : 8))
1213 intersection(all(range(1..2, true_())), all(range(8..8, true_())))
1214 );
1215 assert_all_routing_strategies_eq!(
1216 slice,
1217 // sel!((*, 8 : 8) | (*, 1))
1218 union(all(range(8..8, true_())), all(range(1..2, true_())))
1219 );
1220 assert_all_routing_strategies_eq!(
1221 slice,
1222 // sel!((*, 1) | (*, 2:8))
1223 union(all(range(1..2, true_())), all(range(2..8, true_())))
1224 );
1225 assert_all_routing_strategies_eq!(
1226 slice,
1227 // sel!((*, *, *) & (*, *, 2:8))
1228 intersection(all(all(all(true_()))), all(all(range(2..8, true_()))))
1229 );
1230 }
1231
1232 #[test]
1233 fn test_routing_02() {
1234 let slice = test_slice();
1235
1236 // zone 0 or 1: sel!(0 | 1, *, *)
1237 assert_all_routing_strategies_eq!(slice, union(range(0, true_()), range(1, true_())));
1238 assert_all_routing_strategies_eq!(
1239 slice,
1240 union(range(0, all(true_())), range(1, all(true_())))
1241 );
1242 // hosts 1 and 3 in zone 0: sel!(0, (1 | 3), *)
1243 assert_all_routing_strategies_eq!(
1244 slice,
1245 range(0, union(range(1, all(true_())), range(3, all(true_()))))
1246 );
1247 // sel!(0, 1:3 | 5:7, *)
1248 assert_all_routing_strategies_eq!(
1249 slice,
1250 range(
1251 0,
1252 union(
1253 range(shape::Range(1, Some(3), 1), all(true_())),
1254 range(shape::Range(5, Some(7), 1), all(true_()))
1255 )
1256 )
1257 );
1258
1259 // sel!(* | *): We start with `union(true_(), true_())`.
1260 //
1261 // Evaluating the left branch generates routing frames
1262 // recursively. Evaluating the right branch generates the same
1263 // frames again.
1264 //
1265 // As a result, we produce duplicate `RoutingFrame`s that
1266 // have:
1267 // - the same `here` coordinate,
1268 // - the same dimension (`dim`), and
1269 // - the same residual selection (`True`).
1270 //
1271 // When both frames reach the delivery condition, the second
1272 // call to `delivered.insert()` returns `false`. If we put an
1273 // `assert!` on that line this would trigger assertion failure
1274 // in the routing simulation.
1275 //
1276 // TODO: We need memoization to avoid redundant work.
1277 //
1278 // This can be achieved without transforming the algebra itself.
1279 // However, adding normalization will make memoization more
1280 // effective, so we should plan to implement both.
1281 //
1282 // Once that's done, we can safely restore the `assert!`.
1283 assert_all_routing_strategies_eq!(slice, union(true_(), true_()));
1284 // sel!(*, *, * | *, *, *)
1285 assert_all_routing_strategies_eq!(
1286 slice,
1287 union(all(all(all(true_()))), all(all(all(true_()))))
1288 );
1289 // no 'false' support in sel!
1290 assert_all_routing_strategies_eq!(slice, union(false_(), all(all(all(true_())))));
1291 assert_all_routing_strategies_eq!(slice, union(all(all(all(true_()))), false_()));
1292 // sel!(0, 0:4, 0 | 1 | 2)
1293 assert_all_routing_strategies_eq!(
1294 slice,
1295 range(
1296 0,
1297 range(
1298 shape::Range(0, Some(4), 1),
1299 union(
1300 range(0, true_()),
1301 union(range(1, true_()), range(2, true_()))
1302 )
1303 )
1304 )
1305 );
1306 assert_all_routing_strategies_eq!(
1307 slice,
1308 range(
1309 0,
1310 union(range(2, range(4, true_())), range(3, range(5, true_())),),
1311 )
1312 );
1313 assert_all_routing_strategies_eq!(
1314 slice,
1315 range(0, range(2, union(range(4, true_()), range(5, true_()),),),)
1316 );
1317 assert_all_routing_strategies_eq!(
1318 slice,
1319 range(
1320 0,
1321 union(range(2, range(4, true_())), range(3, range(5, true_())),),
1322 )
1323 );
1324 assert_all_routing_strategies_eq!(
1325 slice,
1326 union(
1327 range(
1328 0,
1329 union(range(2, range(4, true_())), range(3, range(5, true_())))
1330 ),
1331 range(
1332 1,
1333 union(range(2, range(4, true_())), range(3, range(5, true_())))
1334 )
1335 )
1336 );
1337 }
1338
1339 #[test]
1340 fn test_routing_01() {
1341 use std::ops::ControlFlow;
1342
1343 let slice = test_slice();
1344 let sel = range(0..=0, all(true_()));
1345
1346 let expected_fanouts: &[&[&[usize]]] = &[
1347 &[&[0, 0, 0]],
1348 &[&[0, 0, 0], &[0, 1, 0], &[0, 2, 0], &[0, 3, 0]],
1349 &[
1350 &[0, 0, 0],
1351 &[0, 0, 1],
1352 &[0, 0, 2],
1353 &[0, 0, 3],
1354 &[0, 0, 4],
1355 &[0, 0, 5],
1356 &[0, 0, 6],
1357 &[0, 0, 7],
1358 ],
1359 &[
1360 &[0, 1, 0],
1361 &[0, 1, 1],
1362 &[0, 1, 2],
1363 &[0, 1, 3],
1364 &[0, 1, 4],
1365 &[0, 1, 5],
1366 &[0, 1, 6],
1367 &[0, 1, 7],
1368 ],
1369 &[
1370 &[0, 2, 0],
1371 &[0, 2, 1],
1372 &[0, 2, 2],
1373 &[0, 2, 3],
1374 &[0, 2, 4],
1375 &[0, 2, 5],
1376 &[0, 2, 6],
1377 &[0, 2, 7],
1378 ],
1379 &[
1380 &[0, 3, 0],
1381 &[0, 3, 1],
1382 &[0, 3, 2],
1383 &[0, 3, 3],
1384 &[0, 3, 4],
1385 &[0, 3, 5],
1386 &[0, 3, 6],
1387 &[0, 3, 7],
1388 ],
1389 ];
1390
1391 let expected_deliveries: &[bool] = &[
1392 false, false, false, false, false, false, // Steps 0–5
1393 true, true, true, true, true, true, true, true, true, true, true, true, true, true,
1394 true, true, true, true, true, true, true, true, true, true, true, true, true, true,
1395 true, true, true, true, true, // Steps 6–38
1396 ];
1397
1398 let mut step = 0;
1399 let mut pending = VecDeque::new();
1400
1401 pending.push_back(RoutingFrame::root(sel.clone(), slice.clone()));
1402
1403 println!("Fan-out trace for selection: {}", sel);
1404
1405 while let Some(frame) = pending.pop_front() {
1406 let mut next_coords = vec![];
1407
1408 let deliver_here = frame.deliver_here();
1409
1410 let _ = frame.next_steps(
1411 &mut |_| panic!("Choice encountered in test_routing_01"),
1412 &mut |step| {
1413 let next = step.into_forward().unwrap();
1414 next_coords.push(next.here.clone());
1415 pending.push_back(next);
1416 ControlFlow::Continue(())
1417 },
1418 );
1419
1420 println!(
1421 "Step {:>2}: from {:?} (flat = {:>2}) | deliver = {} | fan-out count = {} | selection = {:?}",
1422 step,
1423 frame.here,
1424 frame.slice.location(&frame.here).unwrap(),
1425 deliver_here,
1426 next_coords.len(),
1427 format!("{}", frame.selection),
1428 );
1429
1430 for next in &next_coords {
1431 println!(" → {:?}", next);
1432 }
1433
1434 if step < expected_fanouts.len() {
1435 let expected = expected_fanouts[step]
1436 .iter()
1437 .map(|v| v.to_vec())
1438 .collect::<Vec<_>>();
1439 assert_eq!(
1440 next_coords, expected,
1441 "Mismatch in next_coords at step {}",
1442 step
1443 );
1444 }
1445
1446 if step < expected_deliveries.len() {
1447 assert_eq!(
1448 deliver_here, expected_deliveries[step],
1449 "Mismatch in deliver_here at step {} (coord = {:?})",
1450 step, frame.here
1451 );
1452 }
1453
1454 step += 1;
1455 }
1456 }
1457
1458 #[test]
1459 fn test_routing_06() {
1460 use std::ops::ControlFlow;
1461
1462 use crate::selection::dsl::*;
1463 use crate::selection::routing::RoutingFrameKey;
1464 use crate::selection::routing::RoutingStep;
1465
1466 let slice = test_slice();
1467 let selection = union(all(true_()), all(true_()));
1468
1469 let mut pending = VecDeque::new();
1470 let mut dedup_delivered = Vec::new();
1471 let mut nodup_delivered = Vec::new();
1472 let mut seen = HashSet::new();
1473
1474 let root = RoutingFrame::root(selection.clone(), slice.clone());
1475 pending.push_back(RoutedMessage::<()>::new(root.here.clone(), root));
1476
1477 while let Some(RoutedMessage { frame, .. }) = pending.pop_front() {
1478 let mut visitor = |step: RoutingStep| {
1479 let next = step.into_forward().unwrap();
1480
1481 if next.action() == RoutingAction::Deliver {
1482 nodup_delivered.push(next.slice.location(&next.here).unwrap());
1483 }
1484
1485 let key = RoutingFrameKey::new(&next);
1486 if seen.insert(key) && next.action() == RoutingAction::Deliver {
1487 dedup_delivered.push(next.slice.location(&next.here).unwrap());
1488 }
1489
1490 if next.action() == RoutingAction::Forward {
1491 pending.push_back(RoutedMessage::new(frame.here.clone(), next));
1492 }
1493
1494 ControlFlow::Continue(())
1495 };
1496
1497 let _ = frame.next_steps(
1498 &mut |_| panic!("Choice encountered in test_routing_06"),
1499 &mut visitor,
1500 );
1501 }
1502
1503 assert_eq!(dedup_delivered.len(), 64);
1504 assert_eq!(nodup_delivered.len(), 128);
1505 }
1506
1507 #[test]
1508 fn test_routing_07() {
1509 use std::ops::ControlFlow;
1510
1511 use crate::selection::dsl::*;
1512 use crate::selection::routing::RoutingFrame;
1513 use crate::selection::routing::RoutingStep;
1514
1515 let slice = test_slice(); // shape: [2, 4, 8]
1516
1517 // Selection: any zone, all hosts, all gpus.
1518 let selection = any(all(all(true_())));
1519 let frame = RoutingFrame::root(selection, slice.clone());
1520
1521 let mut steps = vec![];
1522 let _ = frame.next_steps(
1523 &mut |_| panic!("Choice encountered in test_routing_07"),
1524 &mut |step: RoutingStep| {
1525 steps.push(step);
1526 ControlFlow::Continue(())
1527 },
1528 );
1529
1530 // Only one hop should be produced at the `any` dimension.
1531 assert_eq!(steps.len(), 1);
1532
1533 // Reject choices.
1534 let hop = &steps[0].as_forward().unwrap();
1535
1536 // There should be 3 components to the frame's coordinate.
1537 assert_eq!(hop.here.len(), 3);
1538
1539 // The selected zone (dim 0) should be in bounds.
1540 let zone = hop.here[0];
1541 assert!(zone < 2, "zone out of bounds: {}", zone);
1542
1543 // Inner selection should still be All(All(True))
1544 assert!(matches!(hop.selection, Selection::All(_)));
1545 }
1546
1547 // This test relies on a deep structural property of the routing
1548 // semantics:
1549 //
1550 // Overdelivery is prevented not by ad hoc guards, but by the
1551 // structure of the traversal itself — particularly in the
1552 // presence of routing frame deduplication.
1553 //
1554 // When a frame reaches the final dimension with `selection ==
1555 // True`, it becomes a delivery frame. If multiple such frames
1556 // target the same coordinate, then:
1557 //
1558 // - They must share the same coordinate `here`
1559 // - They must have reached it via the same routing path (by the
1560 // Unique Path Theorem)
1561 // - Their `RoutingFrame` state is thus structurally identical:
1562 // - Same `here`
1563 // - Same `dim` (equal to `slice.num_dim()`)
1564 // - Same residual `selection == True`
1565 //
1566 // The deduplication logic (via `RoutingFrameKey`) collapses such
1567 // structurally equivalent frames. As a result, only one frame
1568 // delivers to the target coordinate, and overdelivery is
1569 // structurally ruled out.
1570 //
1571 // This test verifies that behavior holds as expected — and, when
1572 // deduplication is disabled, confirms that overdelivery becomes
1573 // observable.
1574 #[test]
1575 fn test_routing_deduplication_precludes_overdelivery() {
1576 // Ensure the environment is clean — this test depends on a
1577 // known configuration of deduplication behavior.
1578 let var = "HYPERACTOR_SELECTION_DISABLE_ROUTING_FRAME_DEDUPLICATION";
1579 assert!(
1580 std::env::var_os(var).is_none(),
1581 "env var `{}` should not be set prior to test",
1582 var
1583 );
1584 let slice = test_slice();
1585
1586 // Construct a structurally duplicated selection.
1587 //
1588 // The union duplicates a singleton selection expression.
1589 // Without deduplication, this would result in two logically
1590 // identical frames targeting the same node — which should
1591 // trigger an over-delivery panic in the simulation.
1592 let a = range(0, range(0, range(0, true_())));
1593 let sel = union(a.clone(), a.clone());
1594
1595 // Sanity check: with deduplication enabled (default), this
1596 // selection does not cause overdelivery.
1597 let result = std::panic::catch_unwind(|| {
1598 let _ = collect_routed_paths(&sel, &slice);
1599 });
1600 assert!(result.is_ok(), "Unexpected panic due to overdelivery");
1601
1602 // Now explicitly disable deduplication.
1603 // SAFETY: TODO: Audit that the environment access only
1604 // happens in single-threaded code.
1605 unsafe { std::env::set_var(var, "1") };
1606
1607 // Expect overdelivery: the duplicated union arms will each
1608 // produce a delivery to the same coordinate.
1609 let result = std::panic::catch_unwind(|| {
1610 let _ = collect_routed_paths(&sel, &slice);
1611 });
1612
1613 // Clean up: restore environment to avoid affecting other
1614 // tests.
1615 // SAFETY: TODO: Audit that the environment access only
1616 // happens in single-threaded code.
1617 unsafe { std::env::remove_var(var) };
1618
1619 assert!(
1620 result.is_err(),
1621 "Expected panic due to overdelivery, but no panic occurred"
1622 );
1623 }
1624
1625 #[test]
1626 fn test_next_steps_zero_dim_slice() {
1627 use std::ops::ControlFlow;
1628
1629 use crate::selection::dsl::*;
1630
1631 let slice = Slice::new(42, vec![], vec![]).unwrap();
1632
1633 let selection = true_();
1634 let frame = RoutingFrame::root(selection, slice.clone());
1635 let mut steps = vec![];
1636 let _ = frame.next_steps(
1637 &mut |_| panic!("Unexpected Choice in 0D test"),
1638 &mut |step| {
1639 steps.push(step);
1640 ControlFlow::Continue(())
1641 },
1642 );
1643
1644 assert_eq!(steps.len(), 1);
1645 let step = steps[0].as_forward().unwrap();
1646 assert_eq!(step.here, vec![0]);
1647 assert!(step.deliver_here());
1648 assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1649
1650 let selection = all(true_());
1651 let frame = RoutingFrame::root(selection, slice.clone());
1652 let mut steps = vec![];
1653 let _ = frame.next_steps(
1654 &mut |_| panic!("Unexpected Choice in 0D test"),
1655 &mut |step| {
1656 steps.push(step);
1657 ControlFlow::Continue(())
1658 },
1659 );
1660
1661 assert_eq!(steps.len(), 1);
1662 let step = steps[0].as_forward().unwrap();
1663 assert_eq!(step.here, vec![0]);
1664 assert!(step.deliver_here());
1665 assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1666
1667 let selection = false_();
1668 let frame = RoutingFrame::root(selection, slice.clone());
1669 let mut steps = vec![];
1670 let _ = frame.next_steps(
1671 &mut |_| panic!("Unexpected Choice in 0D test"),
1672 &mut |step| {
1673 steps.push(step);
1674 ControlFlow::Continue(())
1675 },
1676 );
1677
1678 assert_eq!(steps.len(), 1);
1679 let step = steps[0].as_forward().unwrap();
1680 assert_eq!(step.here, vec![0]);
1681 assert!(!step.deliver_here());
1682 assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1683
1684 let selection = all(false_());
1685 let frame = RoutingFrame::root(selection, slice.clone());
1686 let mut steps = vec![];
1687 let _ = frame.next_steps(
1688 &mut |_| panic!("Unexpected Choice in 0D test"),
1689 &mut |step| {
1690 steps.push(step);
1691 ControlFlow::Continue(())
1692 },
1693 );
1694 assert_eq!(steps.len(), 1);
1695 let step = steps[0].as_forward().unwrap();
1696 assert_eq!(step.here, vec![0]);
1697 assert!(!step.deliver_here());
1698 assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1699 }
1700}