ndslice/selection/routing.rs
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! # Routing
10//!
11//! This module defines [`RoutingFrame`] and its [`next_steps`]
12//! method, which model how messages propagate through a
13//! multidimensional mesh based on a [`Selection`] expression.
14//!
15//! A [`RoutingFrame`] represents the state of routing at a particular
16//! point in the mesh. It tracks the current coordinate (`here`), the
17//! remaining selection to apply (`selection`), the mesh layout
18//! (`slice`), and the current dimension of traversal (`dim`).
19//!
20//! [`next_steps`] defines a routing-specific evaluation
21//! strategy for `Selection`. Unlike [`Selection::eval`], which
22//! produces flat indices that match a selection, this method produces
23//! intermediate routing states — new frames or deferred steps to
24//! continue traversing.
25//!
26//! Rather than returning raw frames directly, [`next_steps`]
27//! produces a stream of [`RoutingStep`]s via a callback — each
28//! representing a distinct kind of routing progression:
29//!
30//! - [`RoutingStep::Forward`] indicates that routing proceeds
31//! deterministically to a new [`RoutingFrame`] — the next
32//! coordinate is fully determined by the current selection and
33//! frame state.
34//!
35//! - [`RoutingStep::Choice`] represents a deferred decision: it
36//! returns a set of admissible indices, and **the caller must
37//! select one** (e.g., for load balancing or policy-based routing)
38//! **before routing can proceed**.
39//!
40//! In this way, non-determinism is treated as a **first-class,
41//! policy-driven** aspect of the routing system — enabling
42//! inspection, customization, and future extensions without
43//! complicating the core traversal logic.
44//!
45//! A frame is considered a delivery target if its selection is
46//! [`Selection::True`] and all dimensions have been traversed, as
47//! determined by [`RoutingFrame::deliver_here`]. All other frames are
48//! forwarded further using [`RoutingFrame::should_route`].
49//!
50//! This design enables **compositional**, **local**, and **scalable**
51//! routing:
52//! - **Compositional**: complex selection expressions decompose into
53//! simpler, independently evaluated sub-selections.
54//! - **Local**: each frame carries exactly the state needed for its
55//! next step — no global coordination or lookahead is required.
56//! - **Scalable**: routing unfolds recursively, one hop at a time,
57//! allowing for efficient traversal even in high-dimensional spaces.
58//!
59//! This module provides the foundation for building structured,
60//! recursive routing logic over multidimensional coordinate spaces.
61use std::collections::HashMap;
62use std::collections::HashSet;
63use std::fmt::Write;
64use std::hash::Hash;
65use std::ops::ControlFlow;
66use std::sync::Arc;
67
68use anyhow::Result;
69use enum_as_inner::EnumAsInner;
70use serde::Deserialize;
71use serde::Serialize;
72use serde::de::DeserializeOwned;
73
74use crate::SliceError;
75use crate::selection::NormalizedSelectionKey;
76use crate::selection::Selection;
77use crate::selection::Slice;
78
79/// Represents the outcome of evaluating a routing step.
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum RoutingAction {
82 Deliver,
83 Forward,
84}
85
86/// `RoutingFrame` captures the state of a selection being evaluated:
87/// the current coordinate (`here`), the remaining selection to apply,
88/// the shape and layout information (`slice`), and the current
89/// dimension (`dim`).
90///
91/// Each frame represents an independent routing decision and produces
92/// zero or more new frames via `next_steps`.
93#[derive(Clone, Debug, Serialize, Deserialize)]
94pub struct RoutingFrame {
95 /// The current coordinate in the mesh where this frame is being
96 /// evaluated.
97 ///
98 /// This is the source location for the next routing step.
99 pub here: Vec<usize>,
100
101 /// The residual selection expression describing where routing
102 /// should continue.
103 ///
104 /// At each step, only the current dimension (tracked by `dim`) of
105 /// this selection is considered.
106 pub selection: Selection,
107
108 /// The shape and layout of the full multidimensional space being
109 /// routed.
110 ///
111 /// This determines the bounds and stride information used to
112 /// compute coordinates and flat indices.
113 pub slice: Arc<Slice>,
114
115 /// The current axis of traversal within the selection and slice.
116 ///
117 /// Routing proceeds dimension-by-dimension; this value tracks how
118 /// many dimensions have already been routed.
119 pub dim: usize,
120}
121
122// Compile-time check: ensure `RoutingFrame` is thread-safe and fully
123// owned.
124fn _assert_routing_frame_traits()
125where
126 RoutingFrame: Send + Sync + Serialize + DeserializeOwned + 'static,
127{
128}
129
130/// A `RoutingStep` represents a unit of progress in the routing
131/// process.
132///
133/// Emitted by [`RoutingFrame::next_steps`], each step describes
134/// how routing should proceed from a given frame:
135///
136/// - [`RoutingStep::Forward`] represents a deterministic hop to the
137/// next coordinate in the mesh, with an updated [`RoutingFrame`].
138///
139/// - [`RoutingStep::Choice`] indicates that routing cannot proceed
140/// until the caller selects one of several admissible indices. This
141/// allows for policy-driven or non-deterministic routing behavior,
142/// such as load balancing.
143#[derive(Debug, Clone, EnumAsInner)]
144pub enum RoutingStep {
145 /// A deterministic routing hop to the next coordinate. Carries an
146 /// updated [`RoutingFrame`] describing the new position and
147 /// residual selection.
148 Forward(RoutingFrame),
149
150 /// A deferred routing decision at the current dimension. Contains
151 /// a set of admissible indices and a residual [`RoutingFrame`] to
152 /// continue routing once a choice is made.
153 Choice(Choice),
154}
155
156/// A deferred routing decision as contained in a
157/// [`RoutingStep::Choice`].
158///
159/// A `Choice` contains:
160/// - `candidates`: the admissible indices at the current dimension
161/// - `frame`: the residual [`RoutingFrame`] describing how routing
162/// continues once a choice is made
163///
164/// To continue routing, the caller must select one of the
165/// `candidates` and call [`Choice::choose`] to produce the
166/// corresponding [`RoutingStep::Forward`].
167#[derive(Debug, Clone)]
168pub struct Choice {
169 pub(crate) candidates: Vec<usize>,
170 pub(crate) frame: RoutingFrame,
171}
172
173impl Choice {
174 /// Returns the list of admissible indices at the current
175 /// dimension.
176 ///
177 /// These represent the valid choices that the caller can select
178 /// from when resolving this deferred routing step.
179 pub fn candidates(&self) -> &[usize] {
180 &self.candidates
181 }
182
183 /// Returns a reference to the residual [`RoutingFrame`]
184 /// associated with this choice.
185 ///
186 /// This frame encodes the selection and mesh context to be used
187 /// once a choice is made, and routing continues at the next
188 /// dimension.
189 pub fn frame(&self) -> &RoutingFrame {
190 &self.frame
191 }
192
193 /// Resolves the choice by selecting a specific index.
194 ///
195 /// Constrains the residual selection to the chosen index at the
196 /// current dimension and returns a [`RoutingStep::Forward`] for
197 /// continued routing.
198 pub fn choose(self, index: usize) -> RoutingStep {
199 // The only thing `next()` has to do is constrain the
200 // selection to a concrete choice at the current dimension.
201 // `self.frame.selection` is the residual (inner) selection to
202 // be applied *past* the current dimension.
203 RoutingStep::Forward(RoutingFrame {
204 selection: crate::dsl::range(index..=index, self.frame.selection),
205 ..self.frame
206 })
207 }
208}
209
210/// Key used to deduplicate routing frames.
211#[derive(Debug, Hash, PartialEq, Eq)]
212pub struct RoutingFrameKey {
213 here: Vec<usize>,
214 dim: usize,
215 selection: NormalizedSelectionKey,
216}
217
218impl RoutingFrameKey {
219 /// Constructs a `RoutingFrameKey` from a `RoutingFrame`.
220 ///
221 /// This key uniquely identifies a routing frame by its coordinate
222 /// (`here`), current dimension, and normalized selection. It is
223 /// used during traversal for purposes such as deduplication and
224 /// memoization.
225 pub fn new(frame: &RoutingFrame) -> Self {
226 Self {
227 here: frame.here.clone(),
228 dim: frame.dim,
229 selection: NormalizedSelectionKey::new(&frame.selection),
230 }
231 }
232}
233
234impl RoutingFrame {
235 /// Constructs the initial frame at the root coordinate (all
236 /// zeros). Selections are expanded as necessary to ensure they
237 /// have depth equal to the slice dimensionality. See the docs for
238 /// `canonicalize_to_dimensions` for the rules.
239 ///
240 /// ### Canonical Handling of Zero-Dimensional Slices
241 ///
242 /// A `Slice` with zero dimensions represents the empty product
243 /// `∏_{i=1}^{0} Xᵢ`, which has exactly one element: the empty
244 /// tuple. To maintain uniform routing semantics, we canonically
245 /// embed such 0D slices as 1D slices of extent 1:
246 ///
247 /// ```text
248 /// Slice::new(offset, [1], [1])
249 /// ```
250 ///
251 /// This embedding preserves the correct number of addressable
252 /// points and allows the routing machinery to proceed through the
253 /// usual recursive strategy without introducing special cases. The
254 /// selected coordinate is `vec![0]`, and `dim = 0` proceeds as
255 /// usual. This makes the routing logic consistent with evaluation
256 /// and avoids edge case handling throughout the codebase.
257 pub fn root(selection: Selection, slice: Slice) -> Self {
258 // Canonically embed 0D as 1D (extent 1).
259 let slice = if slice.num_dim() > 0 {
260 Arc::new(slice)
261 } else {
262 Arc::new(Slice::new(slice.offset(), vec![1], vec![1]).unwrap())
263 };
264 let n = slice.num_dim();
265 RoutingFrame {
266 here: vec![0; n],
267 selection: selection.canonicalize_to_dimensions(n),
268 slice,
269 dim: 0,
270 }
271 }
272
273 /// Produces a new frame advanced to the next dimension with
274 /// updated position and selection.
275 pub fn advance(&self, here: Vec<usize>, selection: Selection) -> Self {
276 RoutingFrame {
277 here,
278 selection,
279 slice: Arc::clone(&self.slice),
280 dim: self.dim + 1,
281 }
282 }
283
284 /// Returns a new frame with the same position and dimension but a
285 /// different selection.
286 pub fn with_selection(&self, selection: Selection) -> Self {
287 RoutingFrame {
288 here: self.here.clone(),
289 selection,
290 slice: Arc::clone(&self.slice),
291 dim: self.dim,
292 }
293 }
294
295 /// Determines the appropriate routing action for this frame.
296 ///
297 /// Returns [`RoutingAction::Deliver`] if the message should be
298 /// delivered at this coordinate, or [`RoutingAction::Forward`] if
299 /// it should be routed further.
300 pub fn action(&self) -> RoutingAction {
301 if self.deliver_here() {
302 RoutingAction::Deliver
303 } else {
304 RoutingAction::Forward
305 }
306 }
307
308 /// Returns the location of this frame in the underlying slice.
309 pub fn location(&self) -> Result<usize, SliceError> {
310 self.slice.location(&self.here)
311 }
312}
313
314impl RoutingFrame {
315 /// Visits the next routing steps from this frame using a
316 /// callback-based traversal.
317 /// This method structurally recurses on the [`Selection`]
318 /// expression, yielding [`RoutingStep`]s via the `f` callback.
319 /// Early termination is supported via [`ControlFlow::Break`].
320 ///
321 /// Compared to the (old, now removed) [`next_steps`] method, this
322 /// avoids intermediate allocation, supports interruptibility, and
323 /// allows policy-driven handling of [`RoutingStep::Choice`]s via
324 /// the `chooser` callback.
325 ///
326 /// ---
327 ///
328 /// ### Traversal Strategy
329 ///
330 /// The traversal proceeds **dimension-by-dimension**,
331 /// structurally mirroring the shape of the selection expression:
332 ///
333 /// - [`Selection::All`] and [`Selection::Range`] iterate over a
334 /// range of coordinates, emitting one [`RoutingStep::Forward`]
335 /// per valid index.
336 /// - [`Selection::Union`] and [`Selection::Intersection`] recurse
337 /// into both branches. Intersection steps are joined at matching
338 /// coordinates and residual selections are reduced.
339 /// - [`Selection::Any`] randomly selects one index along the
340 /// current dimension and emits a single step.
341 /// - [`Selection::True`] and [`Selection::False`] emit no steps.
342 ///
343 /// At each step, only the current dimension (tracked via `self.dim`)
344 /// is evaluated. Future dimensions remain untouched until deeper
345 /// recursion.
346 ///
347 /// ---
348 ///
349 /// ### Evaluation Semantics
350 ///
351 /// - **Selection::True**
352 /// No further routing is performed — if this frame is at the
353 /// final dimension, delivery occurs.
354 /// - **Selection::False**
355 /// No match — routing halts.
356 ///
357 /// - **Selection::All / Selection::Range**
358 /// Emits one [`RoutingStep::Forward`] per matching index, each
359 /// advancing to the next dimension with the inner selection.
360 ///
361 /// - **Selection::Union**
362 /// Evaluates both branches independently and emits all
363 /// resulting steps.
364 ///
365 /// - **Selection::Intersection**
366 /// Emits only those steps where both branches produce the same
367 /// coordinate, combining the residual selections at that point.
368 ///
369 /// - **Selection::Any**
370 /// Randomly selects one index and emits a single
371 /// [`RoutingStep::Forward`].
372 ///
373 /// - **Selection::Choice**
374 /// Defers decision to the caller by invoking the `chooser`
375 /// function, which resolves the candidate index.
376 ///
377 /// ---
378 ///
379 /// ### Delivery Semantics
380 ///
381 /// Message delivery is determined by
382 /// [`RoutingFrame::deliver_here`], which returns true when:
383 ///
384 /// - The frame’s selection is [`Selection::True`], and
385 /// - All dimensions have been traversed (`dim ==
386 /// slice.num_dim()`).
387 ///
388 /// ---
389 ///
390 /// ### Panics
391 ///
392 /// Panics if `slice.num_dim() == 0`. Use a canonical embedding
393 /// (e.g., 0D → 1D) before calling this (see e.g.
394 /// `RoutingFrame::root`).
395 ///
396 /// ---
397 ///
398 /// ### Summary
399 ///
400 /// - **Structure-driven**: Mirrors the shape of the selection
401 /// expression.
402 /// - **Compositional**: Each variant defines its own traversal
403 /// behavior.
404 /// - **Interruptible**: Early termination is supported via
405 /// [`ControlFlow`].
406 /// - **Minimally allocating**: Avoids intermediate buffers in
407 /// most cases; only [`Selection::Intersection`] allocates
408 /// temporary state for pairwise matching.
409 /// - **Policy-ready**: Integrates with runtime routing policies
410 /// via the `chooser`.
411 pub fn next_steps(
412 &self,
413 _chooser: &mut dyn FnMut(&Choice) -> usize,
414 f: &mut dyn FnMut(RoutingStep) -> ControlFlow<()>,
415 ) -> ControlFlow<()> {
416 assert!(self.slice.num_dim() > 0, "next_steps requires num_dims > 0");
417
418 match &self.selection {
419 Selection::True => ControlFlow::Continue(()),
420 Selection::False => ControlFlow::Continue(()),
421 Selection::All(inner) => {
422 let size = self.slice.sizes()[self.dim];
423 for i in 0..size {
424 let mut coord = self.here.clone();
425 coord[self.dim] = i;
426 let frame = self.advance(coord, (**inner).clone());
427 if let ControlFlow::Break(_) = f(RoutingStep::Forward(frame)) {
428 return ControlFlow::Break(());
429 }
430 }
431 ControlFlow::Continue(())
432 }
433
434 Selection::Range(range, inner) => {
435 let size = self.slice.sizes()[self.dim];
436 let (min, max, step) = range.resolve(size);
437
438 for i in (min..max).step_by(step) {
439 let mut coord = self.here.clone();
440 coord[self.dim] = i;
441 let frame = self.advance(coord, (**inner).clone());
442 if let ControlFlow::Break(_) = f(RoutingStep::Forward(frame)) {
443 return ControlFlow::Break(());
444 }
445 }
446
447 ControlFlow::Continue(())
448 }
449
450 Selection::Any(inner) => {
451 let size = self.slice.sizes()[self.dim];
452 if size == 0 {
453 return ControlFlow::Continue(());
454 }
455
456 use rand::Rng;
457 let mut rng: rand::prelude::ThreadRng = rand::thread_rng();
458 let i = rng.gen_range(0..size);
459 let mut coord = self.here.clone();
460 coord[self.dim] = i;
461 let frame = self.advance(coord, (**inner).clone());
462 f(RoutingStep::Forward(frame))
463 }
464
465 Selection::Union(a, b) => {
466 if let ControlFlow::Break(_) =
467 self.with_selection((**a).clone()).next_steps(_chooser, f)
468 {
469 return ControlFlow::Break(());
470 }
471 self.with_selection((**b).clone()).next_steps(_chooser, f)
472 }
473
474 Selection::Intersection(a, b) => {
475 let mut left = vec![];
476 let mut right = vec![];
477
478 let mut collect_left = |step: RoutingStep| {
479 if let RoutingStep::Forward(frame) = step {
480 left.push(frame);
481 }
482 ControlFlow::Continue(())
483 };
484 let mut collect_right = |step: RoutingStep| {
485 if let RoutingStep::Forward(frame) = step {
486 right.push(frame);
487 }
488 ControlFlow::Continue(())
489 };
490
491 self.with_selection((**a).clone())
492 .next_steps(_chooser, &mut collect_left)?;
493 self.with_selection((**b).clone())
494 .next_steps(_chooser, &mut collect_right)?;
495
496 for fa in &left {
497 for fb in &right {
498 if fa.here == fb.here {
499 let residual = fa
500 .selection
501 .clone()
502 .reduce_intersection(fb.selection.clone());
503 let frame = self.advance(fa.here.clone(), residual);
504 if let ControlFlow::Break(_) = f(RoutingStep::Forward(frame)) {
505 return ControlFlow::Break(());
506 }
507 }
508 }
509 }
510
511 ControlFlow::Continue(())
512 }
513
514 // TODO(SF, 2025-04-30): This term is not in the algebra
515 // yet.
516 // Selection::LoadBalanced(inner) => {
517 // let size = self.slice.sizes()[self.dim];
518 // if size == 0 {
519 // ControlFlow::Continue(())
520 // } else {
521 // let candidates = (0..size).collect();
522 // let choice = Choice {
523 // candidates,
524 // frame: self.with_selection((*inner).clone()),
525 // };
526 // let index = chooser(&choice);
527 // f(choice.choose(index))
528 // }
529 // }
530
531 // Catch-all for future combinators (e.g., Label).
532 _ => unimplemented!(),
533 }
534 }
535
536 /// Returns true if this frame represents a terminal delivery
537 /// point — i.e., the selection is `True` and all dimensions have
538 /// been traversed.
539 pub fn deliver_here(&self) -> bool {
540 matches!(self.selection, Selection::True) && self.dim == self.slice.num_dim()
541 }
542
543 /// Returns true if the message has not yet reached its final
544 /// destination and should be forwarded to the next routing step.
545 pub fn should_route(&self) -> bool {
546 !self.deliver_here()
547 }
548}
549
550impl RoutingFrame {
551 /// Traces the unique routing path to the given destination
552 /// coordinate.
553 ///
554 /// Returns `Some(vec![root, ..., dest])` if `dest` is selected,
555 /// or `None` if not.
556 pub fn trace_route(&self, dest: &[usize]) -> Option<Vec<Vec<usize>>> {
557 use std::collections::HashSet;
558 use std::ops::ControlFlow;
559
560 use crate::selection::routing::RoutingFrameKey;
561
562 fn go(
563 frame: RoutingFrame,
564 dest: &[usize],
565 mut path: Vec<Vec<usize>>,
566 seen: &mut HashSet<RoutingFrameKey>,
567 ) -> Option<Vec<Vec<usize>>> {
568 let key = RoutingFrameKey::new(&frame);
569 if !seen.insert(key) {
570 return None;
571 }
572
573 path.push(frame.here.clone());
574
575 if frame.deliver_here() && frame.here == dest {
576 return Some(path);
577 }
578
579 let mut found = None;
580 let _ = frame.next_steps(
581 &mut |_| panic!("Choice encountered in trace_route"),
582 &mut |step: RoutingStep| {
583 let next = step.into_forward().unwrap();
584 if let Some(result) = go(next, dest, path.clone(), seen) {
585 found = Some(result);
586 ControlFlow::Break(())
587 } else {
588 ControlFlow::Continue(())
589 }
590 },
591 );
592
593 found
594 }
595
596 let mut seen = HashSet::new();
597 go(self.clone(), dest, Vec::new(), &mut seen)
598 }
599}
600
601/// Formats a routing path as a string, showing each hop in order.
602///
603/// Each line shows the hop index, an arrow (`→` for intermediate
604/// steps, `⇨` for the final destination), and the coordinate as a
605/// tuple (e.g., `(0, 1)`).
606/// # Example
607///
608/// ```text
609/// 0 → (0, 0)
610/// 1 → (0, 1)
611/// 2 ⇨ (1, 1)
612/// ```
613#[track_caller]
614#[allow(dead_code)]
615pub fn format_route(route: &[Vec<usize>]) -> String {
616 let mut out = String::new();
617 for (i, hop) in route.iter().enumerate() {
618 let arrow = if i == route.len() - 1 { "⇨" } else { "→" };
619 let coord = format!(
620 "({})",
621 hop.iter()
622 .map(ToString::to_string)
623 .collect::<Vec<_>>()
624 .join(", ")
625 );
626 let _ = writeln!(&mut out, "{:>2} {} {}", i, arrow, coord);
627 }
628 out
629}
630
631/// Formats a routing tree as an indented string.
632///
633/// Traverses the tree of `RoutingFrame`s starting from the root,
634/// displaying each step with indentation by dimension. Delivery
635/// targets are marked `✅`.
636///
637/// # Example
638/// ```text
639/// (0, 0)
640/// (0, 1) ✅
641/// (1, 0)
642/// (1, 1) ✅
643/// ```
644#[track_caller]
645#[allow(dead_code)]
646pub fn format_routing_tree(selection: Selection, slice: &Slice) -> String {
647 let root = RoutingFrame::root(selection, slice.clone());
648 let mut out = String::new();
649 let mut seen = HashSet::new();
650 format_routing_tree_rec(&root, 0, &mut out, &mut seen).unwrap();
651 out
652}
653
654fn format_routing_tree_rec(
655 frame: &RoutingFrame,
656 indent: usize,
657 out: &mut String,
658 seen: &mut HashSet<RoutingFrameKey>,
659) -> std::fmt::Result {
660 use crate::selection::routing::RoutingFrameKey;
661
662 let key = RoutingFrameKey::new(frame);
663 if !seen.insert(key) {
664 return Ok(()); // already visited
665 }
666
667 let indent_str = " ".repeat(indent);
668 let coord_str = format!(
669 "({})",
670 frame
671 .here
672 .iter()
673 .map(ToString::to_string)
674 .collect::<Vec<_>>()
675 .join(", ")
676 );
677
678 match frame.action() {
679 RoutingAction::Deliver => {
680 writeln!(out, "{}{} ✅", indent_str, coord_str)?;
681 }
682 RoutingAction::Forward => {
683 writeln!(out, "{}{}", indent_str, coord_str)?;
684 let _ = frame.next_steps(
685 &mut |_| panic!("Choice encountered in format_routing_tree_rec"),
686 &mut |step| {
687 let next = step.into_forward().unwrap();
688 format_routing_tree_rec(&next, indent + 1, out, seen).unwrap();
689 ControlFlow::Continue(())
690 },
691 );
692 }
693 }
694
695 Ok(())
696}
697
698// Pretty-prints a routing path from source to destination.
699//
700// Each hop is shown as a numbered step with directional arrows.
701#[track_caller]
702#[allow(dead_code)]
703pub fn print_route(route: &[Vec<usize>]) {
704 println!("{}", format_route(route));
705}
706
707/// Prints the routing tree for a selection over a slice.
708///
709/// Traverses the routing structure from the root, printing each step
710/// with indentation by dimension. Delivery points are marked with
711/// `✅`.
712#[track_caller]
713#[allow(dead_code)]
714pub fn print_routing_tree(selection: Selection, slice: &Slice) {
715 println!("{}", format_routing_tree(selection, slice));
716}
717
718// == "CommActor multicast" routing ==
719
720/// Resolves the current set of routing frames (`dests`) to determine
721/// whether the message should be delivered at this rank, and which
722/// routing frames should be forwarded to peer ranks.
723///
724/// This is the continuation of a multicast operation: each forwarded
725/// message contains one or more `RoutingFrame`s that represent
726/// partial routing state. This call determines how those frames
727/// propagate next.
728///
729/// `deliver_here` is true if any frame targets this rank and
730/// indicates delivery. `next_steps` contains the peer ranks and frames
731/// to forward.
732///
733/// This is also the top-level entry point for CommActor's routing
734/// logic.
735pub fn resolve_routing(
736 rank: usize,
737 frames: impl IntoIterator<Item = RoutingFrame>,
738 chooser: &mut dyn FnMut(&Choice) -> usize,
739) -> Result<(bool, HashMap<usize, Vec<RoutingFrame>>)> {
740 let mut deliver_here = false;
741 let mut next_steps = HashMap::new();
742 for frame in frames {
743 resolve_routing_one(rank, frame, chooser, &mut deliver_here, &mut next_steps)?;
744 }
745 Ok((deliver_here, next_steps))
746}
747
748/// Recursively resolves routing for a single `RoutingFrame` at the
749/// given rank, determining whether the message should be delivered
750/// locally and which frames should be forwarded to peer ranks.
751///
752/// - If the frame targets the local `rank` and is a delivery point,
753/// `deliver_here` is set to `true`.
754/// - If the frame targets the local `rank` but is not a delivery
755/// point, the function recurses on its forward steps.
756/// - If the frame targets a different rank, it is added to
757/// `next_steps`.
758///
759/// Deduplication is handled by `get_next_steps`. Dynamic constructs
760/// such as `Any` or `First` must be resolved by the provided
761/// `chooser`, which selects an index from a `Choice`.
762///
763/// Traversal is depth-first within a rank and breadth-first across
764/// ranks. This defines the exact routing behavior used by
765/// `CommActor`: it exhaustively evaluates all local routing structure
766/// before forwarding to peers.
767///
768/// The resulting `next_steps` map contains all non-local ranks that
769/// should receive forwarded frames, where each entry maps a peer rank
770/// to a list of routing continuations to evaluate at that peer.
771/// `deliver_here` is set to `true` if the current rank is a final
772/// delivery point.
773pub(crate) fn resolve_routing_one(
774 rank: usize,
775 frame: RoutingFrame,
776 chooser: &mut dyn FnMut(&Choice) -> usize,
777 deliver_here: &mut bool,
778 next_steps: &mut HashMap<usize, Vec<RoutingFrame>>,
779) -> Result<()> {
780 let frame_rank = frame.slice.location(&frame.here)?;
781 if frame_rank == rank {
782 if frame.deliver_here() {
783 *deliver_here = true;
784 } else {
785 for next in get_next_steps(frame, chooser)? {
786 resolve_routing_one(rank, next, chooser, deliver_here, next_steps)?;
787 }
788 }
789 } else {
790 next_steps.entry(frame_rank).or_default().push(frame);
791 }
792 Ok(())
793}
794
795/// Computes the set of `Forward` routing frames reachable from the
796/// given `RoutingFrame`.
797///
798/// This function traverses the result of `frame.next_steps(...)`,
799/// collecting only `RoutingStep::Forward(_)` steps. The caller
800/// provides a `chooser` function to resolve dynamic constructs such
801/// as `Any` or `First`.
802///
803/// Some obviously redundant steps may be filtered, but no strict
804/// guarantee is made about structural uniqueness.
805fn get_next_steps(
806 dest: RoutingFrame,
807 chooser: &mut dyn FnMut(&Choice) -> usize,
808) -> Result<Vec<RoutingFrame>> {
809 let mut seen = HashSet::new();
810 let mut unique_steps = vec![];
811 let _ = dest.next_steps(chooser, &mut |step| {
812 if let RoutingStep::Forward(frame) = step {
813 let key = RoutingFrameKey::new(&frame);
814 if seen.insert(key) {
815 unique_steps.push(frame);
816 }
817 }
818 ControlFlow::Continue(())
819 });
820 Ok(unique_steps)
821}
822
823// == Testing (`collect_commactor_routing_tree` mesh simulation) ===
824
825/// Captures the logical structure of a CommActor multicast operation.
826///
827/// This type models how a message is delivered and forwarded through
828/// a mesh under CommActor routing semantics. It is used in tests to
829/// verify path determinism and understand message propagation
830/// behavior.
831///
832/// - `delivered`: ranks where the message was delivered (`post`
833/// called)
834/// - `visited`: all ranks that participated, including forwarding
835/// only
836/// - `forwards`: maps each rank to the routing frames it forwarded
837#[cfg(test)]
838#[allow(dead_code)]
839#[derive(Default)]
840pub(crate) struct CommActorRoutingTree {
841 // Ranks that were delivered the message (i.e. called `post`). Map
842 // from rank → delivery path (flat rank indices) from root to that
843 // rank.
844 pub delivered: HashMap<usize, Vec<usize>>,
845
846 // Ranks that participated in the multicast - either by delivering
847 // the message or forwarding it to peers.
848 pub visited: HashSet<usize>,
849
850 /// Map from rank → routing frames this rank forwarded to other
851 /// ranks.
852 pub forwards: HashMap<usize, Vec<RoutingFrame>>,
853}
854
855#[cfg(test)]
856mod tests {
857 use std::collections::HashSet;
858 use std::collections::VecDeque;
859
860 use super::RoutingAction;
861 use super::RoutingFrame;
862 use super::print_route;
863 use super::print_routing_tree;
864 use crate::Slice;
865 use crate::selection::EvalOpts;
866 use crate::selection::Selection;
867 use crate::selection::dsl::*;
868 use crate::selection::test_utils::RoutedMessage;
869 use crate::selection::test_utils::collect_commactor_routing_tree;
870 use crate::selection::test_utils::collect_routed_nodes;
871 use crate::selection::test_utils::collect_routed_paths;
872 use crate::shape;
873
874 // A test slice: (zones = 2, hosts = 4, gpus = 8).
875 fn test_slice() -> Slice {
876 Slice::new(0usize, vec![2, 4, 8], vec![32, 8, 1]).unwrap()
877 }
878
879 /// Asserts that a routing strategy produces the same set of nodes
880 /// as `Selection::eval`.
881 ///
882 /// This macro compares the result of evaluating a `Selection`
883 /// using the given `collector` against the reference
884 /// implementation `Selection::eval` (with lenient options).
885 ///
886 /// The `collector` should be a function or closure of type
887 /// `Fn(&Selection, &Slice) -> Vec<usize>`, such as
888 /// `collect_routed_nodes` or a CommActor-based simulation.
889 ///
890 /// Panics if the two sets of routed nodes differ.
891 ///
892 /// # Example
893 /// ```
894 /// assert_routing_eq_with!(slice, selection, collect_routed_nodes);
895 /// ```
896 macro_rules! assert_routing_eq_with {
897 ($slice:expr_2021, $sel:expr_2021, $collector:expr_2021) => {{
898 let sel = $sel;
899 let slice = $slice.clone();
900 let mut expected: Vec<_> = sel.eval(&EvalOpts::lenient(), &slice).unwrap().collect();
901 expected.sort();
902 let mut actual: Vec<_> = ($collector)(&sel, &slice);
903 actual.sort();
904 assert_eq!(actual, expected, "Mismatch for selection: {}", sel);
905 }};
906 }
907
908 /// Asserts that `collect_routed_nodes` matches `Selection::eval`
909 /// on the given slice.
910 macro_rules! assert_collect_routed_nodes_eq {
911 ($slice:expr_2021, $sel:expr_2021) => {
912 assert_routing_eq_with!($slice, $sel, collect_routed_nodes)
913 };
914 }
915
916 /// Asserts that CommActor routing delivers to the same nodes as
917 /// `Selection::eval`.
918 macro_rules! assert_commactor_routing_eq {
919 ($slice:expr_2021, $sel:expr_2021) => {
920 assert_routing_eq_with!($slice, $sel, |s, sl| {
921 collect_commactor_routing_tree(s, sl)
922 .delivered
923 .into_keys()
924 .collect()
925 });
926 };
927 }
928
929 /// Asserts that all routing strategies produce the same set of
930 /// routed nodes as `Selection::eval`.
931 ///
932 /// Compares both the direct strategy (`collect_routed_nodes`) and
933 /// the CommActor routing simulation
934 /// (`collect_commactor_routing_tree`) against the expected output
935 /// from `Selection::eval`.
936 macro_rules! assert_all_routing_strategies_eq {
937 ($slice:expr_2021, $sel:expr_2021) => {
938 assert_collect_routed_nodes_eq!($slice, $sel);
939 assert_commactor_routing_eq!($slice, $sel);
940 };
941 }
942
943 #[test]
944 fn test_routing_04() {
945 use crate::selection::dsl::*;
946
947 let slice = test_slice(); // [2, 4, 8], strides [32, 8, 1]
948
949 // Destination: GPU 2 on host 2 in zone 1.
950 let dest = vec![1, 2, 2];
951 let selection = range(1, range(2, range(2, true_())));
952 let root = RoutingFrame::root(selection.clone(), slice.clone());
953 let path = root.trace_route(&dest).expect("no route found");
954 println!(
955 "\ndest: {:?}, (singleton-)selection: ({})\n",
956 &dest, &selection
957 );
958 print_route(&path);
959 println!("\n");
960 assert_eq!(path.last(), Some(&dest));
961
962 // Destination: "Right back where we started from 🙂".
963 let dest = vec![0, 0, 0];
964 let selection = range(0, range(0, range(0, true_())));
965 let root = RoutingFrame::root(selection.clone(), slice.clone());
966 let path = root.trace_route(&dest).expect("no route found");
967 println!(
968 "\ndest: {:?}, (singleton-)selection: ({})\n",
969 &dest, &selection
970 );
971 print_route(&path);
972 println!("\n");
973 assert_eq!(path.last(), Some(&dest));
974 }
975
976 #[test]
977 fn test_routing_05() {
978 use crate::selection::dsl::*;
979
980 // "Jun's example" -- a 2 x 2 row major mesh.
981 let slice = Slice::new(0usize, vec![2, 2], vec![2, 1]).unwrap();
982 // Thats is,
983 // (0, 0) (0, 1)
984 // (0, 1) (1, 0)
985 //
986 // and we want to cast to {(0, 1), (1, 0) and (1, 1)}:
987 //
988 // (0, 0)❌ (0, 1)✅
989 // (0, 1)✅ (1, 0)✅
990 //
991 // One reasonable selection expression describing the
992 // destination set.
993 let selection = union(range(0, range(1, true_())), range(1, all(true_())));
994
995 // Now print the routing tree.
996 print_routing_tree(selection, &slice);
997
998 // Prints:
999 // (0, 0)
1000 // (0, 0)
1001 // (0, 1) ✅
1002 // (1, 0)
1003 // (1, 0) ✅
1004 // (1, 1) ✅
1005
1006 // Another example: (zones = 2, hosts = 4, gpus = 8).
1007 let slice = Slice::new(0usize, vec![2, 4, 8], vec![32, 8, 1]).unwrap();
1008 // Let's have all the odd GPUs on hosts 1, 2 and 3 in zone 0.
1009 let selection = range(
1010 0,
1011 range(1..4, range(shape::Range(1, None, /*step*/ 2), true_())),
1012 );
1013
1014 // Now print the routing tree.
1015 print_routing_tree(selection, &slice);
1016
1017 // Prints:
1018 // (0, 0, 0)
1019 // (0, 0, 0)
1020 // (0, 1, 0)
1021 // (0, 1, 1) ✅
1022 // (0, 1, 3) ✅
1023 // (0, 1, 5) ✅
1024 // (0, 1, 7) ✅
1025 // (0, 2, 0)
1026 // (0, 2, 1) ✅
1027 // (0, 2, 3) ✅
1028 // (0, 2, 5) ✅
1029 // (0, 2, 7) ✅
1030 // (0, 3, 0)
1031 // (0, 3, 1) ✅
1032 // (0, 3, 3) ✅
1033 // (0, 3, 5) ✅
1034 // (0, 3, 7) ✅
1035 }
1036
1037 #[test]
1038 fn test_routing_00() {
1039 let slice = test_slice();
1040
1041 assert_all_routing_strategies_eq!(slice, false_());
1042 assert_all_routing_strategies_eq!(slice, true_());
1043 assert_all_routing_strategies_eq!(slice, all(true_()));
1044 assert_all_routing_strategies_eq!(slice, all(all(true_())));
1045 assert_all_routing_strategies_eq!(slice, all(all(false_())));
1046 assert_all_routing_strategies_eq!(slice, all(all(all(true_()))));
1047 assert_all_routing_strategies_eq!(slice, all(range(0..=0, all(true_()))));
1048 assert_all_routing_strategies_eq!(slice, all(all(range(0..4, true_()))));
1049 assert_all_routing_strategies_eq!(slice, all(range(1..=2, all(true_()))));
1050 assert_all_routing_strategies_eq!(slice, all(all(range(2..6, true_()))));
1051 assert_all_routing_strategies_eq!(slice, all(all(range(3..=3, true_()))));
1052 assert_all_routing_strategies_eq!(slice, all(range(1..3, all(true_()))));
1053 assert_all_routing_strategies_eq!(slice, all(all(range(0..=0, true_()))));
1054 assert_all_routing_strategies_eq!(slice, range(1..=1, range(3..=3, range(0..=2, true_()))));
1055 assert_all_routing_strategies_eq!(
1056 slice,
1057 all(all(range(shape::Range(0, Some(8), 2), true_())))
1058 );
1059 assert_all_routing_strategies_eq!(
1060 slice,
1061 all(range(shape::Range(1, Some(4), 2), all(true_())))
1062 );
1063 }
1064
1065 #[test]
1066 fn test_routing_03() {
1067 let slice = test_slice();
1068
1069 assert_all_routing_strategies_eq!(
1070 slice,
1071 // sel!(0 & (0,(1|3), *))
1072 intersection(
1073 range(0, true_()),
1074 range(0, union(range(1, all(true_())), range(3, all(true_()))))
1075 )
1076 );
1077 assert_all_routing_strategies_eq!(
1078 slice,
1079 // sel!(0 & (0, (3|1), *)),
1080 intersection(
1081 range(0, true_()),
1082 range(0, union(range(3, all(true_())), range(1, all(true_()))))
1083 )
1084 );
1085 assert_all_routing_strategies_eq!(
1086 slice,
1087 // sel!((*, *, *) & (*, *, (2 | 4)))
1088 intersection(
1089 all(all(all(true_()))),
1090 all(all(union(range(2, true_()), range(4, true_()))))
1091 )
1092 );
1093 assert_all_routing_strategies_eq!(
1094 slice,
1095 // sel!((*, *, *) & (*, *, (4 | 2)))
1096 intersection(
1097 all(all(all(true_()))),
1098 all(all(union(range(4, true_()), range(2, true_()))))
1099 )
1100 );
1101 assert_all_routing_strategies_eq!(
1102 slice,
1103 // sel!((*, (1 | 2)) & (*, (2 | 1)))
1104 intersection(
1105 all(union(range(1, true_()), range(2, true_()))),
1106 all(union(range(2, true_()), range(1, true_())))
1107 )
1108 );
1109 assert_all_routing_strategies_eq!(
1110 slice,
1111 intersection(all(all(all(true_()))), all(true_()))
1112 );
1113 assert_all_routing_strategies_eq!(slice, intersection(true_(), all(all(all(true_())))));
1114 assert_all_routing_strategies_eq!(slice, intersection(all(all(all(true_()))), false_()));
1115 assert_all_routing_strategies_eq!(slice, intersection(false_(), all(all(all(true_())))));
1116 assert_all_routing_strategies_eq!(
1117 slice,
1118 intersection(
1119 all(all(range(0..4, true_()))),
1120 all(all(range(0..4, true_())))
1121 )
1122 );
1123 assert_all_routing_strategies_eq!(
1124 slice,
1125 intersection(all(all(range(1, true_()))), all(all(range(2, true_()))))
1126 );
1127 assert_all_routing_strategies_eq!(
1128 slice,
1129 intersection(all(all(range(2, true_()))), all(all(range(1, true_()))))
1130 );
1131 assert_all_routing_strategies_eq!(
1132 slice,
1133 intersection(
1134 all(all(range(1, true_()))),
1135 intersection(all(all(true_())), all(all(range(1, true_()))))
1136 )
1137 );
1138 assert_all_routing_strategies_eq!(
1139 slice,
1140 intersection(
1141 range(0, true_()),
1142 range(0, all(union(range(1, true_()), range(3, true_()))))
1143 )
1144 );
1145 assert_all_routing_strategies_eq!(
1146 slice,
1147 range(
1148 0,
1149 intersection(true_(), all(union(range(1, true_()), range(3, true_()))))
1150 )
1151 );
1152 assert_all_routing_strategies_eq!(
1153 slice,
1154 intersection(all(range(1..=2, true_())), all(range(2..=3, true_())))
1155 );
1156 assert_all_routing_strategies_eq!(
1157 slice,
1158 intersection(
1159 range(0, true_()),
1160 intersection(range(0, all(true_())), range(0, range(1, all(true_()))))
1161 )
1162 );
1163 assert_all_routing_strategies_eq!(
1164 slice,
1165 intersection(
1166 range(0, range(1, all(true_()))),
1167 intersection(range(0, all(true_())), range(0, true_()))
1168 )
1169 );
1170 assert_all_routing_strategies_eq!(
1171 slice,
1172 // sel!( (*, *, *) & ((*, *, *) & (*, *, *)) ),
1173 intersection(
1174 all(all(all(true_()))),
1175 intersection(all(all(all(true_()))), all(all(all(true_()))))
1176 )
1177 );
1178 assert_all_routing_strategies_eq!(
1179 slice,
1180 union(
1181 intersection(range(0, true_()), range(0, range(1, all(true_())))),
1182 range(1, all(all(true_())))
1183 )
1184 );
1185 assert_all_routing_strategies_eq!(
1186 slice,
1187 // sel!((1, *, *) | (0 & (0, 3, *)))
1188 union(
1189 range(1, all(all(true_()))),
1190 intersection(range(0, true_()), range(0, range(3, all(true_()))))
1191 )
1192 );
1193 assert_all_routing_strategies_eq!(
1194 slice,
1195 intersection(
1196 union(range(0, true_()), range(1, true_())),
1197 union(range(1, true_()), range(0, true_()))
1198 )
1199 );
1200 assert_all_routing_strategies_eq!(
1201 slice,
1202 union(
1203 intersection(range(0, range(1, true_())), range(0, range(1, true_()))),
1204 intersection(range(1, range(3, true_())), range(1, range(3, true_())))
1205 )
1206 );
1207 assert_all_routing_strategies_eq!(
1208 slice,
1209 // sel!(*, 8 : 8)
1210 all(range(8..8, true_()))
1211 );
1212 assert_all_routing_strategies_eq!(
1213 slice,
1214 // sel!((*, 1) & (*, 8 : 8))
1215 intersection(all(range(1..2, true_())), all(range(8..8, true_())))
1216 );
1217 assert_all_routing_strategies_eq!(
1218 slice,
1219 // sel!((*, 8 : 8) | (*, 1))
1220 union(all(range(8..8, true_())), all(range(1..2, true_())))
1221 );
1222 assert_all_routing_strategies_eq!(
1223 slice,
1224 // sel!((*, 1) | (*, 2:8))
1225 union(all(range(1..2, true_())), all(range(2..8, true_())))
1226 );
1227 assert_all_routing_strategies_eq!(
1228 slice,
1229 // sel!((*, *, *) & (*, *, 2:8))
1230 intersection(all(all(all(true_()))), all(all(range(2..8, true_()))))
1231 );
1232 }
1233
1234 #[test]
1235 fn test_routing_02() {
1236 let slice = test_slice();
1237
1238 // zone 0 or 1: sel!(0 | 1, *, *)
1239 assert_all_routing_strategies_eq!(slice, union(range(0, true_()), range(1, true_())));
1240 assert_all_routing_strategies_eq!(
1241 slice,
1242 union(range(0, all(true_())), range(1, all(true_())))
1243 );
1244 // hosts 1 and 3 in zone 0: sel!(0, (1 | 3), *)
1245 assert_all_routing_strategies_eq!(
1246 slice,
1247 range(0, union(range(1, all(true_())), range(3, all(true_()))))
1248 );
1249 // sel!(0, 1:3 | 5:7, *)
1250 assert_all_routing_strategies_eq!(
1251 slice,
1252 range(
1253 0,
1254 union(
1255 range(shape::Range(1, Some(3), 1), all(true_())),
1256 range(shape::Range(5, Some(7), 1), all(true_()))
1257 )
1258 )
1259 );
1260
1261 // sel!(* | *): We start with `union(true_(), true_())`.
1262 //
1263 // Evaluating the left branch generates routing frames
1264 // recursively. Evaluating the right branch generates the same
1265 // frames again.
1266 //
1267 // As a result, we produce duplicate `RoutingFrame`s that
1268 // have:
1269 // - the same `here` coordinate,
1270 // - the same dimension (`dim`), and
1271 // - the same residual selection (`True`).
1272 //
1273 // When both frames reach the delivery condition, the second
1274 // call to `delivered.insert()` returns `false`. If we put an
1275 // `assert!` on that line this would trigger assertion failure
1276 // in the routing simulation.
1277 //
1278 // TODO: We need memoization to avoid redundant work.
1279 //
1280 // This can be achieved without transforming the algebra itself.
1281 // However, adding normalization will make memoization more
1282 // effective, so we should plan to implement both.
1283 //
1284 // Once that's done, we can safely restore the `assert!`.
1285 assert_all_routing_strategies_eq!(slice, union(true_(), true_()));
1286 // sel!(*, *, * | *, *, *)
1287 assert_all_routing_strategies_eq!(
1288 slice,
1289 union(all(all(all(true_()))), all(all(all(true_()))))
1290 );
1291 // no 'false' support in sel!
1292 assert_all_routing_strategies_eq!(slice, union(false_(), all(all(all(true_())))));
1293 assert_all_routing_strategies_eq!(slice, union(all(all(all(true_()))), false_()));
1294 // sel!(0, 0:4, 0 | 1 | 2)
1295 assert_all_routing_strategies_eq!(
1296 slice,
1297 range(
1298 0,
1299 range(
1300 shape::Range(0, Some(4), 1),
1301 union(
1302 range(0, true_()),
1303 union(range(1, true_()), range(2, true_()))
1304 )
1305 )
1306 )
1307 );
1308 assert_all_routing_strategies_eq!(
1309 slice,
1310 range(
1311 0,
1312 union(range(2, range(4, true_())), range(3, range(5, true_())),),
1313 )
1314 );
1315 assert_all_routing_strategies_eq!(
1316 slice,
1317 range(0, range(2, union(range(4, true_()), range(5, true_()),),),)
1318 );
1319 assert_all_routing_strategies_eq!(
1320 slice,
1321 range(
1322 0,
1323 union(range(2, range(4, true_())), range(3, range(5, true_())),),
1324 )
1325 );
1326 assert_all_routing_strategies_eq!(
1327 slice,
1328 union(
1329 range(
1330 0,
1331 union(range(2, range(4, true_())), range(3, range(5, true_())))
1332 ),
1333 range(
1334 1,
1335 union(range(2, range(4, true_())), range(3, range(5, true_())))
1336 )
1337 )
1338 );
1339 }
1340
1341 #[test]
1342 fn test_routing_01() {
1343 use std::ops::ControlFlow;
1344
1345 let slice = test_slice();
1346 let sel = range(0..=0, all(true_()));
1347
1348 let expected_fanouts: &[&[&[usize]]] = &[
1349 &[&[0, 0, 0]],
1350 &[&[0, 0, 0], &[0, 1, 0], &[0, 2, 0], &[0, 3, 0]],
1351 &[
1352 &[0, 0, 0],
1353 &[0, 0, 1],
1354 &[0, 0, 2],
1355 &[0, 0, 3],
1356 &[0, 0, 4],
1357 &[0, 0, 5],
1358 &[0, 0, 6],
1359 &[0, 0, 7],
1360 ],
1361 &[
1362 &[0, 1, 0],
1363 &[0, 1, 1],
1364 &[0, 1, 2],
1365 &[0, 1, 3],
1366 &[0, 1, 4],
1367 &[0, 1, 5],
1368 &[0, 1, 6],
1369 &[0, 1, 7],
1370 ],
1371 &[
1372 &[0, 2, 0],
1373 &[0, 2, 1],
1374 &[0, 2, 2],
1375 &[0, 2, 3],
1376 &[0, 2, 4],
1377 &[0, 2, 5],
1378 &[0, 2, 6],
1379 &[0, 2, 7],
1380 ],
1381 &[
1382 &[0, 3, 0],
1383 &[0, 3, 1],
1384 &[0, 3, 2],
1385 &[0, 3, 3],
1386 &[0, 3, 4],
1387 &[0, 3, 5],
1388 &[0, 3, 6],
1389 &[0, 3, 7],
1390 ],
1391 ];
1392
1393 let expected_deliveries: &[bool] = &[
1394 false, false, false, false, false, false, // Steps 0–5
1395 true, true, true, true, true, true, true, true, true, true, true, true, true, true,
1396 true, true, true, true, true, true, true, true, true, true, true, true, true, true,
1397 true, true, true, true, true, // Steps 6–38
1398 ];
1399
1400 let mut step = 0;
1401 let mut pending = VecDeque::new();
1402
1403 pending.push_back(RoutingFrame::root(sel.clone(), slice.clone()));
1404
1405 println!("Fan-out trace for selection: {}", sel);
1406
1407 while let Some(frame) = pending.pop_front() {
1408 let mut next_coords = vec![];
1409
1410 let deliver_here = frame.deliver_here();
1411
1412 let _ = frame.next_steps(
1413 &mut |_| panic!("Choice encountered in test_routing_01"),
1414 &mut |step| {
1415 let next = step.into_forward().unwrap();
1416 next_coords.push(next.here.clone());
1417 pending.push_back(next);
1418 ControlFlow::Continue(())
1419 },
1420 );
1421
1422 println!(
1423 "Step {:>2}: from {:?} (flat = {:>2}) | deliver = {} | fan-out count = {} | selection = {:?}",
1424 step,
1425 frame.here,
1426 frame.slice.location(&frame.here).unwrap(),
1427 deliver_here,
1428 next_coords.len(),
1429 format!("{}", frame.selection),
1430 );
1431
1432 for next in &next_coords {
1433 println!(" → {:?}", next);
1434 }
1435
1436 if step < expected_fanouts.len() {
1437 let expected = expected_fanouts[step]
1438 .iter()
1439 .map(|v| v.to_vec())
1440 .collect::<Vec<_>>();
1441 assert_eq!(
1442 next_coords, expected,
1443 "Mismatch in next_coords at step {}",
1444 step
1445 );
1446 }
1447
1448 if step < expected_deliveries.len() {
1449 assert_eq!(
1450 deliver_here, expected_deliveries[step],
1451 "Mismatch in deliver_here at step {} (coord = {:?})",
1452 step, frame.here
1453 );
1454 }
1455
1456 step += 1;
1457 }
1458 }
1459
1460 #[test]
1461 fn test_routing_06() {
1462 use std::ops::ControlFlow;
1463
1464 use crate::selection::dsl::*;
1465 use crate::selection::routing::RoutingFrameKey;
1466 use crate::selection::routing::RoutingStep;
1467
1468 let slice = test_slice();
1469 let selection = union(all(true_()), all(true_()));
1470
1471 let mut pending = VecDeque::new();
1472 let mut dedup_delivered = Vec::new();
1473 let mut nodup_delivered = Vec::new();
1474 let mut seen = HashSet::new();
1475
1476 let root = RoutingFrame::root(selection.clone(), slice.clone());
1477 pending.push_back(RoutedMessage::<()>::new(root.here.clone(), root));
1478
1479 while let Some(RoutedMessage { frame, .. }) = pending.pop_front() {
1480 let mut visitor = |step: RoutingStep| {
1481 let next = step.into_forward().unwrap();
1482
1483 if next.action() == RoutingAction::Deliver {
1484 nodup_delivered.push(next.slice.location(&next.here).unwrap());
1485 }
1486
1487 let key = RoutingFrameKey::new(&next);
1488 if seen.insert(key) && next.action() == RoutingAction::Deliver {
1489 dedup_delivered.push(next.slice.location(&next.here).unwrap());
1490 }
1491
1492 if next.action() == RoutingAction::Forward {
1493 pending.push_back(RoutedMessage::new(frame.here.clone(), next));
1494 }
1495
1496 ControlFlow::Continue(())
1497 };
1498
1499 let _ = frame.next_steps(
1500 &mut |_| panic!("Choice encountered in test_routing_06"),
1501 &mut visitor,
1502 );
1503 }
1504
1505 assert_eq!(dedup_delivered.len(), 64);
1506 assert_eq!(nodup_delivered.len(), 128);
1507 }
1508
1509 #[test]
1510 fn test_routing_07() {
1511 use std::ops::ControlFlow;
1512
1513 use crate::selection::dsl::*;
1514 use crate::selection::routing::RoutingFrame;
1515 use crate::selection::routing::RoutingStep;
1516
1517 let slice = test_slice(); // shape: [2, 4, 8]
1518
1519 // Selection: any zone, all hosts, all gpus.
1520 let selection = any(all(all(true_())));
1521 let frame = RoutingFrame::root(selection, slice.clone());
1522
1523 let mut steps = vec![];
1524 let _ = frame.next_steps(
1525 &mut |_| panic!("Choice encountered in test_routing_07"),
1526 &mut |step: RoutingStep| {
1527 steps.push(step);
1528 ControlFlow::Continue(())
1529 },
1530 );
1531
1532 // Only one hop should be produced at the `any` dimension.
1533 assert_eq!(steps.len(), 1);
1534
1535 // Reject choices.
1536 let hop = &steps[0].as_forward().unwrap();
1537
1538 // There should be 3 components to the frame's coordinate.
1539 assert_eq!(hop.here.len(), 3);
1540
1541 // The selected zone (dim 0) should be in bounds.
1542 let zone = hop.here[0];
1543 assert!(zone < 2, "zone out of bounds: {}", zone);
1544
1545 // Inner selection should still be All(All(True))
1546 assert!(matches!(hop.selection, Selection::All(_)));
1547 }
1548
1549 // This test relies on a deep structural property of the routing
1550 // semantics:
1551 //
1552 // Overdelivery is prevented not by ad hoc guards, but by the
1553 // structure of the traversal itself — particularly in the
1554 // presence of routing frame deduplication.
1555 //
1556 // When a frame reaches the final dimension with `selection ==
1557 // True`, it becomes a delivery frame. If multiple such frames
1558 // target the same coordinate, then:
1559 //
1560 // - They must share the same coordinate `here`
1561 // - They must have reached it via the same routing path (by the
1562 // Unique Path Theorem)
1563 // - Their `RoutingFrame` state is thus structurally identical:
1564 // - Same `here`
1565 // - Same `dim` (equal to `slice.num_dim()`)
1566 // - Same residual `selection == True`
1567 //
1568 // The deduplication logic (via `RoutingFrameKey`) collapses such
1569 // structurally equivalent frames. As a result, only one frame
1570 // delivers to the target coordinate, and overdelivery is
1571 // structurally ruled out.
1572 //
1573 // This test verifies that behavior holds as expected — and, when
1574 // deduplication is disabled, confirms that overdelivery becomes
1575 // observable.
1576 #[test]
1577 fn test_routing_deduplication_precludes_overdelivery() {
1578 // Ensure the environment is clean — this test depends on a
1579 // known configuration of deduplication behavior.
1580 let var = "HYPERACTOR_SELECTION_DISABLE_ROUTING_FRAME_DEDUPLICATION";
1581 assert!(
1582 std::env::var_os(var).is_none(),
1583 "env var `{}` should not be set prior to test",
1584 var
1585 );
1586 let slice = test_slice();
1587
1588 // Construct a structurally duplicated selection.
1589 //
1590 // The union duplicates a singleton selection expression.
1591 // Without deduplication, this would result in two logically
1592 // identical frames targeting the same node — which should
1593 // trigger an over-delivery panic in the simulation.
1594 let a = range(0, range(0, range(0, true_())));
1595 let sel = union(a.clone(), a.clone());
1596
1597 // Sanity check: with deduplication enabled (default), this
1598 // selection does not cause overdelivery.
1599 let result = std::panic::catch_unwind(|| {
1600 let _ = collect_routed_paths(&sel, &slice);
1601 });
1602 assert!(result.is_ok(), "Unexpected panic due to overdelivery");
1603
1604 // Now explicitly disable deduplication.
1605 // SAFETY: TODO: Audit that the environment access only
1606 // happens in single-threaded code.
1607 unsafe { std::env::set_var(var, "1") };
1608
1609 // Expect overdelivery: the duplicated union arms will each
1610 // produce a delivery to the same coordinate.
1611 let result = std::panic::catch_unwind(|| {
1612 let _ = collect_routed_paths(&sel, &slice);
1613 });
1614
1615 // Clean up: restore environment to avoid affecting other
1616 // tests.
1617 // SAFETY: TODO: Audit that the environment access only
1618 // happens in single-threaded code.
1619 unsafe { std::env::remove_var(var) };
1620
1621 assert!(
1622 result.is_err(),
1623 "Expected panic due to overdelivery, but no panic occurred"
1624 );
1625 }
1626
1627 #[test]
1628 fn test_next_steps_zero_dim_slice() {
1629 use std::ops::ControlFlow;
1630
1631 use crate::selection::dsl::*;
1632
1633 let slice = Slice::new(42, vec![], vec![]).unwrap();
1634
1635 let selection = true_();
1636 let frame = RoutingFrame::root(selection, slice.clone());
1637 let mut steps = vec![];
1638 let _ = frame.next_steps(
1639 &mut |_| panic!("Unexpected Choice in 0D test"),
1640 &mut |step| {
1641 steps.push(step);
1642 ControlFlow::Continue(())
1643 },
1644 );
1645
1646 assert_eq!(steps.len(), 1);
1647 let step = steps[0].as_forward().unwrap();
1648 assert_eq!(step.here, vec![0]);
1649 assert!(step.deliver_here());
1650 assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1651
1652 let selection = all(true_());
1653 let frame = RoutingFrame::root(selection, slice.clone());
1654 let mut steps = vec![];
1655 let _ = frame.next_steps(
1656 &mut |_| panic!("Unexpected Choice in 0D test"),
1657 &mut |step| {
1658 steps.push(step);
1659 ControlFlow::Continue(())
1660 },
1661 );
1662
1663 assert_eq!(steps.len(), 1);
1664 let step = steps[0].as_forward().unwrap();
1665 assert_eq!(step.here, vec![0]);
1666 assert!(step.deliver_here());
1667 assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1668
1669 let selection = false_();
1670 let frame = RoutingFrame::root(selection, slice.clone());
1671 let mut steps = vec![];
1672 let _ = frame.next_steps(
1673 &mut |_| panic!("Unexpected Choice in 0D test"),
1674 &mut |step| {
1675 steps.push(step);
1676 ControlFlow::Continue(())
1677 },
1678 );
1679
1680 assert_eq!(steps.len(), 1);
1681 let step = steps[0].as_forward().unwrap();
1682 assert_eq!(step.here, vec![0]);
1683 assert!(!step.deliver_here());
1684 assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1685
1686 let selection = all(false_());
1687 let frame = RoutingFrame::root(selection, slice.clone());
1688 let mut steps = vec![];
1689 let _ = frame.next_steps(
1690 &mut |_| panic!("Unexpected Choice in 0D test"),
1691 &mut |step| {
1692 steps.push(step);
1693 ControlFlow::Continue(())
1694 },
1695 );
1696 assert_eq!(steps.len(), 1);
1697 let step = steps[0].as_forward().unwrap();
1698 assert_eq!(step.here, vec![0]);
1699 assert!(!step.deliver_here());
1700 assert_eq!(step.slice.location(&step.here).unwrap(), 42);
1701 }
1702}