Skip to main content

torch_sys_cuda/
nccl_common.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Types shared between the real NCCL bindings (`nccl`) and the CPU-only stubs
10//! (`nccl_stubs`).
11//!
12//! These types have no dependency on the `nccl-sys` FFI and are always
13//! compiled. Keeping them here removes duplication and guarantees both code
14//! paths see an identical public API.
15
16use std::marker::PhantomData;
17
18use thiserror::Error;
19use torch_sys2::DeviceType;
20use torch_sys2::ScalarType;
21
22use crate::cuda::CudaError;
23
24/// Corresponds to ncclResult_t error cases.
25#[derive(Debug, Error)]
26pub enum RawNcclError {
27    #[error("a call to a CUDA function failed")]
28    UnhandledCudaError,
29    #[error("a call to the system failed")]
30    SystemError,
31    #[error("an internal check failed; either bug in nccl or memory corruption")]
32    InternalError,
33    #[error("an argument has an invalid value")]
34    InvalidArgument,
35    #[error("a call to NCCL is incorrect, usually a programming error")]
36    InvalidUsage,
37    #[error(
38        "a call failed possibly due to a network error or a remote process exiting prematurely"
39    )]
40    RemoteError,
41}
42
43/// Types of errors that the safe `Communicator` API can return.
44#[derive(Debug, Error)]
45pub enum NcclError {
46    #[error("a NCCL-level error: {0:?}")]
47    NcclError(#[from] RawNcclError),
48
49    #[error("a CUDA-level error: {0:?}")]
50    CudaError(#[from] CudaError),
51
52    #[error("invalid NCCL data type: {0:#?}")]
53    InvalidDataType(ScalarType),
54
55    #[error("tensor used in collective must be contiguous")]
56    NoncontiguousTensor,
57
58    // TODO would be nice to get real device printouts
59    #[error("tensor must be on CUDA device, got: {0:?}")]
60    InvalidDevice(DeviceType),
61
62    #[error("got sparse tensor, only dense tensors allowed")]
63    InvalidSparseTensor,
64
65    #[error("float8 dtypes are not currently supported for NCCL reductions")]
66    Float8Reduction,
67
68    #[error("output tensor must have the same type as input tensor")]
69    TypeMismatch,
70
71    #[error("output tensor size must be equal to world size times input tensor size")]
72    OutputSizeMismatch,
73
74    #[error("input tensor must be the same size as output size times world size")]
75    InputSizeMismatch,
76
77    #[error("ranks passed should be within the global world_size, got: {0:#?}")]
78    InvalidSplit(Vec<i32>),
79
80    #[error("undefined tensor used for NCCL operation")]
81    UndefinedTensor,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum NcclStatus {
86    /// Function succeeded.
87    Success,
88    /// A NCCL operation on the communicator is being enqueued and is being
89    /// progressed in the background.
90    InProgress,
91}
92
93/// A ticket that we use to link group start/end calls. Does not implement
94/// `Send`, to enforce that group start and end calls are on the same thread.
95// This isn't an RAII guard because ncclGroupEnd can raise errors.
96//
97// TODO: technically anyone can manufacture a ticket to pass to group_end. We
98// can prevent this by checking thread id or something, but seems unnecessary;
99// you'd really have to be trying to mess things up.
100pub struct NcclGroupTicket {
101    // marker to disable Send on this type.
102    pub(crate) unsend_marker: PhantomData<*const ()>,
103}
104
105/// Rust version of `ncclDataType_t`.
106#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum DataType {
108    Int8 = 0,
109    Uint8 = 1,
110    Int32 = 2,
111    Uint32 = 3,
112    Int64 = 4,
113    Uint64 = 5,
114    Float16 = 6,
115    Float32 = 7,
116    Float64 = 8,
117    Bfloat16 = 9,
118}
119
120impl TryFrom<ScalarType> for DataType {
121    type Error = NcclError;
122
123    fn try_from(value: ScalarType) -> Result<Self, Self::Error> {
124        match value {
125            ScalarType::Char => Ok(DataType::Int8),
126            ScalarType::Byte => Ok(DataType::Uint8),
127            ScalarType::Half => Ok(DataType::Float16),
128            ScalarType::Float => Ok(DataType::Float32),
129            ScalarType::Double => Ok(DataType::Float64),
130            ScalarType::Int => Ok(DataType::Int32),
131            ScalarType::Long => Ok(DataType::Int64),
132            ScalarType::Bool => Ok(DataType::Uint8),
133            ScalarType::BFloat16 => Ok(DataType::Bfloat16),
134            ScalarType::Float8_e5m2 => Ok(DataType::Uint8),
135            ScalarType::Float8_e4m3fn => Ok(DataType::Uint8),
136            ScalarType::Float8_e4m3fnuz => Ok(DataType::Uint8),
137            ScalarType::Float8_e5m2fnuz => Ok(DataType::Uint8),
138            _ => Err(NcclError::InvalidDataType(value)),
139        }
140    }
141}