torch_sys_cuda/
nccl.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::ffi::CString;
10use std::fmt;
11use std::fmt::Write;
12use std::hash::Hasher;
13use std::marker::PhantomData;
14use std::mem::MaybeUninit;
15
16use fxhash::FxHasher32;
17use nccl_sys::*;
18use serde::Deserialize;
19use serde::Serialize;
20use thiserror::Error;
21use torch_sys::CudaDevice;
22use torch_sys::DeviceType;
23use torch_sys::ScalarType;
24use torch_sys::Tensor;
25use torch_sys::TensorCell;
26use torch_sys::factory_float_tensor;
27use torch_sys::is_float8_type;
28use torch_sys::suggest_memory_format;
29
30use crate::bridge::ffi::make_nccl_config;
31use crate::cuda::CudaError;
32use crate::cuda::Stream;
33use crate::cuda::set_device;
34
35/// Corresponds to ncclResult_t error cases
36#[derive(Debug, Error)]
37pub enum RawNcclError {
38    #[error("a call to a CUDA function failed")]
39    UnhandledCudaError,
40    #[error("a call to the system failed")]
41    SystemError,
42    #[error("an internal check failed; either bug in nccl or memory corruption")]
43    InternalError,
44    #[error("an argument has an invalid value")]
45    InvalidArgument,
46    #[error("a call to NCCL is incorrect, usually a programming error")]
47    InvalidUsage,
48    #[error(
49        "a call failed possibly due to a network error or a remote process exiting prematurely"
50    )]
51    RemoteError,
52}
53
54/// Types of errors that the safe [`Communicator`] API can return.
55#[derive(Debug, Error)]
56pub enum NcclError {
57    #[error("a NCCL-level error: {0:?}")]
58    NcclError(#[from] RawNcclError),
59
60    #[error("a CUDA-level error: {0:?}")]
61    CudaError(#[from] CudaError),
62
63    #[error("invalid NCCL data type: {0:#?}")]
64    InvalidDataType(ScalarType),
65
66    #[error("tensor used in collective must be contiguous")]
67    NoncontiguousTensor,
68
69    // TODO would be nice to get real device printouts
70    #[error("tensor must be on CUDA device, got: {0:?}")]
71    InvalidDevice(DeviceType),
72
73    #[error("got sparse tensor, only dense tensors allowed")]
74    InvalidSparseTensor,
75
76    #[error("float8 dtypes are not currently supported for NCCL reductions")]
77    Float8Reduction,
78
79    #[error("output tensor must have the same type as input tensor")]
80    TypeMismatch,
81
82    #[error("output tensor size must be equal to world size times input tensor size")]
83    OutputSizeMismatch,
84
85    #[error("input tensor must be the same size as output size times world size")]
86    InputSizeMismatch,
87
88    #[error("ranks passed should be within the global world_size, got: {0:#?}")]
89    InvalidSplit(Vec<i32>),
90
91    #[error("undefined tensor used for NCCL operation")]
92    UndefinedTensor,
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub enum NcclStatus {
97    /// Function succeeded.
98    Success,
99    /// A NCCL operation on the communicator is being enqueued and is being
100    /// progressed in the background.
101    InProgress,
102}
103
104/// Rust version of ncclConfig_t. See nccl documentation for what each field
105/// means:
106/// https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
107///
108/// Note that we don't validate field values; we rely on nccl to do that.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct NcclConfig {
111    pub blocking: bool,
112    pub cga_cluster_size: u8,
113    pub min_ctas: u8,
114    pub max_ctas: u8,
115    pub net_name: Option<String>,
116    pub split_share: bool,
117}
118
119impl Default for NcclConfig {
120    fn default() -> Self {
121        NcclConfig {
122            blocking: true,
123            cga_cluster_size: 4,
124            min_ctas: 1,
125            max_ctas: 32,
126            net_name: None,
127            split_share: false,
128        }
129    }
130}
131
132impl From<NcclConfig> for ncclConfig_t {
133    fn from(config: NcclConfig) -> Self {
134        let mut ret = make_nccl_config();
135        ret.blocking = config.blocking.into();
136        ret.cgaClusterSize = config.cga_cluster_size.into();
137        ret.minCTAs = config.min_ctas.into();
138        ret.maxCTAs = config.max_ctas.into();
139        if let Some(net_name) = config.net_name {
140            let c_string = CString::new(net_name)
141                .expect("failed to create CString")
142                .into_boxed_c_str();
143
144            // Just leak the string to avoid complicated ownership issues. I'm
145            // not aware of anywhere where we actually want to specify the
146            // network module name in configuration instead of letting nccl just
147            // choose it for us. If this happens + we are creating tons of
148            // config objects, we can revisit this.
149            let ptr = Box::leak(c_string).as_ptr();
150            ret.netName = ptr;
151        }
152        ret.splitShare = config.split_share.into();
153
154        ret
155    }
156}
157
158fn nccl_check(result: ncclResult_t) -> Result<NcclStatus, RawNcclError> {
159    match result.0 {
160        0 => Ok(NcclStatus::Success),
161        1 => Err(RawNcclError::UnhandledCudaError),
162        2 => Err(RawNcclError::SystemError),
163        3 => Err(RawNcclError::InternalError),
164        4 => Err(RawNcclError::InvalidArgument),
165        5 => Err(RawNcclError::InvalidUsage),
166        6 => Err(RawNcclError::RemoteError),
167        7 => Ok(NcclStatus::InProgress),
168        _ => panic!("Unknown ncclResult_t: {:?}", result.0),
169    }
170}
171
172/// A ticket that we use to link group start/end calls. Does not implement
173/// `Send`, to enforce that group start and end calls are on the same thread.
174// This isn't an RAII guard because ncclGroupEnd can raise errors.
175//
176// TODO: technically anyone can manufacture a ticket to pass to group_end. We
177// can prevent this by checking thread id or something, but seems unnecessary;
178// you'd really have to be trying to mess things up.
179pub struct NcclGroupTicket {
180    // marker to disable Send on this type.
181    unsend_marker: PhantomData<*const ()>,
182}
183
184/// Start a new NCCL group. All NCCL calls within this group will be combined,
185/// provided that they were issued on the same thread.
186pub fn group_start() -> Result<NcclGroupTicket, NcclError> {
187    // SAFETY: intended use of C function.
188    nccl_check(unsafe { ncclGroupStart() })?;
189    Ok(NcclGroupTicket {
190        unsend_marker: PhantomData,
191    })
192}
193
194/// End the NCCL group.
195pub fn group_end(_ticket: NcclGroupTicket) -> Result<(), NcclError> {
196    // SAFETY: intended use of C function.
197    nccl_check(unsafe { ncclGroupEnd() })?;
198    Ok(())
199}
200
201/// Binding for `ncclUniqueId`.
202#[derive(Clone, Serialize, Deserialize)]
203pub struct UniqueId {
204    inner: ncclUniqueId,
205}
206
207impl fmt::Debug for UniqueId {
208    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
209        f.debug_struct("UniqueId")
210            .field(
211                "inner",
212                &format_args!(
213                    "{}",
214                    self.inner
215                        .internal
216                        .iter()
217                        .fold(String::new(), |mut output, b| {
218                            let _ = write!(output, "{:02x}", b);
219                            output
220                        })
221                ),
222            )
223            .finish()
224    }
225}
226
227impl UniqueId {
228    /// Create a new `UniqueId`.
229    pub fn new() -> Result<Self, RawNcclError> {
230        let mut inner = MaybeUninit::uninit();
231        // Safety: intended usage of this function
232        let inner = unsafe {
233            nccl_check(ncclGetUniqueId(inner.as_mut_ptr()))?;
234            inner.assume_init()
235        };
236        Ok(Self { inner })
237    }
238}
239
240/// Rust version of `ncclDataType_t`.
241#[derive(Debug, Clone, Copy, PartialEq, Eq)]
242pub enum DataType {
243    Int8 = 0,
244    Uint8 = 1,
245    Int32 = 2,
246    Uint32 = 3,
247    Int64 = 4,
248    Uint64 = 5,
249    Float16 = 6,
250    Float32 = 7,
251    Float64 = 8,
252    Bfloat16 = 9,
253}
254
255impl From<DataType> for ncclDataType_t {
256    fn from(data_type: DataType) -> Self {
257        Self(data_type as std::os::raw::c_uint)
258    }
259}
260
261impl TryFrom<ScalarType> for DataType {
262    type Error = NcclError;
263
264    fn try_from(value: ScalarType) -> Result<Self, Self::Error> {
265        match value {
266            ScalarType::Char => Ok(DataType::Int8),
267            ScalarType::Byte => Ok(DataType::Uint8),
268            ScalarType::Half => Ok(DataType::Float16),
269            ScalarType::Float => Ok(DataType::Float32),
270            ScalarType::Double => Ok(DataType::Float64),
271            ScalarType::Int => Ok(DataType::Int32),
272            ScalarType::Long => Ok(DataType::Int64),
273            ScalarType::Bool => Ok(DataType::Uint8),
274            ScalarType::BFloat16 => Ok(DataType::Bfloat16),
275            ScalarType::Float8_e5m2 => Ok(DataType::Uint8),
276            ScalarType::Float8_e4m3fn => Ok(DataType::Uint8),
277            ScalarType::Float8_e4m3fnuz => Ok(DataType::Uint8),
278            ScalarType::Float8_e5m2fnuz => Ok(DataType::Uint8),
279            _ => Err(NcclError::InvalidDataType(value)),
280        }
281    }
282}
283
284/// Rust version of `ncclRedOp_t`.
285#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
286pub enum ReduceOp {
287    Sum = 0,
288    Prod = 1,
289    Max = 2,
290    Min = 3,
291    Avg = 4,
292}
293
294impl From<ReduceOp> for ncclRedOp_t {
295    fn from(reduce_op: ReduceOp) -> Self {
296        Self(reduce_op as std::os::raw::c_uint)
297    }
298}
299
300fn check_tensor(tensor: &Tensor, is_p2p: bool) -> Result<(), NcclError> {
301    if !tensor.defined() {
302        return Err(NcclError::UndefinedTensor);
303    }
304    if !tensor.is_cuda() {
305        return Err(NcclError::InvalidDevice(tensor.device().device_type()));
306    }
307    if tensor.is_sparse() {
308        return Err(NcclError::InvalidSparseTensor);
309    }
310
311    if !is_p2p && !tensor.is_contiguous(suggest_memory_format(tensor)) {
312        return Err(NcclError::NoncontiguousTensor);
313    }
314
315    Ok(())
316}
317/// Wraps a NCCL communicator, and provides a Tensor-based interface it.
318///
319/// This implements a subset of the `c10d::ProcessGroup`API.
320#[derive(Debug)]
321pub struct Communicator {
322    inner: ncclComm_t,
323    // World size of this communicator.
324    world_size: i32,
325    // Rank within the world, value is within 0..world_size.
326    rank: i32,
327    // Size of the global world. This can be different from `world_size` if this
328    // communicator was split off from a larger one.
329    global_world_size: i32,
330    global_rank: i32,
331    device: CudaDevice,
332}
333
334/// SAFETY: `ncclComm_t` is okay to access from multiple threads, but each
335/// communicator *must* issue nccl calls in the same order. It is up to the user
336/// to ensure this.
337unsafe impl Send for Communicator {}
338/// SAFETY: `ncclComm_t` is okay to access from multiple threads, but each
339/// communicator *must* issue nccl calls in the same order. It is up to the user
340/// to ensure this.
341unsafe impl Sync for Communicator {}
342
343// Ported from: https://github.com/pytorch/pytorch/blob/0d6d29af380d6a639bf23127f05e439fafa640bf/torch/distributed/distributed_c10d.py#L4669
344fn calculate_color(ranks: &[i32]) -> i32 {
345    // Assumes `ranks` is sorted.
346    let mut hasher = FxHasher32::default();
347    ranks.iter().for_each(|r| hasher.write_i32(*r));
348    let hash = hasher.finish();
349
350    // Convert to positive value to fit as color arg to `ncclCommSplit`.
351    (hash % (i32::MAX as u64)) as i32
352}
353
354impl Communicator {
355    /// Create a new communicator. This function must be called by a different
356    /// thread/process per rank.
357    pub fn new(
358        device: CudaDevice,
359        world_size: i32,
360        unique_id: UniqueId,
361        rank: i32,
362    ) -> Result<Self, NcclError> {
363        set_device(device)?;
364        let mut inner = MaybeUninit::uninit();
365        // SAFETY: intended use of C function
366        let inner = unsafe {
367            nccl_check(ncclCommInitRank(
368                inner.as_mut_ptr(),
369                world_size,
370                unique_id.inner,
371                rank,
372            ))?;
373            inner.assume_init()
374        };
375        Ok(Self {
376            inner,
377            world_size,
378            rank,
379            global_rank: rank,
380            global_world_size: world_size,
381            device,
382        })
383    }
384
385    /// Split off a new communicator from this one, preserving the same world
386    /// size.
387    pub fn split_all(&mut self, config: Option<NcclConfig>) -> Result<Self, NcclError> {
388        let ranks = (0..self.global_world_size).collect();
389        Ok(self.split_from(ranks, config)?.unwrap())
390    }
391
392    /// Split off a new communicator from this one. Only `ranks` will be present
393    /// on this new communicator.
394    ///
395    /// If `ranks` is empty, `ncclCommSplit` will be called with
396    /// NCCL_SPLIT_NOCOLOR. This can be useful if ranks excluded from the split
397    /// don't even know what ranks will be included.
398    pub fn split_from(
399        &mut self,
400        mut ranks: Vec<i32>,
401        config: Option<NcclConfig>,
402    ) -> Result<Option<Self>, NcclError> {
403        ranks.sort();
404        for rank in &ranks {
405            if *rank < 0 || *rank >= self.global_world_size {
406                return Err(NcclError::InvalidSplit(ranks));
407            }
408        }
409
410        let color = match ranks.binary_search(&self.rank) {
411            Ok(_) => calculate_color(ranks.as_slice()),
412            Err(_) => NCCL_SPLIT_NOCOLOR,
413        };
414
415        let config = config.map(ncclConfig_t::from);
416        let mut new = MaybeUninit::uninit();
417
418        // SAFETY: intended use of C function
419        let new = unsafe {
420            // This rather awkward duplication is intentional; we are passing in
421            // `config` as a pointer, which is only guaranteed to be valid for
422            // the duration of `Some(mut config)` match arm.
423            match config {
424                Some(mut config) => {
425                    nccl_check(ncclCommSplit(
426                        self.inner,
427                        color,
428                        self.rank,
429                        new.as_mut_ptr(),
430                        &mut config,
431                    ))?;
432                }
433                None => {
434                    nccl_check(ncclCommSplit(
435                        self.inner,
436                        color,
437                        self.rank,
438                        new.as_mut_ptr(),
439                        std::ptr::null_mut(),
440                    ))?;
441                }
442            }
443            new.assume_init()
444        };
445
446        let group_rank = ranks.iter().position(|v| *v == self.rank);
447        match color {
448            NCCL_SPLIT_NOCOLOR => Ok(None),
449            _ => Ok(Some(Self {
450                inner: new,
451                world_size: ranks.len() as i32,
452                rank: group_rank.unwrap() as i32,
453                global_rank: self.global_rank,
454                global_world_size: self.global_world_size,
455                device: self.device,
456            })),
457        }
458    }
459
460    /// Reduce the tensor data across all ranks, with each rank receiving the
461    /// final result in-place.
462    ///
463    /// See `torch.distributed.all_reduce` for more detailed documentation.
464    pub fn all_reduce(
465        &mut self,
466        tensor: &TensorCell,
467        reduce_op: ReduceOp,
468        stream: &Stream,
469    ) -> Result<NcclStatus, NcclError> {
470        let tensor = tensor.borrow_mut();
471        let data_type: DataType = tensor.scalar_type().try_into()?;
472
473        check_tensor(&tensor, false)?;
474        if is_float8_type(tensor.scalar_type()) {
475            return Err(NcclError::Float8Reduction);
476        }
477        // SAFETY: intended use of C function
478        unsafe {
479            Ok(nccl_check(ncclAllReduce(
480                tensor.data_ptr(),
481                tensor.mut_data_ptr(),
482                tensor.numel() as usize,
483                data_type.into(),
484                reduce_op.into(),
485                self.inner,
486                stream.stream(),
487            ))?)
488        }
489    }
490
491    /// Broadcast the tensor data on the `root` rank to all the others.
492    ///
493    /// See `torch.distributed.broadcast` for more detailed documentation.
494    pub fn broadcast(
495        &mut self,
496        tensor: &TensorCell,
497        root: i32,
498        stream: &Stream,
499    ) -> Result<NcclStatus, NcclError> {
500        let tensor = tensor.borrow_mut();
501        check_tensor(&tensor, false)?;
502        let data_type: DataType = tensor.scalar_type().try_into()?;
503        // SAFETY: intended use of C function
504        unsafe {
505            Ok(nccl_check(ncclBroadcast(
506                tensor.data_ptr(),
507                tensor.mut_data_ptr(),
508                tensor.numel() as usize,
509                data_type.into(),
510                root,
511                self.inner,
512                stream.stream(),
513            ))?)
514        }
515    }
516
517    /// Reduce the tensor data across all ranks, writing the result out to
518    /// tensor on the `root` rank.
519    ///
520    /// See `torch.distributed.reduce` for more detailed documentation.
521    pub fn reduce(
522        &mut self,
523        tensor: &TensorCell,
524        reduce_op: ReduceOp,
525        root: i32,
526        stream: &Stream,
527    ) -> Result<NcclStatus, NcclError> {
528        let tensor = tensor.borrow_mut();
529        check_tensor(&tensor, false)?;
530        if is_float8_type(tensor.scalar_type()) {
531            return Err(NcclError::Float8Reduction);
532        }
533        let data_type: DataType = tensor.scalar_type().try_into()?;
534        // SAFETY: intended use of C function
535        unsafe {
536            Ok(nccl_check(ncclReduce(
537                tensor.data_ptr(),
538                tensor.mut_data_ptr(),
539                tensor.numel() as usize,
540                data_type.into(),
541                reduce_op.into(),
542                root,
543                self.inner,
544                stream.stream(),
545            ))?)
546        }
547    }
548
549    /// Gather tensors from all ranks into a list of output tensors.
550    ///
551    /// See `torch.distributed.all_gather` for more detailed documentation.
552    pub fn all_gather(
553        &mut self,
554        output_cells: &[TensorCell],
555        input_cell: &TensorCell,
556        stream: &Stream,
557    ) -> Result<NcclStatus, NcclError> {
558        let output = output_cells
559            .iter()
560            .map(|t| t.borrow_mut())
561            .collect::<Vec<_>>();
562        let input = input_cell.borrow();
563        check_tensor(&input, false)?;
564        let output_type = output[0].scalar_type();
565        let output_numel: i64 = output.iter().map(|t| t.numel()).sum();
566        for t in &output {
567            if t.scalar_type() != output_type {
568                return Err(NcclError::TypeMismatch);
569            }
570        }
571        if input.scalar_type() != output_type {
572            return Err(NcclError::TypeMismatch);
573        }
574        if input.numel() * self.world_size as i64 != output_numel {
575            return Err(NcclError::OutputSizeMismatch);
576        }
577        let data_type: DataType = input.scalar_type().try_into()?;
578        // TODO: optimization if the output list are all the same shape, where
579        // a single allGather can be done.
580        // SAFETY: intended use of C function
581        unsafe {
582            nccl_check(ncclGroupStart())?;
583            for (i, output) in output.iter().enumerate() {
584                // auto& input = (i == rank_) ? inputTensor : output;
585                let rank = i as i32;
586                let output_ptr = output.mut_data_ptr();
587                // If the current rank is the sender, we need to broadcast the input tensor.
588                // Everything else just broadcasts the output tensor.
589                if rank == self.rank {
590                    nccl_check(ncclBroadcast(
591                        input.data_ptr(),
592                        output_ptr,
593                        input.numel() as usize,
594                        data_type.into(),
595                        rank,
596                        self.inner,
597                        stream.stream(),
598                    ))?;
599                } else {
600                    nccl_check(ncclBroadcast(
601                        output_ptr,
602                        output_ptr,
603                        output.numel() as usize,
604                        data_type.into(),
605                        rank,
606                        self.inner,
607                        stream.stream(),
608                    ))?;
609                }
610            }
611            nccl_check(ncclGroupEnd())?;
612        }
613        Ok(NcclStatus::Success)
614    }
615
616    /// Gather tensors from all ranks into a single output tensor.
617    ///
618    /// See `torch.distributed.all_gather_into_tensor` for more detailed
619    /// documentation.
620    pub fn all_gather_into_tensor(
621        &mut self,
622        output_cell: &TensorCell,
623        input_cell: &TensorCell,
624        stream: &Stream,
625    ) -> Result<NcclStatus, NcclError> {
626        let output = output_cell.borrow_mut();
627        let _input_borrow = if input_cell.aliases(output_cell) {
628            None
629        } else {
630            Some(input_cell.borrow())
631        };
632        // SAFETY: we either borrowed above or borrowed an alias
633        let input = unsafe { input_cell.get_unchecked() };
634        check_tensor(&output, false)?;
635        check_tensor(input, false)?;
636        if input.scalar_type() != output.scalar_type() {
637            return Err(NcclError::TypeMismatch);
638        }
639        if input.numel() * self.world_size as i64 != output.numel() {
640            return Err(NcclError::OutputSizeMismatch);
641        }
642
643        let data_type: DataType = input.scalar_type().try_into()?;
644        // SAFETY: intended use of C function
645        unsafe {
646            Ok(nccl_check(ncclAllGather(
647                input.data_ptr(),
648                output.mut_data_ptr(),
649                input.numel() as usize,
650                data_type.into(),
651                self.inner,
652                stream.stream(),
653            ))?)
654        }
655    }
656
657    /// Reduce, then scatters the result to all tensors in the group.
658    ///
659    /// See `torch.distributed.reduce_scatter_tensor` for more detailed
660    /// documentation.
661    pub fn reduce_scatter_tensor(
662        &mut self,
663        output_cell: &TensorCell,
664        input_cell: &TensorCell,
665        reduce_op: ReduceOp,
666        stream: &Stream,
667    ) -> Result<NcclStatus, NcclError> {
668        let output = output_cell.borrow_mut();
669        let _input_borrow = if input_cell.aliases(output_cell) {
670            None
671        } else {
672            Some(input_cell.borrow())
673        };
674
675        // SAFETY: we either borrowed above or borrowed an alias
676        let input = unsafe { input_cell.get_unchecked() }; // SAFETY: intended use of C function
677
678        check_tensor(&output, false)?;
679        check_tensor(input, false)?;
680        if input.scalar_type() != output.scalar_type() {
681            return Err(NcclError::TypeMismatch);
682        }
683        if input.numel() != output.numel() * self.world_size as i64 {
684            return Err(NcclError::InputSizeMismatch);
685        }
686        if is_float8_type(input.scalar_type()) {
687            return Err(NcclError::Float8Reduction);
688        }
689
690        let data_type: DataType = input.scalar_type().try_into()?;
691        // SAFETY: intended use of C function
692        unsafe {
693            Ok(nccl_check(ncclReduceScatter(
694                input.data_ptr(),
695                output.mut_data_ptr(),
696                output.numel() as usize,
697                data_type.into(),
698                reduce_op.into(),
699                self.inner,
700                stream.stream(),
701            ))?)
702        }
703    }
704
705    /// Send a tensor to the rank `dst`.
706    pub fn send(
707        &mut self,
708        tensor_cell: &TensorCell,
709        dst: i32,
710        stream: &Stream,
711    ) -> Result<NcclStatus, NcclError> {
712        let tensor = tensor_cell.borrow();
713        let data_type: DataType = tensor.scalar_type().try_into()?;
714
715        check_tensor(&tensor, true)?;
716
717        // SAFETY: intended use of C function
718        unsafe {
719            Ok(nccl_check(ncclSend(
720                tensor.data_ptr(),
721                tensor.numel() as usize,
722                data_type.into(),
723                dst,
724                self.inner,
725                stream.stream(),
726            ))?)
727        }
728    }
729
730    /// Receive a tensor from the rank `src`.
731    pub fn recv(
732        &mut self,
733        tensor_cell: &TensorCell,
734        src: i32,
735        stream: &Stream,
736    ) -> Result<NcclStatus, NcclError> {
737        let tensor = tensor_cell.borrow_mut();
738        let data_type: DataType = tensor.scalar_type().try_into()?;
739
740        check_tensor(&tensor, true)?;
741
742        // SAFETY: intended use of C function
743        unsafe {
744            Ok(nccl_check(ncclRecv(
745                tensor.mut_data_ptr(),
746                tensor.numel() as usize,
747                data_type.into(),
748                src,
749                self.inner,
750                stream.stream(),
751            ))?)
752        }
753    }
754
755    /// Split the input tensor then scatter the split list to all processes in
756    /// the group. The received splits are then concatenated into the output tensor.
757    ///
758    /// See `torch.distributed.all_to_all_single` for more detailed
759    /// documentation.
760    pub fn all_to_all_single(
761        &mut self,
762        output_cell: &TensorCell,
763        input_cell: &TensorCell,
764        stream: &Stream,
765    ) -> Result<NcclStatus, NcclError> {
766        let output = output_cell.borrow_mut();
767        let _input_borrow = if input_cell.aliases(output_cell) {
768            None
769        } else {
770            Some(input_cell.borrow_mut())
771        };
772        // SAFETY: we either borrowed above or borrowed an alias
773        let input = unsafe { input_cell.get_unchecked() };
774
775        check_tensor(&output, false)?;
776        check_tensor(input, false)?;
777        if input.scalar_type() != output.scalar_type() {
778            return Err(NcclError::TypeMismatch);
779        }
780
781        let data_type: DataType = input.scalar_type().try_into()?;
782        let count = input.numel() as usize / self.world_size as usize;
783        let rank_stride = input.nbytes() as isize / self.world_size as isize;
784        // SAFETY: intended use of C functions
785        unsafe {
786            let send_buff = input.data_ptr();
787            let recv_buff = output.mut_data_ptr();
788
789            nccl_check(ncclGroupStart())?;
790            for r in 0..self.world_size {
791                nccl_check(ncclSend(
792                    send_buff.offset(r as isize * rank_stride),
793                    count,
794                    data_type.into(),
795                    r,
796                    self.inner,
797                    stream.stream(),
798                ))?;
799                nccl_check(ncclRecv(
800                    recv_buff.offset(r as isize * rank_stride),
801                    count,
802                    data_type.into(),
803                    r,
804                    self.inner,
805                    stream.stream(),
806                ))?;
807            }
808
809            nccl_check(ncclGroupEnd())?;
810        };
811        Ok(NcclStatus::Success)
812    }
813
814    /// Synchronize all ranks.
815    ///
816    /// See `torch.distributed.barrier` for more detailed documentation.
817    pub fn barrier(&mut self, stream: &Stream) -> Result<NcclStatus, NcclError> {
818        let tensor = factory_float_tensor(&[1.0], self.device.into());
819        let data_type: DataType = tensor.scalar_type().try_into()?;
820
821        // NOTE(agallagher): NCCL doesn't have a native barrier impl, so use
822        // `ncclAllReduce` to implement one.
823        // SAFETY: intended use of C function
824        unsafe {
825            Ok(nccl_check(ncclAllReduce(
826                tensor.data_ptr(),
827                tensor.mut_data_ptr(),
828                tensor.numel() as usize,
829                data_type.into(),
830                ReduceOp::Sum.into(),
831                self.inner,
832                stream.stream(),
833            ))?)
834        }
835    }
836}
837
838#[cfg(test)]
839mod tests {
840    use torch_sys::CudaDevice;
841    use torch_sys::DeviceIndex;
842    use torch_sys::factory_float_tensor;
843    use torch_sys::testing::allclose;
844    use torch_sys::testing::cuda_full;
845    use torch_sys::testing::stack;
846
847    use super::*;
848    use crate::cuda::set_device;
849
850    #[test]
851    fn all_reduce() {
852        let unique_id = UniqueId::new().unwrap();
853        let mut handles = Vec::new();
854        for i in 0..2 {
855            let unique_id = unique_id.clone();
856            handles.push(std::thread::spawn(move || {
857                let device = CudaDevice::new(DeviceIndex(i));
858                set_device(device).unwrap();
859                let stream = Stream::new();
860                let tensor = cuda_full(&[2, 2], 1.0);
861                let expected = cuda_full(&[2, 2], 2.0);
862
863                let cell = TensorCell::new(tensor);
864                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
865                comm.all_reduce(&cell, ReduceOp::Sum, &stream).unwrap();
866                stream.synchronize();
867                assert!(allclose(&cell.borrow(), &expected).unwrap());
868            }));
869        }
870        for handle in handles {
871            handle.join().unwrap();
872        }
873    }
874
875    #[test]
876    fn broadcast() {
877        let unique_id = UniqueId::new().unwrap();
878        let mut handles = Vec::new();
879        for i in 0..2 {
880            let unique_id = unique_id.clone();
881            handles.push(std::thread::spawn(move || {
882                let device = CudaDevice::new(DeviceIndex(i));
883                set_device(device).unwrap();
884                let stream = Stream::new();
885                let tensor = cuda_full(&[2, 2], i as f32);
886
887                let cell = TensorCell::new(tensor);
888                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
889                comm.broadcast(&cell, 1, &stream).unwrap();
890                stream.synchronize();
891                assert!(allclose(&cell.borrow(), &cuda_full(&[2, 2], 1.0)).unwrap());
892            }));
893        }
894        for handle in handles {
895            handle.join().unwrap();
896        }
897    }
898
899    #[test]
900    fn reduce() {
901        let unique_id = UniqueId::new().unwrap();
902        let mut handles = Vec::new();
903        for i in 0..2 {
904            let unique_id = unique_id.clone();
905            handles.push(std::thread::spawn(move || {
906                let device = CudaDevice::new(DeviceIndex(i));
907                set_device(device).unwrap();
908                let stream = Stream::new();
909                let tensor = cuda_full(&[2, 2], 2.0);
910
911                let cell = TensorCell::new(tensor);
912                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
913                comm.reduce(&cell, ReduceOp::Sum, 0, &stream).unwrap();
914                stream.synchronize();
915                match i {
916                    0 => assert!(allclose(&cell.borrow(), &cuda_full(&[2, 2], 4.0)).unwrap()),
917                    1 => assert!(allclose(&cell.borrow(), &cuda_full(&[2, 2], 2.0)).unwrap()),
918                    _ => unreachable!(),
919                }
920            }));
921        }
922        for handle in handles {
923            handle.join().unwrap();
924        }
925    }
926
927    #[test]
928    fn all_gather_into_tensor() {
929        let unique_id = UniqueId::new().unwrap();
930        let mut handles = Vec::new();
931        for i in 0..2 {
932            let unique_id = unique_id.clone();
933            handles.push(std::thread::spawn(move || {
934                let device = CudaDevice::new(DeviceIndex(i));
935                set_device(device).unwrap();
936                let stream = Stream::new();
937                let input_tensor = cuda_full(&[2, 2], i as f32);
938                let output_tensor = cuda_full(&[2, 2, 2], 0.0);
939
940                let expected = {
941                    let mut tensor_list = Vec::new();
942                    for i in 0..2 {
943                        tensor_list.push(cuda_full(&[2, 2], i as f32));
944                    }
945                    stack(&tensor_list)
946                };
947                let input_cell = TensorCell::new(input_tensor);
948                let output_cell = TensorCell::new(output_tensor);
949                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
950                comm.all_gather_into_tensor(&output_cell, &input_cell, &stream)
951                    .unwrap();
952                stream.synchronize();
953                assert!(allclose(&output_cell.borrow(), &expected).unwrap());
954            }));
955        }
956        for handle in handles {
957            handle.join().unwrap();
958        }
959    }
960
961    #[test]
962    fn send_recv() {
963        let unique_id = UniqueId::new().unwrap();
964        let mut handles = Vec::new();
965        let unique_id_ = unique_id.clone();
966        handles.push(std::thread::spawn(move || {
967            let device = CudaDevice::new(DeviceIndex(0));
968            set_device(device).unwrap();
969            let stream = Stream::new();
970            let tensor = cuda_full(&[2, 2], 0.0);
971
972            let cell = TensorCell::new(tensor);
973            let mut comm = Communicator::new(device, 2, unique_id_, 0).unwrap();
974            comm.send(&cell, 1, &stream).unwrap();
975            stream.synchronize();
976        }));
977        let unique_id_ = unique_id.clone();
978        handles.push(std::thread::spawn(move || {
979            let device = CudaDevice::new(DeviceIndex(1));
980            set_device(device).unwrap();
981            let stream = Stream::new();
982            let tensor = cuda_full(&[2, 2], 1.1);
983            let expected = cuda_full(&[2, 2], 0.0);
984
985            let cell = TensorCell::new(tensor);
986            let mut comm = Communicator::new(device, 2, unique_id_, 1).unwrap();
987            comm.recv(&cell, 0, &stream).unwrap();
988            stream.synchronize();
989            assert!(allclose(&cell.borrow(), &expected).unwrap());
990        }));
991        for handle in handles {
992            handle.join().unwrap();
993        }
994    }
995
996    #[test]
997    fn all_to_all_single() {
998        let unique_id = UniqueId::new().unwrap();
999        let mut handles = Vec::new();
1000        for i in 0..2 {
1001            let unique_id = unique_id.clone();
1002            handles.push(std::thread::spawn(move || {
1003                let device = CudaDevice::new(DeviceIndex(i));
1004                set_device(device).unwrap();
1005                let stream = Stream::new();
1006                let input = match i {
1007                    0 => factory_float_tensor(&[0.0, 1.0], device.into()),
1008                    1 => factory_float_tensor(&[2.0, 3.0], device.into()),
1009                    _ => unreachable!(),
1010                };
1011                let output = cuda_full(&[2], 0.0);
1012
1013                let input = TensorCell::new(input);
1014                let output = TensorCell::new(output);
1015
1016                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
1017                comm.all_to_all_single(&output, &input, &stream).unwrap();
1018                stream.synchronize();
1019
1020                let expected = match i {
1021                    0 => factory_float_tensor(&[0.0, 2.0], device.into()),
1022                    1 => factory_float_tensor(&[1.0, 3.0], device.into()),
1023                    _ => unreachable!(),
1024                };
1025                assert!(allclose(&output.borrow(), &expected).unwrap());
1026            }));
1027        }
1028        for handle in handles {
1029            handle.join().unwrap();
1030        }
1031    }
1032
1033    #[test]
1034    fn reduce_scatter_tensor() {
1035        let unique_id = UniqueId::new().unwrap();
1036        let mut handles = Vec::new();
1037        for i in 0..2 {
1038            let unique_id = unique_id.clone();
1039            handles.push(std::thread::spawn(move || {
1040                let device = CudaDevice::new(DeviceIndex(i));
1041                set_device(device).unwrap();
1042                let stream = Stream::new();
1043                let input = factory_float_tensor(&[0.0, 1.0, 2.0, 3.0], device.into());
1044                let output = cuda_full(&[2], 1.0);
1045
1046                let input = TensorCell::new(input);
1047                let output = TensorCell::new(output);
1048
1049                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
1050                comm.reduce_scatter_tensor(&output, &input, ReduceOp::Sum, &stream)
1051                    .unwrap();
1052                stream.synchronize();
1053
1054                let expected = match i {
1055                    0 => factory_float_tensor(&[0.0, 2.0], device.into()),
1056                    1 => factory_float_tensor(&[4.0, 6.0], device.into()),
1057                    _ => unreachable!(),
1058                };
1059                assert!(allclose(&output.borrow(), &expected).unwrap());
1060            }));
1061        }
1062        for handle in handles {
1063            handle.join().unwrap();
1064        }
1065    }
1066
1067    #[test]
1068    fn split_from() {
1069        let unique_id = UniqueId::new().unwrap();
1070        let mut handles = Vec::new();
1071        for i in 0..2 {
1072            let unique_id = unique_id.clone();
1073            handles.push(std::thread::spawn(move || {
1074                let device = CudaDevice::new(DeviceIndex(i));
1075                set_device(device).unwrap();
1076                let stream = Stream::new();
1077                let tensor = cuda_full(&[2, 2], 1.0);
1078                let cell = TensorCell::new(tensor);
1079                let mut comm = Communicator::new(device, 2, unique_id, i.into()).unwrap();
1080
1081                // Split a new comm with only rank 0
1082                let split_comm = comm.split_from(vec![0], None).unwrap();
1083
1084                match i {
1085                    0 => assert!(split_comm.is_some()),
1086                    1 => assert!(split_comm.is_none()),
1087                    _ => unreachable!(),
1088                };
1089
1090                match i {
1091                    0 => {
1092                        split_comm
1093                            .unwrap()
1094                            .all_reduce(&cell, ReduceOp::Sum, &stream)
1095                            .unwrap();
1096                        stream.synchronize();
1097                        let expected = cuda_full(&[2, 2], 1.0);
1098                        assert!(allclose(&cell.borrow(), &expected).unwrap());
1099                    }
1100                    1 => (),
1101                    _ => unreachable!(),
1102                };
1103            }));
1104        }
1105        for handle in handles {
1106            handle.join().unwrap();
1107        }
1108    }
1109}