1use std::fmt;
10use std::fmt::Write;
11use std::hash::Hasher;
12use std::marker::PhantomData;
13use std::mem::MaybeUninit;
14
15use fxhash::FxHasher32;
16use nccl_sys::*;
17use serde::Deserialize;
18use serde::Serialize;
19use thiserror::Error;
20use torch_sys2::CudaDevice;
21use torch_sys2::DeviceType;
22use torch_sys2::ScalarType;
23use torch_sys2::Tensor;
24use torch_sys2::TensorCell;
25use torch_sys2::factory_float_tensor;
26use torch_sys2::is_float8_type;
27
28use crate::cuda::CudaError;
29use crate::cuda::Stream;
30use crate::cuda::set_device;
31
32#[derive(Debug, Error)]
34pub enum RawNcclError {
35 #[error("a call to a CUDA function failed")]
36 UnhandledCudaError,
37 #[error("a call to the system failed")]
38 SystemError,
39 #[error("an internal check failed; either bug in nccl or memory corruption")]
40 InternalError,
41 #[error("an argument has an invalid value")]
42 InvalidArgument,
43 #[error("a call to NCCL is incorrect, usually a programming error")]
44 InvalidUsage,
45 #[error(
46 "a call failed possibly due to a network error or a remote process exiting prematurely"
47 )]
48 RemoteError,
49}
50
51#[derive(Debug, Error)]
53pub enum NcclError {
54 #[error("a NCCL-level error: {0:?}")]
55 NcclError(#[from] RawNcclError),
56
57 #[error("a CUDA-level error: {0:?}")]
58 CudaError(#[from] CudaError),
59
60 #[error("invalid NCCL data type: {0:#?}")]
61 InvalidDataType(ScalarType),
62
63 #[error("tensor used in collective must be contiguous")]
64 NoncontiguousTensor,
65
66 #[error("tensor must be on CUDA device, got: {0:?}")]
68 InvalidDevice(DeviceType),
69
70 #[error("got sparse tensor, only dense tensors allowed")]
71 InvalidSparseTensor,
72
73 #[error("float8 dtypes are not currently supported for NCCL reductions")]
74 Float8Reduction,
75
76 #[error("output tensor must have the same type as input tensor")]
77 TypeMismatch,
78
79 #[error("output tensor size must be equal to world size times input tensor size")]
80 OutputSizeMismatch,
81
82 #[error("input tensor must be the same size as output size times world size")]
83 InputSizeMismatch,
84
85 #[error("ranks passed should be within the global world_size, got: {0:#?}")]
86 InvalidSplit(Vec<i32>),
87
88 #[error("undefined tensor used for NCCL operation")]
89 UndefinedTensor,
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum NcclStatus {
94 Success,
96 InProgress,
99}
100
101fn nccl_check(result: ncclResult_t) -> Result<NcclStatus, RawNcclError> {
102 match result.0 {
103 0 => Ok(NcclStatus::Success),
104 1 => Err(RawNcclError::UnhandledCudaError),
105 2 => Err(RawNcclError::SystemError),
106 3 => Err(RawNcclError::InternalError),
107 4 => Err(RawNcclError::InvalidArgument),
108 5 => Err(RawNcclError::InvalidUsage),
109 6 => Err(RawNcclError::RemoteError),
110 7 => Ok(NcclStatus::InProgress),
111 _ => panic!("Unknown ncclResult_t: {:?}", result.0),
112 }
113}
114
115pub struct NcclGroupTicket {
123 unsend_marker: PhantomData<*const ()>,
125}
126
127pub fn group_start() -> Result<NcclGroupTicket, NcclError> {
130 nccl_check(unsafe { ncclGroupStart() })?;
132 Ok(NcclGroupTicket {
133 unsend_marker: PhantomData,
134 })
135}
136
137pub fn group_end(_ticket: NcclGroupTicket) -> Result<(), NcclError> {
139 nccl_check(unsafe { ncclGroupEnd() })?;
141 Ok(())
142}
143
144#[derive(Clone, Serialize, Deserialize)]
146pub struct UniqueId {
147 inner: ncclUniqueId,
148}
149
150impl fmt::Debug for UniqueId {
151 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152 f.debug_struct("UniqueId")
153 .field(
154 "inner",
155 &format_args!(
156 "{}",
157 self.inner
158 .internal
159 .iter()
160 .fold(String::new(), |mut output, b| {
161 let _ = write!(output, "{:02x}", b);
162 output
163 })
164 ),
165 )
166 .finish()
167 }
168}
169
170impl UniqueId {
171 pub fn new() -> Result<Self, RawNcclError> {
173 let mut inner = MaybeUninit::uninit();
174 let inner = unsafe {
176 nccl_check(ncclGetUniqueId(inner.as_mut_ptr()))?;
177 inner.assume_init()
178 };
179 Ok(Self { inner })
180 }
181}
182
183#[derive(Debug, Clone, Copy, PartialEq, Eq)]
185pub enum DataType {
186 Int8 = 0,
187 Uint8 = 1,
188 Int32 = 2,
189 Uint32 = 3,
190 Int64 = 4,
191 Uint64 = 5,
192 Float16 = 6,
193 Float32 = 7,
194 Float64 = 8,
195 Bfloat16 = 9,
196}
197
198impl From<DataType> for ncclDataType_t {
199 fn from(data_type: DataType) -> Self {
200 Self(data_type as std::os::raw::c_uint)
201 }
202}
203
204impl TryFrom<ScalarType> for DataType {
205 type Error = NcclError;
206
207 fn try_from(value: ScalarType) -> Result<Self, Self::Error> {
208 match value {
209 ScalarType::Char => Ok(DataType::Int8),
210 ScalarType::Byte => Ok(DataType::Uint8),
211 ScalarType::Half => Ok(DataType::Float16),
212 ScalarType::Float => Ok(DataType::Float32),
213 ScalarType::Double => Ok(DataType::Float64),
214 ScalarType::Int => Ok(DataType::Int32),
215 ScalarType::Long => Ok(DataType::Int64),
216 ScalarType::Bool => Ok(DataType::Uint8),
217 ScalarType::BFloat16 => Ok(DataType::Bfloat16),
218 ScalarType::Float8_e5m2 => Ok(DataType::Uint8),
219 ScalarType::Float8_e4m3fn => Ok(DataType::Uint8),
220 ScalarType::Float8_e4m3fnuz => Ok(DataType::Uint8),
221 ScalarType::Float8_e5m2fnuz => Ok(DataType::Uint8),
222 _ => Err(NcclError::InvalidDataType(value)),
223 }
224 }
225}
226
227#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
229pub enum ReduceOp {
230 Sum = 0,
231 Prod = 1,
232 Max = 2,
233 Min = 3,
234 Avg = 4,
235}
236
237impl From<ReduceOp> for ncclRedOp_t {
238 fn from(reduce_op: ReduceOp) -> Self {
239 Self(reduce_op as std::os::raw::c_uint)
240 }
241}
242
243fn check_tensor(tensor: &Tensor, is_p2p: bool) -> Result<(), NcclError> {
244 if !tensor.defined() {
245 return Err(NcclError::UndefinedTensor);
246 }
247 if !tensor.is_cuda() {
248 return Err(NcclError::InvalidDevice(tensor.device().device_type()));
249 }
250 if tensor.is_sparse() {
251 return Err(NcclError::InvalidSparseTensor);
252 }
253
254 if !is_p2p && !tensor.is_contiguous() {
255 return Err(NcclError::NoncontiguousTensor);
256 }
257
258 Ok(())
259}
260#[derive(Debug)]
264pub struct Communicator {
265 inner: ncclComm_t,
266 world_size: i32,
268 rank: i32,
270 global_world_size: i32,
273 global_rank: i32,
274 device: CudaDevice,
275}
276
277unsafe impl Send for Communicator {}
281unsafe impl Sync for Communicator {}
285
286fn calculate_color(ranks: &[i32]) -> i32 {
288 let mut hasher = FxHasher32::default();
290 ranks.iter().for_each(|r| hasher.write_i32(*r));
291 let hash = hasher.finish();
292
293 (hash % (i32::MAX as u64)) as i32
295}
296
297impl Communicator {
298 pub fn new(
301 device: CudaDevice,
302 world_size: i32,
303 unique_id: UniqueId,
304 rank: i32,
305 ) -> Result<Self, NcclError> {
306 set_device(device)?;
307 let mut inner = MaybeUninit::uninit();
308 let inner = unsafe {
310 nccl_check(ncclCommInitRank(
311 inner.as_mut_ptr(),
312 world_size,
313 unique_id.inner,
314 rank,
315 ))?;
316 inner.assume_init()
317 };
318 Ok(Self {
319 inner,
320 world_size,
321 rank,
322 global_rank: rank,
323 global_world_size: world_size,
324 device,
325 })
326 }
327
328 pub fn split_all(&mut self) -> Result<Self, NcclError> {
331 let ranks = (0..self.global_world_size).collect();
332 Ok(self.split_from(ranks)?.unwrap())
333 }
334
335 pub fn split_from(&mut self, mut ranks: Vec<i32>) -> Result<Option<Self>, NcclError> {
342 ranks.sort();
343 for rank in &ranks {
344 if *rank < 0 || *rank >= self.global_world_size {
345 return Err(NcclError::InvalidSplit(ranks));
346 }
347 }
348
349 let color = match ranks.binary_search(&self.rank) {
350 Ok(_) => calculate_color(ranks.as_slice()),
351 Err(_) => NCCL_SPLIT_NOCOLOR,
352 };
353
354 let mut new = MaybeUninit::uninit();
355
356 let new = unsafe {
358 nccl_check(ncclCommSplit(
359 self.inner,
360 color,
361 self.rank,
362 new.as_mut_ptr(),
363 std::ptr::null_mut(),
364 ))?;
365 new.assume_init()
366 };
367
368 let group_rank = ranks.iter().position(|v| *v == self.rank);
369 match color {
370 NCCL_SPLIT_NOCOLOR => Ok(None),
371 _ => Ok(Some(Self {
372 inner: new,
373 world_size: ranks.len() as i32,
374 rank: group_rank.unwrap() as i32,
375 global_rank: self.global_rank,
376 global_world_size: self.global_world_size,
377 device: self.device,
378 })),
379 }
380 }
381
382 pub fn all_reduce(
387 &mut self,
388 tensor: &TensorCell,
389 reduce_op: ReduceOp,
390 stream: &Stream,
391 ) -> Result<NcclStatus, NcclError> {
392 let tensor = tensor.borrow_mut();
393 let data_type: DataType = tensor.scalar_type().try_into()?;
394
395 check_tensor(&tensor, false)?;
396 if is_float8_type(tensor.scalar_type()) {
397 return Err(NcclError::Float8Reduction);
398 }
399 unsafe {
401 Ok(nccl_check(ncclAllReduce(
402 tensor.data_ptr(),
403 tensor.mut_data_ptr(),
404 tensor.numel() as usize,
405 data_type.into(),
406 reduce_op.into(),
407 self.inner,
408 stream.stream(),
409 ))?)
410 }
411 }
412
413 pub fn broadcast(
417 &mut self,
418 tensor: &TensorCell,
419 root: i32,
420 stream: &Stream,
421 ) -> Result<NcclStatus, NcclError> {
422 let tensor = tensor.borrow_mut();
423 check_tensor(&tensor, false)?;
424 let data_type: DataType = tensor.scalar_type().try_into()?;
425 unsafe {
427 Ok(nccl_check(ncclBroadcast(
428 tensor.data_ptr(),
429 tensor.mut_data_ptr(),
430 tensor.numel() as usize,
431 data_type.into(),
432 root,
433 self.inner,
434 stream.stream(),
435 ))?)
436 }
437 }
438
439 pub fn reduce(
444 &mut self,
445 tensor: &TensorCell,
446 reduce_op: ReduceOp,
447 root: i32,
448 stream: &Stream,
449 ) -> Result<NcclStatus, NcclError> {
450 let tensor = tensor.borrow_mut();
451 check_tensor(&tensor, false)?;
452 if is_float8_type(tensor.scalar_type()) {
453 return Err(NcclError::Float8Reduction);
454 }
455 let data_type: DataType = tensor.scalar_type().try_into()?;
456 unsafe {
458 Ok(nccl_check(ncclReduce(
459 tensor.data_ptr(),
460 tensor.mut_data_ptr(),
461 tensor.numel() as usize,
462 data_type.into(),
463 reduce_op.into(),
464 root,
465 self.inner,
466 stream.stream(),
467 ))?)
468 }
469 }
470
471 pub fn all_gather(
475 &mut self,
476 output_cells: &[TensorCell],
477 input_cell: &TensorCell,
478 stream: &Stream,
479 ) -> Result<NcclStatus, NcclError> {
480 let output = output_cells
481 .iter()
482 .map(|t| t.borrow_mut())
483 .collect::<Vec<_>>();
484 let input = input_cell.borrow();
485 check_tensor(&input, false)?;
486 let output_type = output[0].scalar_type();
487 let output_numel: i64 = output.iter().map(|t| t.numel()).sum();
488 for t in &output {
489 if t.scalar_type() != output_type {
490 return Err(NcclError::TypeMismatch);
491 }
492 }
493 if input.scalar_type() != output_type {
494 return Err(NcclError::TypeMismatch);
495 }
496 if input.numel() * self.world_size as i64 != output_numel {
497 return Err(NcclError::OutputSizeMismatch);
498 }
499 let data_type: DataType = input.scalar_type().try_into()?;
500 unsafe {
504 nccl_check(ncclGroupStart())?;
505 for (i, output) in output.iter().enumerate() {
506 let rank = i as i32;
508 let output_ptr = output.mut_data_ptr();
509 if rank == self.rank {
512 nccl_check(ncclBroadcast(
513 input.data_ptr(),
514 output_ptr,
515 input.numel() as usize,
516 data_type.into(),
517 rank,
518 self.inner,
519 stream.stream(),
520 ))?;
521 } else {
522 nccl_check(ncclBroadcast(
523 output_ptr,
524 output_ptr,
525 output.numel() as usize,
526 data_type.into(),
527 rank,
528 self.inner,
529 stream.stream(),
530 ))?;
531 }
532 }
533 nccl_check(ncclGroupEnd())?;
534 }
535 Ok(NcclStatus::Success)
536 }
537
538 pub fn all_gather_into_tensor(
543 &mut self,
544 output_cell: &TensorCell,
545 input_cell: &TensorCell,
546 stream: &Stream,
547 ) -> Result<NcclStatus, NcclError> {
548 let output = output_cell.borrow_mut();
549 let _input_borrow = if input_cell.aliases(output_cell) {
550 None
551 } else {
552 Some(input_cell.borrow())
553 };
554 let input = unsafe { input_cell.get_unchecked() };
556 check_tensor(&output, false)?;
557 check_tensor(input, false)?;
558 if input.scalar_type() != output.scalar_type() {
559 return Err(NcclError::TypeMismatch);
560 }
561 if input.numel() * self.world_size as i64 != output.numel() {
562 return Err(NcclError::OutputSizeMismatch);
563 }
564
565 let data_type: DataType = input.scalar_type().try_into()?;
566 unsafe {
568 Ok(nccl_check(ncclAllGather(
569 input.data_ptr(),
570 output.mut_data_ptr(),
571 input.numel() as usize,
572 data_type.into(),
573 self.inner,
574 stream.stream(),
575 ))?)
576 }
577 }
578
579 pub fn reduce_scatter_tensor(
584 &mut self,
585 output_cell: &TensorCell,
586 input_cell: &TensorCell,
587 reduce_op: ReduceOp,
588 stream: &Stream,
589 ) -> Result<NcclStatus, NcclError> {
590 let output = output_cell.borrow_mut();
591 let _input_borrow = if input_cell.aliases(output_cell) {
592 None
593 } else {
594 Some(input_cell.borrow())
595 };
596
597 let input = unsafe { input_cell.get_unchecked() }; check_tensor(&output, false)?;
601 check_tensor(input, false)?;
602 if input.scalar_type() != output.scalar_type() {
603 return Err(NcclError::TypeMismatch);
604 }
605 if input.numel() != output.numel() * self.world_size as i64 {
606 return Err(NcclError::InputSizeMismatch);
607 }
608 if is_float8_type(input.scalar_type()) {
609 return Err(NcclError::Float8Reduction);
610 }
611
612 let data_type: DataType = input.scalar_type().try_into()?;
613 unsafe {
615 Ok(nccl_check(ncclReduceScatter(
616 input.data_ptr(),
617 output.mut_data_ptr(),
618 output.numel() as usize,
619 data_type.into(),
620 reduce_op.into(),
621 self.inner,
622 stream.stream(),
623 ))?)
624 }
625 }
626
627 pub fn send(
629 &mut self,
630 tensor_cell: &TensorCell,
631 dst: i32,
632 stream: &Stream,
633 ) -> Result<NcclStatus, NcclError> {
634 let tensor = tensor_cell.borrow();
635 let data_type: DataType = tensor.scalar_type().try_into()?;
636
637 check_tensor(&tensor, true)?;
638
639 unsafe {
641 Ok(nccl_check(ncclSend(
642 tensor.data_ptr(),
643 tensor.numel() as usize,
644 data_type.into(),
645 dst,
646 self.inner,
647 stream.stream(),
648 ))?)
649 }
650 }
651
652 pub fn recv(
654 &mut self,
655 tensor_cell: &TensorCell,
656 src: i32,
657 stream: &Stream,
658 ) -> Result<NcclStatus, NcclError> {
659 let tensor = tensor_cell.borrow_mut();
660 let data_type: DataType = tensor.scalar_type().try_into()?;
661
662 check_tensor(&tensor, true)?;
663
664 unsafe {
666 Ok(nccl_check(ncclRecv(
667 tensor.mut_data_ptr(),
668 tensor.numel() as usize,
669 data_type.into(),
670 src,
671 self.inner,
672 stream.stream(),
673 ))?)
674 }
675 }
676
677 pub fn all_to_all_single(
683 &mut self,
684 output_cell: &TensorCell,
685 input_cell: &TensorCell,
686 stream: &Stream,
687 ) -> Result<NcclStatus, NcclError> {
688 let output = output_cell.borrow_mut();
689 let _input_borrow = if input_cell.aliases(output_cell) {
690 None
691 } else {
692 Some(input_cell.borrow_mut())
693 };
694 let input = unsafe { input_cell.get_unchecked() };
696
697 check_tensor(&output, false)?;
698 check_tensor(input, false)?;
699 if input.scalar_type() != output.scalar_type() {
700 return Err(NcclError::TypeMismatch);
701 }
702
703 let data_type: DataType = input.scalar_type().try_into()?;
704 let count = input.numel() as usize / self.world_size as usize;
705 let rank_stride = input.nbytes() as isize / self.world_size as isize;
706 unsafe {
708 let send_buff = input.data_ptr();
709 let recv_buff = output.mut_data_ptr();
710
711 nccl_check(ncclGroupStart())?;
712 for r in 0..self.world_size {
713 nccl_check(ncclSend(
714 send_buff.offset(r as isize * rank_stride),
715 count,
716 data_type.into(),
717 r,
718 self.inner,
719 stream.stream(),
720 ))?;
721 nccl_check(ncclRecv(
722 recv_buff.offset(r as isize * rank_stride),
723 count,
724 data_type.into(),
725 r,
726 self.inner,
727 stream.stream(),
728 ))?;
729 }
730
731 nccl_check(ncclGroupEnd())?;
732 };
733 Ok(NcclStatus::Success)
734 }
735
736 pub fn barrier(&mut self, stream: &Stream) -> Result<NcclStatus, NcclError> {
740 let tensor = factory_float_tensor(&[1.0], self.device.into());
741 let data_type: DataType = tensor.scalar_type().try_into()?;
742
743 unsafe {
747 Ok(nccl_check(ncclAllReduce(
748 tensor.data_ptr(),
749 tensor.mut_data_ptr(),
750 tensor.numel() as usize,
751 data_type.into(),
752 ReduceOp::Sum.into(),
753 self.inner,
754 stream.stream(),
755 ))?)
756 }
757 }
758}
759
760#[cfg(test)]
761mod tests {
762 use pyo3::Python;
763 use torch_sys2::CudaDevice;
764 use torch_sys2::DeviceIndex;
765 use torch_sys2::factory_float_tensor;
766 use torch_sys2::testing::allclose;
767 use torch_sys2::testing::cuda_full;
768 use torch_sys2::testing::stack;
769
770 use super::*;
771 use crate::cuda::set_device;
772
773 fn test_setup() {
776 Python::initialize();
777 let handle = std::thread::spawn(|| {
778 pyo3::Python::attach(|py| {
779 py.import("torch").expect("failed to import torch");
780 });
781 });
782 handle.join().expect("failed to join torch import thread");
783 }
784
785 #[test]
786 fn all_reduce() {
787 test_setup();
788 let unique_id = UniqueId::new().unwrap();
789 let mut handles = Vec::new();
790 for i in 0..2 {
791 let unique_id = unique_id.clone();
792 handles.push(std::thread::spawn(move || {
793 let device = CudaDevice::new(DeviceIndex(i));
794 set_device(device).unwrap();
795 let stream = Stream::new();
796 let tensor = cuda_full(&[2, 2], 1.0);
797 let expected = cuda_full(&[2, 2], 2.0);
798
799 let cell = TensorCell::new(tensor);
800 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
801 comm.all_reduce(&cell, ReduceOp::Sum, &stream).unwrap();
802 stream.synchronize();
803 assert!(allclose(&cell.borrow(), &expected).unwrap());
804 }));
805 }
806 for handle in handles {
807 handle.join().unwrap();
808 }
809 }
810
811 #[test]
812 fn broadcast() {
813 test_setup();
814 let unique_id = UniqueId::new().unwrap();
815 let mut handles = Vec::new();
816 for i in 0..2 {
817 let unique_id = unique_id.clone();
818 handles.push(std::thread::spawn(move || {
819 let device = CudaDevice::new(DeviceIndex(i));
820 set_device(device).unwrap();
821 let stream = Stream::new();
822 let tensor = cuda_full(&[2, 2], i as f32);
823
824 let cell = TensorCell::new(tensor);
825 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
826 comm.broadcast(&cell, 1, &stream).unwrap();
827 stream.synchronize();
828 assert!(allclose(&cell.borrow(), &cuda_full(&[2, 2], 1.0)).unwrap());
829 }));
830 }
831 for handle in handles {
832 handle.join().unwrap();
833 }
834 }
835
836 #[test]
837 fn reduce() {
838 test_setup();
839 let unique_id = UniqueId::new().unwrap();
840 let mut handles = Vec::new();
841 for i in 0..2 {
842 let unique_id = unique_id.clone();
843 handles.push(std::thread::spawn(move || {
844 let device = CudaDevice::new(DeviceIndex(i));
845 set_device(device).unwrap();
846 let stream = Stream::new();
847 let tensor = cuda_full(&[2, 2], 2.0);
848
849 let cell = TensorCell::new(tensor);
850 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
851 comm.reduce(&cell, ReduceOp::Sum, 0, &stream).unwrap();
852 stream.synchronize();
853 match i {
854 0 => assert!(allclose(&cell.borrow(), &cuda_full(&[2, 2], 4.0)).unwrap()),
855 1 => assert!(allclose(&cell.borrow(), &cuda_full(&[2, 2], 2.0)).unwrap()),
856 _ => unreachable!(),
857 }
858 }));
859 }
860 for handle in handles {
861 handle.join().unwrap();
862 }
863 }
864
865 #[test]
866 fn all_gather_into_tensor() {
867 test_setup();
868 let unique_id = UniqueId::new().unwrap();
869 let mut handles = Vec::new();
870 for i in 0..2 {
871 let unique_id = unique_id.clone();
872 handles.push(std::thread::spawn(move || {
873 let device = CudaDevice::new(DeviceIndex(i));
874 set_device(device).unwrap();
875 let stream = Stream::new();
876 let input_tensor = cuda_full(&[2, 2], i as f32);
877 let output_tensor = cuda_full(&[2, 2, 2], 0.0);
878
879 let expected = {
880 let mut tensor_list = Vec::new();
881 for i in 0..2 {
882 tensor_list.push(cuda_full(&[2, 2], i as f32));
883 }
884 stack(&tensor_list)
885 };
886 let input_cell = TensorCell::new(input_tensor);
887 let output_cell = TensorCell::new(output_tensor);
888 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
889 comm.all_gather_into_tensor(&output_cell, &input_cell, &stream)
890 .unwrap();
891 stream.synchronize();
892 assert!(allclose(&output_cell.borrow(), &expected).unwrap());
893 }));
894 }
895 for handle in handles {
896 handle.join().unwrap();
897 }
898 }
899
900 #[test]
901 fn send_recv() {
902 test_setup();
903 let unique_id = UniqueId::new().unwrap();
904 let mut handles = Vec::new();
905 let unique_id_ = unique_id.clone();
906 handles.push(std::thread::spawn(move || {
907 let device = CudaDevice::new(DeviceIndex(0));
908 set_device(device).unwrap();
909 let stream = Stream::new();
910 let tensor = cuda_full(&[2, 2], 0.0);
911
912 let cell = TensorCell::new(tensor);
913 let mut comm = Communicator::new(device, 2, unique_id_, 0).unwrap();
914 comm.send(&cell, 1, &stream).unwrap();
915 stream.synchronize();
916 }));
917 let unique_id_ = unique_id.clone();
918 handles.push(std::thread::spawn(move || {
919 let device = CudaDevice::new(DeviceIndex(1));
920 set_device(device).unwrap();
921 let stream = Stream::new();
922 let tensor = cuda_full(&[2, 2], 1.1);
923 let expected = cuda_full(&[2, 2], 0.0);
924
925 let cell = TensorCell::new(tensor);
926 let mut comm = Communicator::new(device, 2, unique_id_, 1).unwrap();
927 comm.recv(&cell, 0, &stream).unwrap();
928 stream.synchronize();
929 assert!(allclose(&cell.borrow(), &expected).unwrap());
930 }));
931 for handle in handles {
932 handle.join().unwrap();
933 }
934 }
935
936 #[test]
937 fn all_to_all_single() {
938 test_setup();
939 let unique_id = UniqueId::new().unwrap();
940 let mut handles = Vec::new();
941 for i in 0..2 {
942 let unique_id = unique_id.clone();
943 handles.push(std::thread::spawn(move || {
944 let device = CudaDevice::new(DeviceIndex(i));
945 set_device(device).unwrap();
946 let stream = Stream::new();
947 let input = match i {
948 0 => factory_float_tensor(&[0.0, 1.0], device.into()),
949 1 => factory_float_tensor(&[2.0, 3.0], device.into()),
950 _ => unreachable!(),
951 };
952 let output = cuda_full(&[2], 0.0);
953
954 let input = TensorCell::new(input);
955 let output = TensorCell::new(output);
956
957 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
958 comm.all_to_all_single(&output, &input, &stream).unwrap();
959 stream.synchronize();
960
961 let expected = match i {
962 0 => factory_float_tensor(&[0.0, 2.0], device.into()),
963 1 => factory_float_tensor(&[1.0, 3.0], device.into()),
964 _ => unreachable!(),
965 };
966 assert!(allclose(&output.borrow(), &expected).unwrap());
967 }));
968 }
969 for handle in handles {
970 handle.join().unwrap();
971 }
972 }
973
974 #[test]
975 fn reduce_scatter_tensor() {
976 test_setup();
977 let unique_id = UniqueId::new().unwrap();
978 let mut handles = Vec::new();
979 for i in 0..2 {
980 let unique_id = unique_id.clone();
981 handles.push(std::thread::spawn(move || {
982 let device = CudaDevice::new(DeviceIndex(i));
983 set_device(device).unwrap();
984 let stream = Stream::new();
985 let input = factory_float_tensor(&[0.0, 1.0, 2.0, 3.0], device.into());
986 let output = cuda_full(&[2], 1.0);
987
988 let input = TensorCell::new(input);
989 let output = TensorCell::new(output);
990
991 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
992 comm.reduce_scatter_tensor(&output, &input, ReduceOp::Sum, &stream)
993 .unwrap();
994 stream.synchronize();
995
996 let expected = match i {
997 0 => factory_float_tensor(&[0.0, 2.0], device.into()),
998 1 => factory_float_tensor(&[4.0, 6.0], device.into()),
999 _ => unreachable!(),
1000 };
1001 assert!(allclose(&output.borrow(), &expected).unwrap());
1002 }));
1003 }
1004 for handle in handles {
1005 handle.join().unwrap();
1006 }
1007 }
1008
1009 #[test]
1010 fn split_from() {
1011 test_setup();
1012 let unique_id = UniqueId::new().unwrap();
1013 let mut handles = Vec::new();
1014 for i in 0..2 {
1015 let unique_id = unique_id.clone();
1016 handles.push(std::thread::spawn(move || {
1017 let device = CudaDevice::new(DeviceIndex(i));
1018 set_device(device).unwrap();
1019 let stream = Stream::new();
1020 let tensor = cuda_full(&[2, 2], 1.0);
1021 let cell = TensorCell::new(tensor);
1022 let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
1023
1024 let split_comm = comm.split_from(vec![0]).unwrap();
1026
1027 match i {
1028 0 => assert!(split_comm.is_some()),
1029 1 => assert!(split_comm.is_none()),
1030 _ => unreachable!(),
1031 };
1032
1033 match i {
1034 0 => {
1035 split_comm
1036 .unwrap()
1037 .all_reduce(&cell, ReduceOp::Sum, &stream)
1038 .unwrap();
1039 stream.synchronize();
1040 let expected = cuda_full(&[2, 2], 1.0);
1041 assert!(allclose(&cell.borrow(), &expected).unwrap());
1042 }
1043 1 => (),
1044 _ => unreachable!(),
1045 };
1046 }));
1047 }
1048 for handle in handles {
1049 handle.join().unwrap();
1050 }
1051 }
1052}