torch_sys_cuda/
nccl.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt;
10use std::fmt::Write;
11use std::hash::Hasher;
12use std::marker::PhantomData;
13use std::mem::MaybeUninit;
14
15use fxhash::FxHasher32;
16use nccl_sys::*;
17use serde::Deserialize;
18use serde::Serialize;
19use thiserror::Error;
20use torch_sys2::CudaDevice;
21use torch_sys2::DeviceType;
22use torch_sys2::ScalarType;
23use torch_sys2::Tensor;
24use torch_sys2::TensorCell;
25use torch_sys2::factory_float_tensor;
26use torch_sys2::is_float8_type;
27
28use crate::cuda::CudaError;
29use crate::cuda::Stream;
30use crate::cuda::set_device;
31
32/// Corresponds to ncclResult_t error cases
33#[derive(Debug, Error)]
34pub enum RawNcclError {
35    #[error("a call to a CUDA function failed")]
36    UnhandledCudaError,
37    #[error("a call to the system failed")]
38    SystemError,
39    #[error("an internal check failed; either bug in nccl or memory corruption")]
40    InternalError,
41    #[error("an argument has an invalid value")]
42    InvalidArgument,
43    #[error("a call to NCCL is incorrect, usually a programming error")]
44    InvalidUsage,
45    #[error(
46        "a call failed possibly due to a network error or a remote process exiting prematurely"
47    )]
48    RemoteError,
49}
50
51/// Types of errors that the safe [`Communicator`] API can return.
52#[derive(Debug, Error)]
53pub enum NcclError {
54    #[error("a NCCL-level error: {0:?}")]
55    NcclError(#[from] RawNcclError),
56
57    #[error("a CUDA-level error: {0:?}")]
58    CudaError(#[from] CudaError),
59
60    #[error("invalid NCCL data type: {0:#?}")]
61    InvalidDataType(ScalarType),
62
63    #[error("tensor used in collective must be contiguous")]
64    NoncontiguousTensor,
65
66    // TODO would be nice to get real device printouts
67    #[error("tensor must be on CUDA device, got: {0:?}")]
68    InvalidDevice(DeviceType),
69
70    #[error("got sparse tensor, only dense tensors allowed")]
71    InvalidSparseTensor,
72
73    #[error("float8 dtypes are not currently supported for NCCL reductions")]
74    Float8Reduction,
75
76    #[error("output tensor must have the same type as input tensor")]
77    TypeMismatch,
78
79    #[error("output tensor size must be equal to world size times input tensor size")]
80    OutputSizeMismatch,
81
82    #[error("input tensor must be the same size as output size times world size")]
83    InputSizeMismatch,
84
85    #[error("ranks passed should be within the global world_size, got: {0:#?}")]
86    InvalidSplit(Vec<i32>),
87
88    #[error("undefined tensor used for NCCL operation")]
89    UndefinedTensor,
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum NcclStatus {
94    /// Function succeeded.
95    Success,
96    /// A NCCL operation on the communicator is being enqueued and is being
97    /// progressed in the background.
98    InProgress,
99}
100
101fn nccl_check(result: ncclResult_t) -> Result<NcclStatus, RawNcclError> {
102    match result.0 {
103        0 => Ok(NcclStatus::Success),
104        1 => Err(RawNcclError::UnhandledCudaError),
105        2 => Err(RawNcclError::SystemError),
106        3 => Err(RawNcclError::InternalError),
107        4 => Err(RawNcclError::InvalidArgument),
108        5 => Err(RawNcclError::InvalidUsage),
109        6 => Err(RawNcclError::RemoteError),
110        7 => Ok(NcclStatus::InProgress),
111        _ => panic!("Unknown ncclResult_t: {:?}", result.0),
112    }
113}
114
115/// A ticket that we use to link group start/end calls. Does not implement
116/// `Send`, to enforce that group start and end calls are on the same thread.
117// This isn't an RAII guard because ncclGroupEnd can raise errors.
118//
119// TODO: technically anyone can manufacture a ticket to pass to group_end. We
120// can prevent this by checking thread id or something, but seems unnecessary;
121// you'd really have to be trying to mess things up.
122pub struct NcclGroupTicket {
123    // marker to disable Send on this type.
124    unsend_marker: PhantomData<*const ()>,
125}
126
127/// Start a new NCCL group. All NCCL calls within this group will be combined,
128/// provided that they were issued on the same thread.
129pub fn group_start() -> Result<NcclGroupTicket, NcclError> {
130    // SAFETY: intended use of C function.
131    nccl_check(unsafe { ncclGroupStart() })?;
132    Ok(NcclGroupTicket {
133        unsend_marker: PhantomData,
134    })
135}
136
137/// End the NCCL group.
138pub fn group_end(_ticket: NcclGroupTicket) -> Result<(), NcclError> {
139    // SAFETY: intended use of C function.
140    nccl_check(unsafe { ncclGroupEnd() })?;
141    Ok(())
142}
143
144/// Binding for `ncclUniqueId`.
145#[derive(Clone, Serialize, Deserialize)]
146pub struct UniqueId {
147    inner: ncclUniqueId,
148}
149
150impl fmt::Debug for UniqueId {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        f.debug_struct("UniqueId")
153            .field(
154                "inner",
155                &format_args!(
156                    "{}",
157                    self.inner
158                        .internal
159                        .iter()
160                        .fold(String::new(), |mut output, b| {
161                            let _ = write!(output, "{:02x}", b);
162                            output
163                        })
164                ),
165            )
166            .finish()
167    }
168}
169
170impl UniqueId {
171    /// Create a new `UniqueId`.
172    pub fn new() -> Result<Self, RawNcclError> {
173        let mut inner = MaybeUninit::uninit();
174        // Safety: intended usage of this function
175        let inner = unsafe {
176            nccl_check(ncclGetUniqueId(inner.as_mut_ptr()))?;
177            inner.assume_init()
178        };
179        Ok(Self { inner })
180    }
181}
182
183/// Rust version of `ncclDataType_t`.
184#[derive(Debug, Clone, Copy, PartialEq, Eq)]
185pub enum DataType {
186    Int8 = 0,
187    Uint8 = 1,
188    Int32 = 2,
189    Uint32 = 3,
190    Int64 = 4,
191    Uint64 = 5,
192    Float16 = 6,
193    Float32 = 7,
194    Float64 = 8,
195    Bfloat16 = 9,
196}
197
198impl From<DataType> for ncclDataType_t {
199    fn from(data_type: DataType) -> Self {
200        Self(data_type as std::os::raw::c_uint)
201    }
202}
203
204impl TryFrom<ScalarType> for DataType {
205    type Error = NcclError;
206
207    fn try_from(value: ScalarType) -> Result<Self, Self::Error> {
208        match value {
209            ScalarType::Char => Ok(DataType::Int8),
210            ScalarType::Byte => Ok(DataType::Uint8),
211            ScalarType::Half => Ok(DataType::Float16),
212            ScalarType::Float => Ok(DataType::Float32),
213            ScalarType::Double => Ok(DataType::Float64),
214            ScalarType::Int => Ok(DataType::Int32),
215            ScalarType::Long => Ok(DataType::Int64),
216            ScalarType::Bool => Ok(DataType::Uint8),
217            ScalarType::BFloat16 => Ok(DataType::Bfloat16),
218            ScalarType::Float8_e5m2 => Ok(DataType::Uint8),
219            ScalarType::Float8_e4m3fn => Ok(DataType::Uint8),
220            ScalarType::Float8_e4m3fnuz => Ok(DataType::Uint8),
221            ScalarType::Float8_e5m2fnuz => Ok(DataType::Uint8),
222            _ => Err(NcclError::InvalidDataType(value)),
223        }
224    }
225}
226
227/// Rust version of `ncclRedOp_t`.
228#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
229pub enum ReduceOp {
230    Sum = 0,
231    Prod = 1,
232    Max = 2,
233    Min = 3,
234    Avg = 4,
235}
236
237impl From<ReduceOp> for ncclRedOp_t {
238    fn from(reduce_op: ReduceOp) -> Self {
239        Self(reduce_op as std::os::raw::c_uint)
240    }
241}
242
243fn check_tensor(tensor: &Tensor, is_p2p: bool) -> Result<(), NcclError> {
244    if !tensor.defined() {
245        return Err(NcclError::UndefinedTensor);
246    }
247    if !tensor.is_cuda() {
248        return Err(NcclError::InvalidDevice(tensor.device().device_type()));
249    }
250    if tensor.is_sparse() {
251        return Err(NcclError::InvalidSparseTensor);
252    }
253
254    if !is_p2p && !tensor.is_contiguous() {
255        return Err(NcclError::NoncontiguousTensor);
256    }
257
258    Ok(())
259}
260/// Wraps a NCCL communicator, and provides a Tensor-based interface it.
261///
262/// This implements a subset of the `c10d::ProcessGroup`API.
263#[derive(Debug)]
264pub struct Communicator {
265    inner: ncclComm_t,
266    // World size of this communicator.
267    world_size: i32,
268    // Rank within the world, value is within 0..world_size.
269    rank: i32,
270    // Size of the global world. This can be different from `world_size` if this
271    // communicator was split off from a larger one.
272    global_world_size: i32,
273    global_rank: i32,
274    device: CudaDevice,
275}
276
277/// SAFETY: `ncclComm_t` is okay to access from multiple threads, but each
278/// communicator *must* issue nccl calls in the same order. It is up to the user
279/// to ensure this.
280unsafe impl Send for Communicator {}
281/// SAFETY: `ncclComm_t` is okay to access from multiple threads, but each
282/// communicator *must* issue nccl calls in the same order. It is up to the user
283/// to ensure this.
284unsafe impl Sync for Communicator {}
285
286// Ported from: https://github.com/pytorch/pytorch/blob/0d6d29af380d6a639bf23127f05e439fafa640bf/torch/distributed/distributed_c10d.py#L4669
287fn calculate_color(ranks: &[i32]) -> i32 {
288    // Assumes `ranks` is sorted.
289    let mut hasher = FxHasher32::default();
290    ranks.iter().for_each(|r| hasher.write_i32(*r));
291    let hash = hasher.finish();
292
293    // Convert to positive value to fit as color arg to `ncclCommSplit`.
294    (hash % (i32::MAX as u64)) as i32
295}
296
297impl Communicator {
298    /// Create a new communicator. This function must be called by a different
299    /// thread/process per rank.
300    pub fn new(
301        device: CudaDevice,
302        world_size: i32,
303        unique_id: UniqueId,
304        rank: i32,
305    ) -> Result<Self, NcclError> {
306        set_device(device)?;
307        let mut inner = MaybeUninit::uninit();
308        // SAFETY: intended use of C function
309        let inner = unsafe {
310            nccl_check(ncclCommInitRank(
311                inner.as_mut_ptr(),
312                world_size,
313                unique_id.inner,
314                rank,
315            ))?;
316            inner.assume_init()
317        };
318        Ok(Self {
319            inner,
320            world_size,
321            rank,
322            global_rank: rank,
323            global_world_size: world_size,
324            device,
325        })
326    }
327
328    /// Split off a new communicator from this one, preserving the same world
329    /// size.
330    pub fn split_all(&mut self) -> Result<Self, NcclError> {
331        let ranks = (0..self.global_world_size).collect();
332        Ok(self.split_from(ranks)?.unwrap())
333    }
334
335    /// Split off a new communicator from this one. Only `ranks` will be present
336    /// on this new communicator.
337    ///
338    /// If `ranks` is empty, `ncclCommSplit` will be called with
339    /// NCCL_SPLIT_NOCOLOR. This can be useful if ranks excluded from the split
340    /// don't even know what ranks will be included.
341    pub fn split_from(&mut self, mut ranks: Vec<i32>) -> Result<Option<Self>, NcclError> {
342        ranks.sort();
343        for rank in &ranks {
344            if *rank < 0 || *rank >= self.global_world_size {
345                return Err(NcclError::InvalidSplit(ranks));
346            }
347        }
348
349        let color = match ranks.binary_search(&self.rank) {
350            Ok(_) => calculate_color(ranks.as_slice()),
351            Err(_) => NCCL_SPLIT_NOCOLOR,
352        };
353
354        let mut new = MaybeUninit::uninit();
355
356        // SAFETY: intended use of C function
357        let new = unsafe {
358            nccl_check(ncclCommSplit(
359                self.inner,
360                color,
361                self.rank,
362                new.as_mut_ptr(),
363                std::ptr::null_mut(),
364            ))?;
365            new.assume_init()
366        };
367
368        let group_rank = ranks.iter().position(|v| *v == self.rank);
369        match color {
370            NCCL_SPLIT_NOCOLOR => Ok(None),
371            _ => Ok(Some(Self {
372                inner: new,
373                world_size: ranks.len() as i32,
374                rank: group_rank.unwrap() as i32,
375                global_rank: self.global_rank,
376                global_world_size: self.global_world_size,
377                device: self.device,
378            })),
379        }
380    }
381
382    /// Reduce the tensor data across all ranks, with each rank receiving the
383    /// final result in-place.
384    ///
385    /// See `torch.distributed.all_reduce` for more detailed documentation.
386    pub fn all_reduce(
387        &mut self,
388        tensor: &TensorCell,
389        reduce_op: ReduceOp,
390        stream: &Stream,
391    ) -> Result<NcclStatus, NcclError> {
392        let tensor = tensor.borrow_mut();
393        let data_type: DataType = tensor.scalar_type().try_into()?;
394
395        check_tensor(&tensor, false)?;
396        if is_float8_type(tensor.scalar_type()) {
397            return Err(NcclError::Float8Reduction);
398        }
399        // SAFETY: intended use of C function
400        unsafe {
401            Ok(nccl_check(ncclAllReduce(
402                tensor.data_ptr(),
403                tensor.mut_data_ptr(),
404                tensor.numel() as usize,
405                data_type.into(),
406                reduce_op.into(),
407                self.inner,
408                stream.stream(),
409            ))?)
410        }
411    }
412
413    /// Broadcast the tensor data on the `root` rank to all the others.
414    ///
415    /// See `torch.distributed.broadcast` for more detailed documentation.
416    pub fn broadcast(
417        &mut self,
418        tensor: &TensorCell,
419        root: i32,
420        stream: &Stream,
421    ) -> Result<NcclStatus, NcclError> {
422        let tensor = tensor.borrow_mut();
423        check_tensor(&tensor, false)?;
424        let data_type: DataType = tensor.scalar_type().try_into()?;
425        // SAFETY: intended use of C function
426        unsafe {
427            Ok(nccl_check(ncclBroadcast(
428                tensor.data_ptr(),
429                tensor.mut_data_ptr(),
430                tensor.numel() as usize,
431                data_type.into(),
432                root,
433                self.inner,
434                stream.stream(),
435            ))?)
436        }
437    }
438
439    /// Reduce the tensor data across all ranks, writing the result out to
440    /// tensor on the `root` rank.
441    ///
442    /// See `torch.distributed.reduce` for more detailed documentation.
443    pub fn reduce(
444        &mut self,
445        tensor: &TensorCell,
446        reduce_op: ReduceOp,
447        root: i32,
448        stream: &Stream,
449    ) -> Result<NcclStatus, NcclError> {
450        let tensor = tensor.borrow_mut();
451        check_tensor(&tensor, false)?;
452        if is_float8_type(tensor.scalar_type()) {
453            return Err(NcclError::Float8Reduction);
454        }
455        let data_type: DataType = tensor.scalar_type().try_into()?;
456        // SAFETY: intended use of C function
457        unsafe {
458            Ok(nccl_check(ncclReduce(
459                tensor.data_ptr(),
460                tensor.mut_data_ptr(),
461                tensor.numel() as usize,
462                data_type.into(),
463                reduce_op.into(),
464                root,
465                self.inner,
466                stream.stream(),
467            ))?)
468        }
469    }
470
471    /// Gather tensors from all ranks into a list of output tensors.
472    ///
473    /// See `torch.distributed.all_gather` for more detailed documentation.
474    pub fn all_gather(
475        &mut self,
476        output_cells: &[TensorCell],
477        input_cell: &TensorCell,
478        stream: &Stream,
479    ) -> Result<NcclStatus, NcclError> {
480        let output = output_cells
481            .iter()
482            .map(|t| t.borrow_mut())
483            .collect::<Vec<_>>();
484        let input = input_cell.borrow();
485        check_tensor(&input, false)?;
486        let output_type = output[0].scalar_type();
487        let output_numel: i64 = output.iter().map(|t| t.numel()).sum();
488        for t in &output {
489            if t.scalar_type() != output_type {
490                return Err(NcclError::TypeMismatch);
491            }
492        }
493        if input.scalar_type() != output_type {
494            return Err(NcclError::TypeMismatch);
495        }
496        if input.numel() * self.world_size as i64 != output_numel {
497            return Err(NcclError::OutputSizeMismatch);
498        }
499        let data_type: DataType = input.scalar_type().try_into()?;
500        // TODO: optimization if the output list are all the same shape, where
501        // a single allGather can be done.
502        // SAFETY: intended use of C function
503        unsafe {
504            nccl_check(ncclGroupStart())?;
505            for (i, output) in output.iter().enumerate() {
506                // auto& input = (i == rank_) ? inputTensor : output;
507                let rank = i as i32;
508                let output_ptr = output.mut_data_ptr();
509                // If the current rank is the sender, we need to broadcast the input tensor.
510                // Everything else just broadcasts the output tensor.
511                if rank == self.rank {
512                    nccl_check(ncclBroadcast(
513                        input.data_ptr(),
514                        output_ptr,
515                        input.numel() as usize,
516                        data_type.into(),
517                        rank,
518                        self.inner,
519                        stream.stream(),
520                    ))?;
521                } else {
522                    nccl_check(ncclBroadcast(
523                        output_ptr,
524                        output_ptr,
525                        output.numel() as usize,
526                        data_type.into(),
527                        rank,
528                        self.inner,
529                        stream.stream(),
530                    ))?;
531                }
532            }
533            nccl_check(ncclGroupEnd())?;
534        }
535        Ok(NcclStatus::Success)
536    }
537
538    /// Gather tensors from all ranks into a single output tensor.
539    ///
540    /// See `torch.distributed.all_gather_into_tensor` for more detailed
541    /// documentation.
542    pub fn all_gather_into_tensor(
543        &mut self,
544        output_cell: &TensorCell,
545        input_cell: &TensorCell,
546        stream: &Stream,
547    ) -> Result<NcclStatus, NcclError> {
548        let output = output_cell.borrow_mut();
549        let _input_borrow = if input_cell.aliases(output_cell) {
550            None
551        } else {
552            Some(input_cell.borrow())
553        };
554        // SAFETY: we either borrowed above or borrowed an alias
555        let input = unsafe { input_cell.get_unchecked() };
556        check_tensor(&output, false)?;
557        check_tensor(input, false)?;
558        if input.scalar_type() != output.scalar_type() {
559            return Err(NcclError::TypeMismatch);
560        }
561        if input.numel() * self.world_size as i64 != output.numel() {
562            return Err(NcclError::OutputSizeMismatch);
563        }
564
565        let data_type: DataType = input.scalar_type().try_into()?;
566        // SAFETY: intended use of C function
567        unsafe {
568            Ok(nccl_check(ncclAllGather(
569                input.data_ptr(),
570                output.mut_data_ptr(),
571                input.numel() as usize,
572                data_type.into(),
573                self.inner,
574                stream.stream(),
575            ))?)
576        }
577    }
578
579    /// Reduce, then scatters the result to all tensors in the group.
580    ///
581    /// See `torch.distributed.reduce_scatter_tensor` for more detailed
582    /// documentation.
583    pub fn reduce_scatter_tensor(
584        &mut self,
585        output_cell: &TensorCell,
586        input_cell: &TensorCell,
587        reduce_op: ReduceOp,
588        stream: &Stream,
589    ) -> Result<NcclStatus, NcclError> {
590        let output = output_cell.borrow_mut();
591        let _input_borrow = if input_cell.aliases(output_cell) {
592            None
593        } else {
594            Some(input_cell.borrow())
595        };
596
597        // SAFETY: we either borrowed above or borrowed an alias
598        let input = unsafe { input_cell.get_unchecked() }; // SAFETY: intended use of C function
599
600        check_tensor(&output, false)?;
601        check_tensor(input, false)?;
602        if input.scalar_type() != output.scalar_type() {
603            return Err(NcclError::TypeMismatch);
604        }
605        if input.numel() != output.numel() * self.world_size as i64 {
606            return Err(NcclError::InputSizeMismatch);
607        }
608        if is_float8_type(input.scalar_type()) {
609            return Err(NcclError::Float8Reduction);
610        }
611
612        let data_type: DataType = input.scalar_type().try_into()?;
613        // SAFETY: intended use of C function
614        unsafe {
615            Ok(nccl_check(ncclReduceScatter(
616                input.data_ptr(),
617                output.mut_data_ptr(),
618                output.numel() as usize,
619                data_type.into(),
620                reduce_op.into(),
621                self.inner,
622                stream.stream(),
623            ))?)
624        }
625    }
626
627    /// Send a tensor to the rank `dst`.
628    pub fn send(
629        &mut self,
630        tensor_cell: &TensorCell,
631        dst: i32,
632        stream: &Stream,
633    ) -> Result<NcclStatus, NcclError> {
634        let tensor = tensor_cell.borrow();
635        let data_type: DataType = tensor.scalar_type().try_into()?;
636
637        check_tensor(&tensor, true)?;
638
639        // SAFETY: intended use of C function
640        unsafe {
641            Ok(nccl_check(ncclSend(
642                tensor.data_ptr(),
643                tensor.numel() as usize,
644                data_type.into(),
645                dst,
646                self.inner,
647                stream.stream(),
648            ))?)
649        }
650    }
651
652    /// Receive a tensor from the rank `src`.
653    pub fn recv(
654        &mut self,
655        tensor_cell: &TensorCell,
656        src: i32,
657        stream: &Stream,
658    ) -> Result<NcclStatus, NcclError> {
659        let tensor = tensor_cell.borrow_mut();
660        let data_type: DataType = tensor.scalar_type().try_into()?;
661
662        check_tensor(&tensor, true)?;
663
664        // SAFETY: intended use of C function
665        unsafe {
666            Ok(nccl_check(ncclRecv(
667                tensor.mut_data_ptr(),
668                tensor.numel() as usize,
669                data_type.into(),
670                src,
671                self.inner,
672                stream.stream(),
673            ))?)
674        }
675    }
676
677    /// Split the input tensor then scatter the split list to all processes in
678    /// the group. The received splits are then concatenated into the output tensor.
679    ///
680    /// See `torch.distributed.all_to_all_single` for more detailed
681    /// documentation.
682    pub fn all_to_all_single(
683        &mut self,
684        output_cell: &TensorCell,
685        input_cell: &TensorCell,
686        stream: &Stream,
687    ) -> Result<NcclStatus, NcclError> {
688        let output = output_cell.borrow_mut();
689        let _input_borrow = if input_cell.aliases(output_cell) {
690            None
691        } else {
692            Some(input_cell.borrow_mut())
693        };
694        // SAFETY: we either borrowed above or borrowed an alias
695        let input = unsafe { input_cell.get_unchecked() };
696
697        check_tensor(&output, false)?;
698        check_tensor(input, false)?;
699        if input.scalar_type() != output.scalar_type() {
700            return Err(NcclError::TypeMismatch);
701        }
702
703        let data_type: DataType = input.scalar_type().try_into()?;
704        let count = input.numel() as usize / self.world_size as usize;
705        let rank_stride = input.nbytes() as isize / self.world_size as isize;
706        // SAFETY: intended use of C functions
707        unsafe {
708            let send_buff = input.data_ptr();
709            let recv_buff = output.mut_data_ptr();
710
711            nccl_check(ncclGroupStart())?;
712            for r in 0..self.world_size {
713                nccl_check(ncclSend(
714                    send_buff.offset(r as isize * rank_stride),
715                    count,
716                    data_type.into(),
717                    r,
718                    self.inner,
719                    stream.stream(),
720                ))?;
721                nccl_check(ncclRecv(
722                    recv_buff.offset(r as isize * rank_stride),
723                    count,
724                    data_type.into(),
725                    r,
726                    self.inner,
727                    stream.stream(),
728                ))?;
729            }
730
731            nccl_check(ncclGroupEnd())?;
732        };
733        Ok(NcclStatus::Success)
734    }
735
736    /// Synchronize all ranks.
737    ///
738    /// See `torch.distributed.barrier` for more detailed documentation.
739    pub fn barrier(&mut self, stream: &Stream) -> Result<NcclStatus, NcclError> {
740        let tensor = factory_float_tensor(&[1.0], self.device.into());
741        let data_type: DataType = tensor.scalar_type().try_into()?;
742
743        // NOTE(agallagher): NCCL doesn't have a native barrier impl, so use
744        // `ncclAllReduce` to implement one.
745        // SAFETY: intended use of C function
746        unsafe {
747            Ok(nccl_check(ncclAllReduce(
748                tensor.data_ptr(),
749                tensor.mut_data_ptr(),
750                tensor.numel() as usize,
751                data_type.into(),
752                ReduceOp::Sum.into(),
753                self.inner,
754                stream.stream(),
755            ))?)
756        }
757    }
758}
759
760#[cfg(test)]
761mod tests {
762    use pyo3::Python;
763    use torch_sys2::CudaDevice;
764    use torch_sys2::DeviceIndex;
765    use torch_sys2::factory_float_tensor;
766    use torch_sys2::testing::allclose;
767    use torch_sys2::testing::cuda_full;
768    use torch_sys2::testing::stack;
769
770    use super::*;
771    use crate::cuda::set_device;
772
773    /// Initialize Python and import torch in a separate thread.
774    /// This is a workaround for a pybind11 bug in PyTorch.
775    fn test_setup() {
776        Python::initialize();
777        let handle = std::thread::spawn(|| {
778            pyo3::Python::attach(|py| {
779                py.import("torch").expect("failed to import torch");
780            });
781        });
782        handle.join().expect("failed to join torch import thread");
783    }
784
785    #[test]
786    fn all_reduce() {
787        test_setup();
788        let unique_id = UniqueId::new().unwrap();
789        let mut handles = Vec::new();
790        for i in 0..2 {
791            let unique_id = unique_id.clone();
792            handles.push(std::thread::spawn(move || {
793                let device = CudaDevice::new(DeviceIndex(i));
794                set_device(device).unwrap();
795                let stream = Stream::new();
796                let tensor = cuda_full(&[2, 2], 1.0);
797                let expected = cuda_full(&[2, 2], 2.0);
798
799                let cell = TensorCell::new(tensor);
800                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
801                comm.all_reduce(&cell, ReduceOp::Sum, &stream).unwrap();
802                stream.synchronize();
803                assert!(allclose(&cell.borrow(), &expected).unwrap());
804            }));
805        }
806        for handle in handles {
807            handle.join().unwrap();
808        }
809    }
810
811    #[test]
812    fn broadcast() {
813        test_setup();
814        let unique_id = UniqueId::new().unwrap();
815        let mut handles = Vec::new();
816        for i in 0..2 {
817            let unique_id = unique_id.clone();
818            handles.push(std::thread::spawn(move || {
819                let device = CudaDevice::new(DeviceIndex(i));
820                set_device(device).unwrap();
821                let stream = Stream::new();
822                let tensor = cuda_full(&[2, 2], i as f32);
823
824                let cell = TensorCell::new(tensor);
825                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
826                comm.broadcast(&cell, 1, &stream).unwrap();
827                stream.synchronize();
828                assert!(allclose(&cell.borrow(), &cuda_full(&[2, 2], 1.0)).unwrap());
829            }));
830        }
831        for handle in handles {
832            handle.join().unwrap();
833        }
834    }
835
836    #[test]
837    fn reduce() {
838        test_setup();
839        let unique_id = UniqueId::new().unwrap();
840        let mut handles = Vec::new();
841        for i in 0..2 {
842            let unique_id = unique_id.clone();
843            handles.push(std::thread::spawn(move || {
844                let device = CudaDevice::new(DeviceIndex(i));
845                set_device(device).unwrap();
846                let stream = Stream::new();
847                let tensor = cuda_full(&[2, 2], 2.0);
848
849                let cell = TensorCell::new(tensor);
850                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
851                comm.reduce(&cell, ReduceOp::Sum, 0, &stream).unwrap();
852                stream.synchronize();
853                match i {
854                    0 => assert!(allclose(&cell.borrow(), &cuda_full(&[2, 2], 4.0)).unwrap()),
855                    1 => assert!(allclose(&cell.borrow(), &cuda_full(&[2, 2], 2.0)).unwrap()),
856                    _ => unreachable!(),
857                }
858            }));
859        }
860        for handle in handles {
861            handle.join().unwrap();
862        }
863    }
864
865    #[test]
866    fn all_gather_into_tensor() {
867        test_setup();
868        let unique_id = UniqueId::new().unwrap();
869        let mut handles = Vec::new();
870        for i in 0..2 {
871            let unique_id = unique_id.clone();
872            handles.push(std::thread::spawn(move || {
873                let device = CudaDevice::new(DeviceIndex(i));
874                set_device(device).unwrap();
875                let stream = Stream::new();
876                let input_tensor = cuda_full(&[2, 2], i as f32);
877                let output_tensor = cuda_full(&[2, 2, 2], 0.0);
878
879                let expected = {
880                    let mut tensor_list = Vec::new();
881                    for i in 0..2 {
882                        tensor_list.push(cuda_full(&[2, 2], i as f32));
883                    }
884                    stack(&tensor_list)
885                };
886                let input_cell = TensorCell::new(input_tensor);
887                let output_cell = TensorCell::new(output_tensor);
888                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
889                comm.all_gather_into_tensor(&output_cell, &input_cell, &stream)
890                    .unwrap();
891                stream.synchronize();
892                assert!(allclose(&output_cell.borrow(), &expected).unwrap());
893            }));
894        }
895        for handle in handles {
896            handle.join().unwrap();
897        }
898    }
899
900    #[test]
901    fn send_recv() {
902        test_setup();
903        let unique_id = UniqueId::new().unwrap();
904        let mut handles = Vec::new();
905        let unique_id_ = unique_id.clone();
906        handles.push(std::thread::spawn(move || {
907            let device = CudaDevice::new(DeviceIndex(0));
908            set_device(device).unwrap();
909            let stream = Stream::new();
910            let tensor = cuda_full(&[2, 2], 0.0);
911
912            let cell = TensorCell::new(tensor);
913            let mut comm = Communicator::new(device, 2, unique_id_, 0).unwrap();
914            comm.send(&cell, 1, &stream).unwrap();
915            stream.synchronize();
916        }));
917        let unique_id_ = unique_id.clone();
918        handles.push(std::thread::spawn(move || {
919            let device = CudaDevice::new(DeviceIndex(1));
920            set_device(device).unwrap();
921            let stream = Stream::new();
922            let tensor = cuda_full(&[2, 2], 1.1);
923            let expected = cuda_full(&[2, 2], 0.0);
924
925            let cell = TensorCell::new(tensor);
926            let mut comm = Communicator::new(device, 2, unique_id_, 1).unwrap();
927            comm.recv(&cell, 0, &stream).unwrap();
928            stream.synchronize();
929            assert!(allclose(&cell.borrow(), &expected).unwrap());
930        }));
931        for handle in handles {
932            handle.join().unwrap();
933        }
934    }
935
936    #[test]
937    fn all_to_all_single() {
938        test_setup();
939        let unique_id = UniqueId::new().unwrap();
940        let mut handles = Vec::new();
941        for i in 0..2 {
942            let unique_id = unique_id.clone();
943            handles.push(std::thread::spawn(move || {
944                let device = CudaDevice::new(DeviceIndex(i));
945                set_device(device).unwrap();
946                let stream = Stream::new();
947                let input = match i {
948                    0 => factory_float_tensor(&[0.0, 1.0], device.into()),
949                    1 => factory_float_tensor(&[2.0, 3.0], device.into()),
950                    _ => unreachable!(),
951                };
952                let output = cuda_full(&[2], 0.0);
953
954                let input = TensorCell::new(input);
955                let output = TensorCell::new(output);
956
957                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
958                comm.all_to_all_single(&output, &input, &stream).unwrap();
959                stream.synchronize();
960
961                let expected = match i {
962                    0 => factory_float_tensor(&[0.0, 2.0], device.into()),
963                    1 => factory_float_tensor(&[1.0, 3.0], device.into()),
964                    _ => unreachable!(),
965                };
966                assert!(allclose(&output.borrow(), &expected).unwrap());
967            }));
968        }
969        for handle in handles {
970            handle.join().unwrap();
971        }
972    }
973
974    #[test]
975    fn reduce_scatter_tensor() {
976        test_setup();
977        let unique_id = UniqueId::new().unwrap();
978        let mut handles = Vec::new();
979        for i in 0..2 {
980            let unique_id = unique_id.clone();
981            handles.push(std::thread::spawn(move || {
982                let device = CudaDevice::new(DeviceIndex(i));
983                set_device(device).unwrap();
984                let stream = Stream::new();
985                let input = factory_float_tensor(&[0.0, 1.0, 2.0, 3.0], device.into());
986                let output = cuda_full(&[2], 1.0);
987
988                let input = TensorCell::new(input);
989                let output = TensorCell::new(output);
990
991                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
992                comm.reduce_scatter_tensor(&output, &input, ReduceOp::Sum, &stream)
993                    .unwrap();
994                stream.synchronize();
995
996                let expected = match i {
997                    0 => factory_float_tensor(&[0.0, 2.0], device.into()),
998                    1 => factory_float_tensor(&[4.0, 6.0], device.into()),
999                    _ => unreachable!(),
1000                };
1001                assert!(allclose(&output.borrow(), &expected).unwrap());
1002            }));
1003        }
1004        for handle in handles {
1005            handle.join().unwrap();
1006        }
1007    }
1008
1009    #[test]
1010    fn split_from() {
1011        test_setup();
1012        let unique_id = UniqueId::new().unwrap();
1013        let mut handles = Vec::new();
1014        for i in 0..2 {
1015            let unique_id = unique_id.clone();
1016            handles.push(std::thread::spawn(move || {
1017                let device = CudaDevice::new(DeviceIndex(i));
1018                set_device(device).unwrap();
1019                let stream = Stream::new();
1020                let tensor = cuda_full(&[2, 2], 1.0);
1021                let cell = TensorCell::new(tensor);
1022                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
1023
1024                // Split a new comm with only rank 0
1025                let split_comm = comm.split_from(vec![0]).unwrap();
1026
1027                match i {
1028                    0 => assert!(split_comm.is_some()),
1029                    1 => assert!(split_comm.is_none()),
1030                    _ => unreachable!(),
1031                };
1032
1033                match i {
1034                    0 => {
1035                        split_comm
1036                            .unwrap()
1037                            .all_reduce(&cell, ReduceOp::Sum, &stream)
1038                            .unwrap();
1039                        stream.synchronize();
1040                        let expected = cuda_full(&[2, 2], 1.0);
1041                        assert!(allclose(&cell.borrow(), &expected).unwrap());
1042                    }
1043                    1 => (),
1044                    _ => unreachable!(),
1045                };
1046            }));
1047        }
1048        for handle in handles {
1049            handle.join().unwrap();
1050        }
1051    }
1052}