1use std::ops::Index;
10use std::sync::Arc;
11
12use serde::Deserialize;
13use serde::Serialize;
14use thiserror::Error;
15
16use crate::Range;
17use crate::Slice;
18use crate::SliceIterator;
19use crate::slice::CartesianIterator;
20
21#[derive(Debug, thiserror::Error)]
23pub enum ExtentError {
24 #[error("label/sizes dimension mismatch: {num_labels} != {num_sizes}")]
30 DimMismatch {
31 num_labels: usize,
33 num_sizes: usize,
35 },
36}
37
38#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
46pub struct Extent {
47 inner: Arc<ExtentData>,
48}
49
50fn _assert_extent_traits()
51where
52 Extent: Send + Sync + 'static,
53{
54}
55
56#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
61struct ExtentData {
62 labels: Vec<String>,
63 sizes: Vec<usize>,
64}
65
66impl Extent {
67 pub fn new(labels: Vec<String>, sizes: Vec<usize>) -> Result<Self, ExtentError> {
69 if labels.len() != sizes.len() {
70 return Err(ExtentError::DimMismatch {
71 num_labels: labels.len(),
72 num_sizes: sizes.len(),
73 });
74 }
75
76 Ok(Self {
77 inner: Arc::new(ExtentData { labels, sizes }),
78 })
79 }
80
81 pub fn labels(&self) -> &[String] {
83 &self.inner.labels
84 }
85
86 pub fn sizes(&self) -> &[usize] {
88 &self.inner.sizes
89 }
90
91 pub fn size(&self, label: &str) -> Option<usize> {
94 self.position(label).map(|pos| self.sizes()[pos])
95 }
96
97 pub fn position(&self, label: &str) -> Option<usize> {
100 self.labels().iter().position(|l| l == label)
101 }
102
103 pub fn point(&self, coords: Vec<usize>) -> Result<Point, PointError> {
108 if coords.len() != self.len() {
109 return Err(PointError::DimMismatch {
110 expected: self.len(),
111 actual: coords.len(),
112 });
113 }
114
115 Ok(Point {
116 coords,
117 extent: Extent {
118 inner: Arc::clone(&self.inner),
119 },
120 })
121 }
122
123 pub fn point_of_rank(&self, mut rank: usize) -> Result<Point, PointError> {
125 if rank >= self.num_ranks() {
126 return Err(PointError::OutOfRange {
127 size: self.len(),
128 rank,
129 });
130 }
131
132 let mut stride: usize = self.sizes().iter().product();
133 let mut coords = vec![0; self.len()];
134 for (i, size) in self.sizes().iter().enumerate() {
135 stride /= size;
136 coords[i] = rank / stride;
137 rank %= stride;
138 }
139
140 Ok(Point {
141 coords,
142 extent: self.clone(),
143 })
144 }
145
146 pub fn len(&self) -> usize {
148 self.sizes().len()
149 }
150
151 pub fn is_empty(&self) -> bool {
153 self.sizes().is_empty()
154 }
155
156 pub fn num_ranks(&self) -> usize {
158 self.sizes().iter().product()
159 }
160
161 pub fn into_inner(self) -> (Vec<String>, Vec<usize>) {
163 match Arc::try_unwrap(self.inner) {
164 Ok(data) => (data.labels, data.sizes),
165 Err(shared) => (shared.labels.clone(), shared.sizes.clone()),
166 }
167 }
168
169 pub fn to_slice(&self) -> Slice {
171 Slice::new_row_major(self.sizes())
172 }
173
174 pub fn iter(&self) -> impl Iterator<Item = (String, usize)> + use<'_> {
176 self.labels()
177 .iter()
178 .zip(self.sizes().iter())
179 .map(|(l, s)| (l.clone(), *s))
180 }
181
182 pub fn points(&self) -> ExtentPointsIterator {
184 ExtentPointsIterator {
185 extent: self,
186 pos: CartesianIterator::new(self.sizes().to_vec()),
187 }
188 }
189}
190
191impl std::fmt::Display for Extent {
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 let n = self.sizes().len();
194 for i in 0..n {
195 write!(f, "{}={}", self.labels()[i], self.sizes()[i])?;
196 if i != n - 1 {
197 write!(f, ",")?;
198 }
199 }
200 Ok(())
201 }
202}
203
204pub struct ExtentPointsIterator<'a> {
206 extent: &'a Extent,
207 pos: CartesianIterator,
208}
209
210impl<'a> Iterator for ExtentPointsIterator<'a> {
211 type Item = Point;
212
213 fn next(&mut self) -> Option<Self::Item> {
214 Some(Point {
215 coords: self.pos.next()?,
216 extent: self.extent.clone(),
217 })
218 }
219}
220
221#[derive(Debug, Error)]
223pub enum PointError {
224 #[error("dimension mismatch: expected {expected}, got {actual}")]
231 DimMismatch {
232 expected: usize,
234 actual: usize,
236 },
237
238 #[error("out of range: size of extent is {size}; does not contain rank {rank}")]
240 OutOfRange { size: usize, rank: usize },
241}
242
243#[derive(Clone, Deserialize, Serialize, PartialEq, Eq, Hash, Debug)]
258pub struct Point {
259 coords: Vec<usize>,
260 extent: Extent,
261}
262
263impl Index<usize> for Point {
264 type Output = usize;
265
266 fn index(&self, dim: usize) -> &Self::Output {
270 &self.coords[dim]
271 }
272}
273
274impl<'a> IntoIterator for &'a Point {
275 type Item = usize;
276 type IntoIter = std::iter::Cloned<std::slice::Iter<'a, usize>>;
277
278 fn into_iter(self) -> Self::IntoIter {
283 self.coords.iter().cloned()
284 }
285}
286
287fn _assert_point_traits()
288where
289 Point: Send + Sync + 'static,
290{
291}
292
293pub trait InExtent {
309 fn in_(self, extent: &Extent) -> Result<Point, PointError>;
310}
311
312impl InExtent for Vec<usize> {
313 fn in_(self, extent: &Extent) -> Result<Point, PointError> {
318 extent.point(self)
319 }
320}
321
322impl Point {
323 pub fn coords(&self) -> &Vec<usize> {
325 &self.coords
326 }
327
328 pub fn extent(&self) -> &Extent {
330 &self.extent
331 }
332
333 pub fn rank(&self) -> usize {
343 let mut stride = 1;
344 let mut result = 0;
345 for (c, size) in self
346 .coords
347 .iter()
348 .rev()
349 .zip(self.extent().sizes().iter().rev())
350 {
351 result += *c * stride;
352 stride *= size;
353 }
354
355 result
356 }
357
358 pub fn len(&self) -> usize {
360 self.coords.len()
361 }
362
363 pub fn is_empty(&self) -> bool {
365 self.coords.is_empty()
366 }
367}
368
369impl std::fmt::Display for Point {
370 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371 let n = self.coords.len();
372 for i in 0..n {
373 write!(f, "{}={}", self.extent.labels()[i], self.coords[i])?;
374 if i != n - 1 {
375 write!(f, ",")?;
376 }
377 }
378 Ok(())
379 }
380}
381
382#[derive(Debug, Error)]
384pub enum ViewError {
385 #[error("no such dimension: {0}")]
387 InvalidDim(String),
388
389 #[error("empty range: {range} for dimension {dim} of size {size}")]
391 EmptyRange {
392 range: Range,
393 dim: String,
394 size: usize,
395 },
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct View {
401 labels: Vec<String>,
402 slice: Slice,
403}
404
405impl View {
406 pub fn extent(&self) -> Extent {
408 Extent::new(self.labels.clone(), self.slice.sizes().to_vec()).unwrap()
409 }
410
411 pub fn iter(&self) -> ViewIterator {
414 ViewIterator {
415 extent: self.extent(),
416 pos: self.slice.iter(),
417 }
418 }
419}
420
421impl std::fmt::Display for View {
422 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
423 let n = self.labels.len();
424 for i in 0..n {
425 write!(f, "{}={}", self.labels[i], self.slice.sizes()[i])?;
426 if i != n - 1 {
427 write!(f, ",")?;
428 }
429 }
430 Ok(())
431 }
432}
433
434pub struct ViewIterator {
436 extent: Extent, pos: SliceIterator, }
439
440impl Iterator for ViewIterator {
441 type Item = (Point, usize);
442
443 fn next(&mut self) -> Option<Self::Item> {
444 let rank = self.pos.next()?;
446 let coords = self.pos.slice.coordinates(rank).unwrap();
448 let point = coords.in_(&self.extent).unwrap();
449 Some((point, rank))
450 }
451}
452
453pub trait Viewable {
457 fn labels(&self) -> Vec<String>;
459
460 fn slice(&self) -> Slice;
463}
464
465impl Viewable for View {
466 fn labels(&self) -> Vec<String> {
467 self.labels.clone()
468 }
469
470 fn slice(&self) -> Slice {
471 self.slice.clone()
472 }
473}
474
475impl Viewable for Extent {
476 fn labels(&self) -> Vec<String> {
477 self.labels().to_vec()
478 }
479
480 fn slice(&self) -> Slice {
481 self.to_slice()
482 }
483}
484
485impl From<Extent> for View {
488 fn from(extent: Extent) -> Self {
489 View {
490 labels: extent.labels().to_vec(),
491 slice: extent.slice(),
492 }
493 }
494}
495
496pub trait ViewExt: Viewable {
498 fn range<R: Into<Range>>(&self, dim: &str, range: R) -> Result<View, ViewError>;
524
525 fn group_by(&self, dim: &str) -> Result<impl Iterator<Item = View>, ViewError>;
561}
562
563impl<T: Viewable> ViewExt for T {
564 fn range<R: Into<Range>>(&self, dim: &str, range: R) -> Result<View, ViewError> {
565 let range = range.into();
566 let dim = self
567 .labels()
568 .iter()
569 .position(|l| dim == l)
570 .ok_or_else(|| ViewError::InvalidDim(dim.to_string()))?;
571 let (mut offset, mut sizes, mut strides) = self.slice().into_inner();
572 let (begin, end, step) = range.resolve(sizes[dim]);
573 if end <= begin {
574 return Err(ViewError::EmptyRange {
575 range,
576 dim: dim.to_string(),
577 size: sizes[dim],
578 });
579 }
580
581 offset += strides[dim] * begin;
582 sizes[dim] = (end - begin).div_ceil(step);
583 strides[dim] *= step;
584 let slice = Slice::new(offset, sizes, strides).unwrap();
585
586 Ok(View {
587 labels: self.labels().clone(),
588 slice,
589 })
590 }
591
592 fn group_by(&self, dim: &str) -> Result<impl Iterator<Item = View>, ViewError> {
593 let dim = self
594 .labels()
595 .iter()
596 .position(|l| dim == l)
597 .ok_or_else(|| ViewError::InvalidDim(dim.to_string()))?;
598
599 let (offset, sizes, strides) = self.slice().into_inner();
600 let mut ranks = Slice::new(offset, sizes[..dim].to_vec(), strides[..dim].to_vec())
601 .unwrap()
602 .iter();
603
604 let labels = self.labels()[dim..].to_vec();
605 let sizes = sizes[dim..].to_vec();
606 let strides = strides[dim..].to_vec();
607
608 Ok(std::iter::from_fn(move || {
609 let rank = ranks.next()?;
610 let slice = Slice::new(rank, sizes.clone(), strides.clone()).unwrap();
611 Some(View {
612 labels: labels.clone(),
613 slice,
614 })
615 }))
616 }
617}
618
619#[macro_export]
627macro_rules! extent {
628 ( $( $label:ident = $size:expr_2021 ),* $(,)? ) => {
629 {
630 let mut labels = Vec::new();
631 let mut sizes = Vec::new();
632
633 $(
634 labels.push(stringify!($label).to_string());
635 sizes.push($size);
636 )*
637
638 $crate::view::Extent::new(labels, sizes).unwrap()
639 }
640 };
641}
642
643#[cfg(test)]
644mod test {
645 use super::*;
646 use crate::Shape;
647 use crate::shape;
648
649 #[test]
650 fn test_points_basic() {
651 let extent = extent!(x = 4, y = 5, z = 6);
652 let _p1 = extent.point(vec![1, 2, 3]).unwrap();
653 let _p2 = vec![1, 2, 3].in_(&extent).unwrap();
654
655 assert_eq!(extent.num_ranks(), 4 * 5 * 6);
656
657 let p3 = extent.point_of_rank(0).unwrap();
658 assert_eq!(p3.coords(), &[0, 0, 0]);
659 assert_eq!(p3.rank(), 0);
660
661 let p4 = extent.point_of_rank(1).unwrap();
662 assert_eq!(p4.coords(), &[0, 0, 1]);
663 assert_eq!(p4.rank(), 1);
664
665 let p5 = extent.point_of_rank(2).unwrap();
666 assert_eq!(p5.coords(), &[0, 0, 2]);
667 assert_eq!(p5.rank(), 2);
668
669 let p6 = extent.point_of_rank(6 * 5 + 1).unwrap();
670 assert_eq!(p6.coords(), &[1, 0, 1]);
671 assert_eq!(p6.rank(), 6 * 5 + 1);
672 assert_eq!(p6[0], 1);
673 assert_eq!(p6[1], 0);
674 assert_eq!(p6[2], 1);
675
676 assert_eq!(extent.points().collect::<Vec<_>>().len(), 4 * 5 * 6);
677 for (rank, point) in extent.points().enumerate() {
678 let &[x, y, z] = &**point.coords() else {
679 panic!("invalid coords");
680 };
681 assert_eq!(z + y * 6 + x * 6 * 5, rank);
682 assert_eq!(point.rank(), rank);
683 }
684 }
685
686 macro_rules! assert_view {
687 ($view:expr, $extent:expr, $( $($coord:expr),+ => $rank:expr );* $(;)?) => {
688 let view = $view;
689 assert_eq!(view.extent(), $extent);
690 let expected: Vec<_> = vec![$(($extent.point(vec![$($coord),+]).unwrap(), $rank)),*];
691 let actual: Vec<_> = view.iter().collect();
692 assert_eq!(actual, expected);
693 };
694 }
695
696 #[test]
697 fn test_view_basic() {
698 let extent = extent!(x = 4, y = 4);
699 assert_view!(
700 extent.range("x", 0..2).unwrap(),
701 extent!(x = 2, y = 4),
702 0, 0 => 0;
703 0, 1 => 1;
704 0, 2 => 2;
705 0, 3 => 3;
706 1, 0 => 4;
707 1, 1 => 5;
708 1, 2 => 6;
709 1, 3 => 7;
710 );
711 assert_view!(
712 extent.range("x", 1).unwrap().range("y", 2..).unwrap(),
713 extent!(x = 1, y = 2),
714 0, 0 => 6;
715 0, 1 => 7;
716 );
717 assert_view!(
718 extent.range("y", Range(0, None, 2)).unwrap(),
719 extent!(x = 4, y = 2),
720 0, 0 => 0;
721 0, 1 => 2;
722 1, 0 => 4;
723 1, 1 => 6;
724 2, 0 => 8;
725 2, 1 => 10;
726 3, 0 => 12;
727 3, 1 => 14;
728 );
729 assert_view!(
730 extent.range("y", Range(0, None, 2)).unwrap().range("x", 2..).unwrap(),
731 extent!(x = 2, y = 2),
732 0, 0 => 8;
733 0, 1 => 10;
734 1, 0 => 12;
735 1, 1 => 14;
736 );
737
738 let extent = extent!(x = 10, y = 2);
739 assert_view!(
740 extent.range("x", Range(0, None, 2)).unwrap(),
741 extent!(x = 5, y = 2),
742 0, 0 => 0;
743 0, 1 => 1;
744 1, 0 => 4;
745 1, 1 => 5;
746 2, 0 => 8;
747 2, 1 => 9;
748 3, 0 => 12;
749 3, 1 => 13;
750 4, 0 => 16;
751 4, 1 => 17;
752 );
753 assert_view!(
754 extent.range("x", Range(0, None, 2)).unwrap().range("x", 2..).unwrap().range("y", 1).unwrap(),
755 extent!(x = 3, y = 1),
756 0, 0 => 9;
757 1, 0 => 13;
758 2, 0 => 17;
759 );
760
761 let extent = extent!(zone = 4, host = 2, gpu = 8);
762 assert_view!(
763 extent.range("zone", 0).unwrap().range("gpu", Range(0, None, 2)).unwrap(),
764 extent!(zone = 1, host = 2, gpu = 4),
765 0, 0, 0 => 0;
766 0, 0, 1 => 2;
767 0, 0, 2 => 4;
768 0, 0, 3 => 6;
769 0, 1, 0 => 8;
770 0, 1, 1 => 10;
771 0, 1, 2 => 12;
772 0, 1, 3 => 14;
773 );
774
775 let extent = extent!(x = 3);
776 assert_view!(
777 extent.range("x", Range(0, None, 2)).unwrap(),
778 extent!(x = 2),
779 0 => 0;
780 1 => 2;
781 );
782 }
783
784 #[test]
785 fn test_point_indexing() {
786 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
787 let point = extent.point(vec![1, 2, 3]).unwrap();
788
789 assert_eq!(point[0], 1);
790 assert_eq!(point[1], 2);
791 assert_eq!(point[2], 3);
792 }
793
794 #[test]
795 #[should_panic]
796 fn test_point_indexing_out_of_bounds() {
797 let extent = Extent::new(vec!["x".into(), "y".into()], vec![4, 5]).unwrap();
798 let point = extent.point(vec![1, 2]).unwrap();
799
800 let _ = point[5]; }
802
803 #[test]
804 fn test_point_into_iter() {
805 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
806 let point = extent.point(vec![1, 2, 3]).unwrap();
807
808 let coords: Vec<usize> = (&point).into_iter().collect();
809 assert_eq!(coords, vec![1, 2, 3]);
810
811 let mut sum = 0;
812 for coord in &point {
813 sum += coord;
814 }
815 assert_eq!(sum, 6);
816 }
817
818 #[test]
819 fn test_extent_basic() {
820 let extent = extent!(x = 10, y = 5, z = 1);
821 assert_eq!(
822 extent.iter().collect::<Vec<_>>(),
823 vec![
824 ("x".to_string(), 10),
825 ("y".to_string(), 5),
826 ("z".to_string(), 1)
827 ]
828 );
829 }
830
831 #[test]
832 fn test_extent_display() {
833 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
834 assert_eq!(format!("{}", extent), "x=4,y=5,z=6");
835
836 let empty_extent = Extent::new(vec![], vec![]).unwrap();
837 assert_eq!(format!("{}", empty_extent), "");
838 }
839
840 #[test]
841 fn test_extent_0d() {
842 let e = Extent::new(vec![], vec![]).unwrap();
843 assert_eq!(e.num_ranks(), 1);
844 let points: Vec<_> = e.points().collect();
845 assert_eq!(points.len(), 1);
846 assert_eq!(points[0].coords(), &[]);
847 assert_eq!(points[0].rank(), 0);
848 }
849
850 #[test]
851 fn test_point_display() {
852 let extent = Extent::new(vec!["x".into(), "y".into(), "z".into()], vec![4, 5, 6]).unwrap();
853 let point = extent.point(vec![1, 2, 3]).unwrap();
854 assert_eq!(format!("{}", point), "x=1,y=2,z=3");
855
856 assert!(extent.point(vec![]).is_err());
857
858 let empty_extent = Extent::new(vec![], vec![]).unwrap();
859 let empty_point = empty_extent.point(vec![]).unwrap();
860 assert_eq!(format!("{}", empty_point), "");
861 }
862
863 #[test]
864 fn test_relative_point() {
865 pub fn relative_point(rank_on_root_mesh: usize, shape: &Shape) -> anyhow::Result<Point> {
868 let coords = shape.slice().coordinates(rank_on_root_mesh)?;
869 let extent = Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec())?;
870 Ok(extent.point(coords)?)
871 }
872
873 let root_shape = shape! { replicas = 4, hosts = 4, gpus = 4 };
874 let sliced_shape = root_shape
886 .select("replicas", crate::Range(0, Some(4), 3))
887 .unwrap()
888 .select("hosts", crate::Range(1, Some(4), 2))
889 .unwrap()
890 .select("gpus", crate::Range(0, Some(4), 2))
891 .unwrap();
892 let ranks_on_root_mesh = &[4, 6, 12, 14, 52, 54, 60, 62];
893 assert_eq!(
894 sliced_shape.slice().iter().collect::<Vec<_>>(),
895 ranks_on_root_mesh,
896 );
897
898 let ranks_on_sliced_mesh = ranks_on_root_mesh
899 .iter()
900 .map(|&r| relative_point(r, &sliced_shape).unwrap().rank());
901 assert_eq!(
902 ranks_on_sliced_mesh.collect::<Vec<_>>(),
903 vec![0, 1, 2, 3, 4, 5, 6, 7]
904 );
905 }
906
907 #[test]
908 fn test_iter_subviews() {
909 let extent = extent!(zone = 4, host = 4, gpu = 8);
910
911 assert_eq!(extent.group_by("gpu").unwrap().count(), 16);
912 assert_eq!(extent.group_by("zone").unwrap().count(), 1);
913
914 let mut parts = extent.group_by("gpu").unwrap();
915 assert_view!(
916 parts.next().unwrap(),
917 extent!(gpu = 8),
918 0 => 0;
919 1 => 1;
920 2 => 2;
921 3 => 3;
922 4 => 4;
923 5 => 5;
924 6 => 6;
925 7 => 7;
926 );
927 assert_view!(
928 parts.next().unwrap(),
929 extent!(gpu = 8),
930 0 => 8;
931 1 => 9;
932 2 => 10;
933 3 => 11;
934 4 => 12;
935 5 => 13;
936 6 => 14;
937 7 => 15;
938 );
939 }
940}