1use std::ffi::CString;
10use std::fmt;
11use std::fmt::Write;
12use std::hash::Hasher;
13use std::marker::PhantomData;
14use std::mem::MaybeUninit;
15
16use fxhash::FxHasher32;
17use nccl_sys::*;
18use serde::Deserialize;
19use serde::Serialize;
20use thiserror::Error;
21use torch_sys::CudaDevice;
22use torch_sys::DeviceType;
23use torch_sys::ScalarType;
24use torch_sys::Tensor;
25use torch_sys::TensorCell;
26use torch_sys::factory_float_tensor;
27use torch_sys::is_float8_type;
28use torch_sys::suggest_memory_format;
29
30use crate::bridge::ffi::make_nccl_config;
31use crate::cuda::CudaError;
32use crate::cuda::Stream;
33use crate::cuda::set_device;
34
35#[derive(Debug, Error)]
37pub enum RawNcclError {
38 #[error("a call to a CUDA function failed")]
39 UnhandledCudaError,
40 #[error("a call to the system failed")]
41 SystemError,
42 #[error("an internal check failed; either bug in nccl or memory corruption")]
43 InternalError,
44 #[error("an argument has an invalid value")]
45 InvalidArgument,
46 #[error("a call to NCCL is incorrect, usually a programming error")]
47 InvalidUsage,
48 #[error(
49 "a call failed possibly due to a network error or a remote process exiting prematurely"
50 )]
51 RemoteError,
52}
53
54#[derive(Debug, Error)]
56pub enum NcclError {
57 #[error("a NCCL-level error: {0:?}")]
58 NcclError(#[from] RawNcclError),
59
60 #[error("a CUDA-level error: {0:?}")]
61 CudaError(#[from] CudaError),
62
63 #[error("invalid NCCL data type: {0:#?}")]
64 InvalidDataType(ScalarType),
65
66 #[error("tensor used in collective must be contiguous")]
67 NoncontiguousTensor,
68
69 #[error("tensor must be on CUDA device, got: {0:?}")]
71 InvalidDevice(DeviceType),
72
73 #[error("got sparse tensor, only dense tensors allowed")]
74 InvalidSparseTensor,
75
76 #[error("float8 dtypes are not currently supported for NCCL reductions")]
77 Float8Reduction,
78
79 #[error("output tensor must have the same type as input tensor")]
80 TypeMismatch,
81
82 #[error("output tensor size must be equal to world size times input tensor size")]
83 OutputSizeMismatch,
84
85 #[error("input tensor must be the same size as output size times world size")]
86 InputSizeMismatch,
87
88 #[error("ranks passed should be within the global world_size, got: {0:#?}")]
89 InvalidSplit(Vec<i32>),
90
91 #[error("undefined tensor used for NCCL operation")]
92 UndefinedTensor,
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub enum NcclStatus {
97 Success,
99 InProgress,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct NcclConfig {
111 pub blocking: bool,
112 pub cga_cluster_size: u8,
113 pub min_ctas: u8,
114 pub max_ctas: u8,
115 pub net_name: Option<String>,
116 pub split_share: bool,
117}
118
119impl Default for NcclConfig {
120 fn default() -> Self {
121 NcclConfig {
122 blocking: true,
123 cga_cluster_size: 4,
124 min_ctas: 1,
125 max_ctas: 32,
126 net_name: None,
127 split_share: false,
128 }
129 }
130}
131
132impl From<NcclConfig> for ncclConfig_t {
133 fn from(config: NcclConfig) -> Self {
134 let mut ret = make_nccl_config();
135 ret.blocking = config.blocking.into();
136 ret.cgaClusterSize = config.cga_cluster_size.into();
137 ret.minCTAs = config.min_ctas.into();
138 ret.maxCTAs = config.max_ctas.into();
139 if let Some(net_name) = config.net_name {
140 let c_string = CString::new(net_name)
141 .expect("failed to create CString")
142 .into_boxed_c_str();
143
144 let ptr = Box::leak(c_string).as_ptr();
150 ret.netName = ptr;
151 }
152 ret.splitShare = config.split_share.into();
153
154 ret
155 }
156}
157
158fn nccl_check(result: ncclResult_t) -> Result<NcclStatus, RawNcclError> {
159 match result.0 {
160 0 => Ok(NcclStatus::Success),
161 1 => Err(RawNcclError::UnhandledCudaError),
162 2 => Err(RawNcclError::SystemError),
163 3 => Err(RawNcclError::InternalError),
164 4 => Err(RawNcclError::InvalidArgument),
165 5 => Err(RawNcclError::InvalidUsage),
166 6 => Err(RawNcclError::RemoteError),
167 7 => Ok(NcclStatus::InProgress),
168 _ => panic!("Unknown ncclResult_t: {:?}", result.0),
169 }
170}
171
172pub struct NcclGroupTicket {
180 unsend_marker: PhantomData<*const ()>,
182}
183
184pub fn group_start() -> Result<NcclGroupTicket, NcclError> {
187 nccl_check(unsafe { ncclGroupStart() })?;
189 Ok(NcclGroupTicket {
190 unsend_marker: PhantomData,
191 })
192}
193
194pub fn group_end(_ticket: NcclGroupTicket) -> Result<(), NcclError> {
196 nccl_check(unsafe { ncclGroupEnd() })?;
198 Ok(())
199}
200
201#[derive(Clone, Serialize, Deserialize)]
203pub struct UniqueId {
204 inner: ncclUniqueId,
205}
206
207impl fmt::Debug for UniqueId {
208 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209 f.debug_struct("UniqueId")
210 .field(
211 "inner",
212 &format_args!(
213 "{}",
214 self.inner
215 .internal
216 .iter()
217 .fold(String::new(), |mut output, b| {
218 let _ = write!(output, "{:02x}", b);
219 output
220 })
221 ),
222 )
223 .finish()
224 }
225}
226
227impl UniqueId {
228 pub fn new() -> Result<Self, RawNcclError> {
230 let mut inner = MaybeUninit::uninit();
231 let inner = unsafe {
233 nccl_check(ncclGetUniqueId(inner.as_mut_ptr()))?;
234 inner.assume_init()
235 };
236 Ok(Self { inner })
237 }
238}
239
240#[derive(Debug, Clone, Copy, PartialEq, Eq)]
242pub enum DataType {
243 Int8 = 0,
244 Uint8 = 1,
245 Int32 = 2,
246 Uint32 = 3,
247 Int64 = 4,
248 Uint64 = 5,
249 Float16 = 6,
250 Float32 = 7,
251 Float64 = 8,
252 Bfloat16 = 9,
253}
254
255impl From<DataType> for ncclDataType_t {
256 fn from(data_type: DataType) -> Self {
257 Self(data_type as std::os::raw::c_uint)
258 }
259}
260
261impl TryFrom<ScalarType> for DataType {
262 type Error = NcclError;
263
264 fn try_from(value: ScalarType) -> Result<Self, Self::Error> {
265 match value {
266 ScalarType::Char => Ok(DataType::Int8),
267 ScalarType::Byte => Ok(DataType::Uint8),
268 ScalarType::Half => Ok(DataType::Float16),
269 ScalarType::Float => Ok(DataType::Float32),
270 ScalarType::Double => Ok(DataType::Float64),
271 ScalarType::Int => Ok(DataType::Int32),
272 ScalarType::Long => Ok(DataType::Int64),
273 ScalarType::Bool => Ok(DataType::Uint8),
274 ScalarType::BFloat16 => Ok(DataType::Bfloat16),
275 ScalarType::Float8_e5m2 => Ok(DataType::Uint8),
276 ScalarType::Float8_e4m3fn => Ok(DataType::Uint8),
277 ScalarType::Float8_e4m3fnuz => Ok(DataType::Uint8),
278 ScalarType::Float8_e5m2fnuz => Ok(DataType::Uint8),
279 _ => Err(NcclError::InvalidDataType(value)),
280 }
281 }
282}
283
284#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
286pub enum ReduceOp {
287 Sum = 0,
288 Prod = 1,
289 Max = 2,
290 Min = 3,
291 Avg = 4,
292}
293
294impl From<ReduceOp> for ncclRedOp_t {
295 fn from(reduce_op: ReduceOp) -> Self {
296 Self(reduce_op as std::os::raw::c_uint)
297 }
298}
299
300fn check_tensor(tensor: &Tensor, is_p2p: bool) -> Result<(), NcclError> {
301 if !tensor.defined() {
302 return Err(NcclError::UndefinedTensor);
303 }
304 if !tensor.is_cuda() {
305 return Err(NcclError::InvalidDevice(tensor.device().device_type()));
306 }
307 if tensor.is_sparse() {
308 return Err(NcclError::InvalidSparseTensor);
309 }
310
311 if !is_p2p && !tensor.is_contiguous(suggest_memory_format(tensor)) {
312 return Err(NcclError::NoncontiguousTensor);
313 }
314
315 Ok(())
316}
317#[derive(Debug)]
321pub struct Communicator {
322 inner: ncclComm_t,
323 world_size: i32,
325 rank: i32,
327 global_world_size: i32,
330 global_rank: i32,
331 device: CudaDevice,
332}
333
334unsafe impl Send for Communicator {}
338unsafe impl Sync for Communicator {}
342
343fn calculate_color(ranks: &[i32]) -> i32 {
345 let mut hasher = FxHasher32::default();
347 ranks.iter().for_each(|r| hasher.write_i32(*r));
348 let hash = hasher.finish();
349
350 (hash % (i32::MAX as u64)) as i32
352}
353
354impl Communicator {
355 pub fn new(
358 device: CudaDevice,
359 world_size: i32,
360 unique_id: UniqueId,
361 rank: i32,
362 ) -> Result<Self, NcclError> {
363 set_device(device)?;
364 let mut inner = MaybeUninit::uninit();
365 let inner = unsafe {
367 nccl_check(ncclCommInitRank(
368 inner.as_mut_ptr(),
369 world_size,
370 unique_id.inner,
371 rank,
372 ))?;
373 inner.assume_init()
374 };
375 Ok(Self {
376 inner,
377 world_size,
378 rank,
379 global_rank: rank,
380 global_world_size: world_size,
381 device,
382 })
383 }
384
385 pub fn split_all(&mut self, config: Option<NcclConfig>) -> Result<Self, NcclError> {
388 let ranks = (0..self.global_world_size).collect();
389 Ok(self.split_from(ranks, config)?.unwrap())
390 }
391
392 pub fn split_from(
399 &mut self,
400 mut ranks: Vec<i32>,
401 config: Option<NcclConfig>,
402 ) -> Result<Option<Self>, NcclError> {
403 ranks.sort();
404 for rank in &ranks {
405 if *rank < 0 || *rank >= self.global_world_size {
406 return Err(NcclError::InvalidSplit(ranks));
407 }
408 }
409
410 let color = match ranks.binary_search(&self.rank) {
411 Ok(_) => calculate_color(ranks.as_slice()),
412 Err(_) => NCCL_SPLIT_NOCOLOR,
413 };
414
415 let config = config.map(ncclConfig_t::from);
416 let mut new = MaybeUninit::uninit();
417
418 let new = unsafe {
420 match config {
424 Some(mut config) => {
425 nccl_check(ncclCommSplit(
426 self.inner,
427 color,
428 self.rank,
429 new.as_mut_ptr(),
430 &mut config,
431 ))?;
432 }
433 None => {
434 nccl_check(ncclCommSplit(
435 self.inner,
436 color,
437 self.rank,
438 new.as_mut_ptr(),
439 std::ptr::null_mut(),
440 ))?;
441 }
442 }
443 new.assume_init()
444 };
445
446 let group_rank = ranks.iter().position(|v| *v == self.rank);
447 match color {
448 NCCL_SPLIT_NOCOLOR => Ok(None),
449 _ => Ok(Some(Self {
450 inner: new,
451 world_size: ranks.len() as i32,
452 rank: group_rank.unwrap() as i32,
453 global_rank: self.global_rank,
454 global_world_size: self.global_world_size,
455 device: self.device,
456 })),
457 }
458 }
459
460 pub fn all_reduce(
465 &mut self,
466 tensor: &TensorCell,
467 reduce_op: ReduceOp,
468 stream: &Stream,
469 ) -> Result<NcclStatus, NcclError> {
470 let tensor = tensor.borrow_mut();
471 let data_type: DataType = tensor.scalar_type().try_into()?;
472
473 check_tensor(&tensor, false)?;
474 if is_float8_type(tensor.scalar_type()) {
475 return Err(NcclError::Float8Reduction);
476 }
477 unsafe {
479 Ok(nccl_check(ncclAllReduce(
480 tensor.data_ptr(),
481 tensor.mut_data_ptr(),
482 tensor.numel() as usize,
483 data_type.into(),
484 reduce_op.into(),
485 self.inner,
486 stream.stream(),
487 ))?)
488 }
489 }
490
491 pub fn broadcast(
495 &mut self,
496 tensor: &TensorCell,
497 root: i32,
498 stream: &Stream,
499 ) -> Result<NcclStatus, NcclError> {
500 let tensor = tensor.borrow_mut();
501 check_tensor(&tensor, false)?;
502 let data_type: DataType = tensor.scalar_type().try_into()?;
503 unsafe {
505 Ok(nccl_check(ncclBroadcast(
506 tensor.data_ptr(),
507 tensor.mut_data_ptr(),
508 tensor.numel() as usize,
509 data_type.into(),
510 root,
511 self.inner,
512 stream.stream(),
513 ))?)
514 }
515 }
516
517 pub fn reduce(
522 &mut self,
523 tensor: &TensorCell,
524 reduce_op: ReduceOp,
525 root: i32,
526 stream: &Stream,
527 ) -> Result<NcclStatus, NcclError> {
528 let tensor = tensor.borrow_mut();
529 check_tensor(&tensor, false)?;
530 if is_float8_type(tensor.scalar_type()) {
531 return Err(NcclError::Float8Reduction);
532 }
533 let data_type: DataType = tensor.scalar_type().try_into()?;
534 unsafe {
536 Ok(nccl_check(ncclReduce(
537 tensor.data_ptr(),
538 tensor.mut_data_ptr(),
539 tensor.numel() as usize,
540 data_type.into(),
541 reduce_op.into(),
542 root,
543 self.inner,
544 stream.stream(),
545 ))?)
546 }
547 }
548
549 pub fn all_gather(
553 &mut self,
554 output_cells: &[TensorCell],
555 input_cell: &TensorCell,
556 stream: &Stream,
557 ) -> Result<NcclStatus, NcclError> {
558 let output = output_cells
559 .iter()
560 .map(|t| t.borrow_mut())
561 .collect::<Vec<_>>();
562 let input = input_cell.borrow();
563 check_tensor(&input, false)?;
564 let output_type = output[0].scalar_type();
565 let output_numel: i64 = output.iter().map(|t| t.numel()).sum();
566 for t in &output {
567 if t.scalar_type() != output_type {
568 return Err(NcclError::TypeMismatch);
569 }
570 }
571 if input.scalar_type() != output_type {
572 return Err(NcclError::TypeMismatch);
573 }
574 if input.numel() * self.world_size as i64 != output_numel {
575 return Err(NcclError::OutputSizeMismatch);
576 }
577 let data_type: DataType = input.scalar_type().try_into()?;
578 unsafe {
582 nccl_check(ncclGroupStart())?;
583 for (i, output) in output.iter().enumerate() {
584 let rank = i as i32;
586 let output_ptr = output.mut_data_ptr();
587 if rank == self.rank {
590 nccl_check(ncclBroadcast(
591 input.data_ptr(),
592 output_ptr,
593 input.numel() as usize,
594 data_type.into(),
595 rank,
596 self.inner,
597 stream.stream(),
598 ))?;
599 } else {
600 nccl_check(ncclBroadcast(
601 output_ptr,
602 output_ptr,
603 output.numel() as usize,
604 data_type.into(),
605 rank,
606 self.inner,
607 stream.stream(),
608 ))?;
609 }
610 }
611 nccl_check(ncclGroupEnd())?;
612 }
613 Ok(NcclStatus::Success)
614 }
615
616 pub fn all_gather_into_tensor(
621 &mut self,
622 output_cell: &TensorCell,
623 input_cell: &TensorCell,
624 stream: &Stream,
625 ) -> Result<NcclStatus, NcclError> {
626 let output = output_cell.borrow_mut();
627 let _input_borrow = if input_cell.aliases(output_cell) {
628 None
629 } else {
630 Some(input_cell.borrow())
631 };
632 let input = unsafe { input_cell.get_unchecked() };
634 check_tensor(&output, false)?;
635 check_tensor(input, false)?;
636 if input.scalar_type() != output.scalar_type() {
637 return Err(NcclError::TypeMismatch);
638 }
639 if input.numel() * self.world_size as i64 != output.numel() {
640 return Err(NcclError::OutputSizeMismatch);
641 }
642
643 let data_type: DataType = input.scalar_type().try_into()?;
644 unsafe {
646 Ok(nccl_check(ncclAllGather(
647 input.data_ptr(),
648 output.mut_data_ptr(),
649 input.numel() as usize,
650 data_type.into(),
651 self.inner,
652 stream.stream(),
653 ))?)
654 }
655 }
656
657 pub fn reduce_scatter_tensor(
662 &mut self,
663 output_cell: &TensorCell,
664 input_cell: &TensorCell,
665 reduce_op: ReduceOp,
666 stream: &Stream,
667 ) -> Result<NcclStatus, NcclError> {
668 let output = output_cell.borrow_mut();
669 let _input_borrow = if input_cell.aliases(output_cell) {
670 None
671 } else {
672 Some(input_cell.borrow())
673 };
674
675 let input = unsafe { input_cell.get_unchecked() }; check_tensor(&output, false)?;
679 check_tensor(input, false)?;
680 if input.scalar_type() != output.scalar_type() {
681 return Err(NcclError::TypeMismatch);
682 }
683 if input.numel() != output.numel() * self.world_size as i64 {
684 return Err(NcclError::InputSizeMismatch);
685 }
686 if is_float8_type(input.scalar_type()) {
687 return Err(NcclError::Float8Reduction);
688 }
689
690 let data_type: DataType = input.scalar_type().try_into()?;
691 unsafe {
693 Ok(nccl_check(ncclReduceScatter(
694 input.data_ptr(),
695 output.mut_data_ptr(),
696 output.numel() as usize,
697 data_type.into(),
698 reduce_op.into(),
699 self.inner,
700 stream.stream(),
701 ))?)
702 }
703 }
704
705 pub fn send(
707 &mut self,
708 tensor_cell: &TensorCell,
709 dst: i32,
710 stream: &Stream,
711 ) -> Result<NcclStatus, NcclError> {
712 let tensor = tensor_cell.borrow();
713 let data_type: DataType = tensor.scalar_type().try_into()?;
714
715 check_tensor(&tensor, true)?;
716
717 unsafe {
719 Ok(nccl_check(ncclSend(
720 tensor.data_ptr(),
721 tensor.numel() as usize,
722 data_type.into(),
723 dst,
724 self.inner,
725 stream.stream(),
726 ))?)
727 }
728 }
729
730 pub fn recv(
732 &mut self,
733 tensor_cell: &TensorCell,
734 src: i32,
735 stream: &Stream,
736 ) -> Result<NcclStatus, NcclError> {
737 let tensor = tensor_cell.borrow_mut();
738 let data_type: DataType = tensor.scalar_type().try_into()?;
739
740 check_tensor(&tensor, true)?;
741
742 unsafe {
744 Ok(nccl_check(ncclRecv(
745 tensor.mut_data_ptr(),
746 tensor.numel() as usize,
747 data_type.into(),
748 src,
749 self.inner,
750 stream.stream(),
751 ))?)
752 }
753 }
754
755 pub fn all_to_all_single(
761 &mut self,
762 output_cell: &TensorCell,
763 input_cell: &TensorCell,
764 stream: &Stream,
765 ) -> Result<NcclStatus, NcclError> {
766 let output = output_cell.borrow_mut();
767 let _input_borrow = if input_cell.aliases(output_cell) {
768 None
769 } else {
770 Some(input_cell.borrow_mut())
771 };
772 let input = unsafe { input_cell.get_unchecked() };
774
775 check_tensor(&output, false)?;
776 check_tensor(input, false)?;
777 if input.scalar_type() != output.scalar_type() {
778 return Err(NcclError::TypeMismatch);
779 }
780
781 let data_type: DataType = input.scalar_type().try_into()?;
782 let count = input.numel() as usize / self.world_size as usize;
783 let rank_stride = input.nbytes() as isize / self.world_size as isize;
784 unsafe {
786 let send_buff = input.data_ptr();
787 let recv_buff = output.mut_data_ptr();
788
789 nccl_check(ncclGroupStart())?;
790 for r in 0..self.world_size {
791 nccl_check(ncclSend(
792 send_buff.offset(r as isize * rank_stride),
793 count,
794 data_type.into(),
795 r,
796 self.inner,
797 stream.stream(),
798 ))?;
799 nccl_check(ncclRecv(
800 recv_buff.offset(r as isize * rank_stride),
801 count,
802 data_type.into(),
803 r,
804 self.inner,
805 stream.stream(),
806 ))?;
807 }
808
809 nccl_check(ncclGroupEnd())?;
810 };
811 Ok(NcclStatus::Success)
812 }
813
814 pub fn barrier(&mut self, stream: &Stream) -> Result<NcclStatus, NcclError> {
818 let tensor = factory_float_tensor(&[1.0], self.device.into());
819 let data_type: DataType = tensor.scalar_type().try_into()?;
820
821 unsafe {
825 Ok(nccl_check(ncclAllReduce(
826 tensor.data_ptr(),
827 tensor.mut_data_ptr(),
828 tensor.numel() as usize,
829 data_type.into(),
830 ReduceOp::Sum.into(),
831 self.inner,
832 stream.stream(),
833 ))?)
834 }
835 }
836}
837
838#[cfg(test)]
839mod tests {
840 use torch_sys::CudaDevice;
841 use torch_sys::DeviceIndex;
842 use torch_sys::factory_float_tensor;
843 use torch_sys::testing::allclose;
844 use torch_sys::testing::cuda_full;
845 use torch_sys::testing::stack;
846
847 use super::*;
848 use crate::cuda::set_device;
849
850 #[test]
851 fn all_reduce() {
852 let unique_id = UniqueId::new().unwrap();
853 let mut handles = Vec::new();
854 for i in 0..2 {
855 let unique_id = unique_id.clone();
856 handles.push(std::thread::spawn(move || {
857 let device = CudaDevice::new(DeviceIndex(i));
858 set_device(device).unwrap();
859 let stream = Stream::new();
860 let tensor = cuda_full(&[2, 2], 1.0);
861 let expected = cuda_full(&[2, 2], 2.0);
862
863 let cell = TensorCell::new(tensor);
864 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
865 comm.all_reduce(&cell, ReduceOp::Sum, &stream).unwrap();
866 stream.synchronize();
867 assert!(allclose(&cell.borrow(), &expected).unwrap());
868 }));
869 }
870 for handle in handles {
871 handle.join().unwrap();
872 }
873 }
874
875 #[test]
876 fn broadcast() {
877 let unique_id = UniqueId::new().unwrap();
878 let mut handles = Vec::new();
879 for i in 0..2 {
880 let unique_id = unique_id.clone();
881 handles.push(std::thread::spawn(move || {
882 let device = CudaDevice::new(DeviceIndex(i));
883 set_device(device).unwrap();
884 let stream = Stream::new();
885 let tensor = cuda_full(&[2, 2], i as f32);
886
887 let cell = TensorCell::new(tensor);
888 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
889 comm.broadcast(&cell, 1, &stream).unwrap();
890 stream.synchronize();
891 assert!(allclose(&cell.borrow(), &cuda_full(&[2, 2], 1.0)).unwrap());
892 }));
893 }
894 for handle in handles {
895 handle.join().unwrap();
896 }
897 }
898
899 #[test]
900 fn reduce() {
901 let unique_id = UniqueId::new().unwrap();
902 let mut handles = Vec::new();
903 for i in 0..2 {
904 let unique_id = unique_id.clone();
905 handles.push(std::thread::spawn(move || {
906 let device = CudaDevice::new(DeviceIndex(i));
907 set_device(device).unwrap();
908 let stream = Stream::new();
909 let tensor = cuda_full(&[2, 2], 2.0);
910
911 let cell = TensorCell::new(tensor);
912 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
913 comm.reduce(&cell, ReduceOp::Sum, 0, &stream).unwrap();
914 stream.synchronize();
915 match i {
916 0 => assert!(allclose(&cell.borrow(), &cuda_full(&[2, 2], 4.0)).unwrap()),
917 1 => assert!(allclose(&cell.borrow(), &cuda_full(&[2, 2], 2.0)).unwrap()),
918 _ => unreachable!(),
919 }
920 }));
921 }
922 for handle in handles {
923 handle.join().unwrap();
924 }
925 }
926
927 #[test]
928 fn all_gather_into_tensor() {
929 let unique_id = UniqueId::new().unwrap();
930 let mut handles = Vec::new();
931 for i in 0..2 {
932 let unique_id = unique_id.clone();
933 handles.push(std::thread::spawn(move || {
934 let device = CudaDevice::new(DeviceIndex(i));
935 set_device(device).unwrap();
936 let stream = Stream::new();
937 let input_tensor = cuda_full(&[2, 2], i as f32);
938 let output_tensor = cuda_full(&[2, 2, 2], 0.0);
939
940 let expected = {
941 let mut tensor_list = Vec::new();
942 for i in 0..2 {
943 tensor_list.push(cuda_full(&[2, 2], i as f32));
944 }
945 stack(&tensor_list)
946 };
947 let input_cell = TensorCell::new(input_tensor);
948 let output_cell = TensorCell::new(output_tensor);
949 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
950 comm.all_gather_into_tensor(&output_cell, &input_cell, &stream)
951 .unwrap();
952 stream.synchronize();
953 assert!(allclose(&output_cell.borrow(), &expected).unwrap());
954 }));
955 }
956 for handle in handles {
957 handle.join().unwrap();
958 }
959 }
960
961 #[test]
962 fn send_recv() {
963 let unique_id = UniqueId::new().unwrap();
964 let mut handles = Vec::new();
965 let unique_id_ = unique_id.clone();
966 handles.push(std::thread::spawn(move || {
967 let device = CudaDevice::new(DeviceIndex(0));
968 set_device(device).unwrap();
969 let stream = Stream::new();
970 let tensor = cuda_full(&[2, 2], 0.0);
971
972 let cell = TensorCell::new(tensor);
973 let mut comm = Communicator::new(device, 2, unique_id_, 0).unwrap();
974 comm.send(&cell, 1, &stream).unwrap();
975 stream.synchronize();
976 }));
977 let unique_id_ = unique_id.clone();
978 handles.push(std::thread::spawn(move || {
979 let device = CudaDevice::new(DeviceIndex(1));
980 set_device(device).unwrap();
981 let stream = Stream::new();
982 let tensor = cuda_full(&[2, 2], 1.1);
983 let expected = cuda_full(&[2, 2], 0.0);
984
985 let cell = TensorCell::new(tensor);
986 let mut comm = Communicator::new(device, 2, unique_id_, 1).unwrap();
987 comm.recv(&cell, 0, &stream).unwrap();
988 stream.synchronize();
989 assert!(allclose(&cell.borrow(), &expected).unwrap());
990 }));
991 for handle in handles {
992 handle.join().unwrap();
993 }
994 }
995
996 #[test]
997 fn all_to_all_single() {
998 let unique_id = UniqueId::new().unwrap();
999 let mut handles = Vec::new();
1000 for i in 0..2 {
1001 let unique_id = unique_id.clone();
1002 handles.push(std::thread::spawn(move || {
1003 let device = CudaDevice::new(DeviceIndex(i));
1004 set_device(device).unwrap();
1005 let stream = Stream::new();
1006 let input = match i {
1007 0 => factory_float_tensor(&[0.0, 1.0], device.into()),
1008 1 => factory_float_tensor(&[2.0, 3.0], device.into()),
1009 _ => unreachable!(),
1010 };
1011 let output = cuda_full(&[2], 0.0);
1012
1013 let input = TensorCell::new(input);
1014 let output = TensorCell::new(output);
1015
1016 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
1017 comm.all_to_all_single(&output, &input, &stream).unwrap();
1018 stream.synchronize();
1019
1020 let expected = match i {
1021 0 => factory_float_tensor(&[0.0, 2.0], device.into()),
1022 1 => factory_float_tensor(&[1.0, 3.0], device.into()),
1023 _ => unreachable!(),
1024 };
1025 assert!(allclose(&output.borrow(), &expected).unwrap());
1026 }));
1027 }
1028 for handle in handles {
1029 handle.join().unwrap();
1030 }
1031 }
1032
1033 #[test]
1034 fn reduce_scatter_tensor() {
1035 let unique_id = UniqueId::new().unwrap();
1036 let mut handles = Vec::new();
1037 for i in 0..2 {
1038 let unique_id = unique_id.clone();
1039 handles.push(std::thread::spawn(move || {
1040 let device = CudaDevice::new(DeviceIndex(i));
1041 set_device(device).unwrap();
1042 let stream = Stream::new();
1043 let input = factory_float_tensor(&[0.0, 1.0, 2.0, 3.0], device.into());
1044 let output = cuda_full(&[2], 1.0);
1045
1046 let input = TensorCell::new(input);
1047 let output = TensorCell::new(output);
1048
1049 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
1050 comm.reduce_scatter_tensor(&output, &input, ReduceOp::Sum, &stream)
1051 .unwrap();
1052 stream.synchronize();
1053
1054 let expected = match i {
1055 0 => factory_float_tensor(&[0.0, 2.0], device.into()),
1056 1 => factory_float_tensor(&[4.0, 6.0], device.into()),
1057 _ => unreachable!(),
1058 };
1059 assert!(allclose(&output.borrow(), &expected).unwrap());
1060 }));
1061 }
1062 for handle in handles {
1063 handle.join().unwrap();
1064 }
1065 }
1066
1067 #[test]
1068 fn split_from() {
1069 let unique_id = UniqueId::new().unwrap();
1070 let mut handles = Vec::new();
1071 for i in 0..2 {
1072 let unique_id = unique_id.clone();
1073 handles.push(std::thread::spawn(move || {
1074 let device = CudaDevice::new(DeviceIndex(i));
1075 set_device(device).unwrap();
1076 let stream = Stream::new();
1077 let tensor = cuda_full(&[2, 2], 1.0);
1078 let cell = TensorCell::new(tensor);
1079 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
1080
1081 let split_comm = comm.split_from(vec![0], None).unwrap();
1083
1084 match i {
1085 0 => assert!(split_comm.is_some()),
1086 1 => assert!(split_comm.is_none()),
1087 _ => unreachable!(),
1088 };
1089
1090 match i {
1091 0 => {
1092 split_comm
1093 .unwrap()
1094 .all_reduce(&cell, ReduceOp::Sum, &stream)
1095 .unwrap();
1096 stream.synchronize();
1097 let expected = cuda_full(&[2, 2], 1.0);
1098 assert!(allclose(&cell.borrow(), &expected).unwrap());
1099 }
1100 1 => (),
1101 _ => unreachable!(),
1102 };
1103 }));
1104 }
1105 for handle in handles {
1106 handle.join().unwrap();
1107 }
1108 }
1109}