torch_sys_cuda/
nccl_common.rs1use std::marker::PhantomData;
17
18use thiserror::Error;
19use torch_sys2::DeviceType;
20use torch_sys2::ScalarType;
21
22use crate::cuda::CudaError;
23
24#[derive(Debug, Error)]
26pub enum RawNcclError {
27 #[error("a call to a CUDA function failed")]
28 UnhandledCudaError,
29 #[error("a call to the system failed")]
30 SystemError,
31 #[error("an internal check failed; either bug in nccl or memory corruption")]
32 InternalError,
33 #[error("an argument has an invalid value")]
34 InvalidArgument,
35 #[error("a call to NCCL is incorrect, usually a programming error")]
36 InvalidUsage,
37 #[error(
38 "a call failed possibly due to a network error or a remote process exiting prematurely"
39 )]
40 RemoteError,
41}
42
43#[derive(Debug, Error)]
45pub enum NcclError {
46 #[error("a NCCL-level error: {0:?}")]
47 NcclError(#[from] RawNcclError),
48
49 #[error("a CUDA-level error: {0:?}")]
50 CudaError(#[from] CudaError),
51
52 #[error("invalid NCCL data type: {0:#?}")]
53 InvalidDataType(ScalarType),
54
55 #[error("tensor used in collective must be contiguous")]
56 NoncontiguousTensor,
57
58 #[error("tensor must be on CUDA device, got: {0:?}")]
60 InvalidDevice(DeviceType),
61
62 #[error("got sparse tensor, only dense tensors allowed")]
63 InvalidSparseTensor,
64
65 #[error("float8 dtypes are not currently supported for NCCL reductions")]
66 Float8Reduction,
67
68 #[error("output tensor must have the same type as input tensor")]
69 TypeMismatch,
70
71 #[error("output tensor size must be equal to world size times input tensor size")]
72 OutputSizeMismatch,
73
74 #[error("input tensor must be the same size as output size times world size")]
75 InputSizeMismatch,
76
77 #[error("ranks passed should be within the global world_size, got: {0:#?}")]
78 InvalidSplit(Vec<i32>),
79
80 #[error("undefined tensor used for NCCL operation")]
81 UndefinedTensor,
82}
83
84#[derive(Debug, Clone, Copy, PartialEq, Eq)]
85pub enum NcclStatus {
86 Success,
88 InProgress,
91}
92
93pub struct NcclGroupTicket {
101 pub(crate) unsend_marker: PhantomData<*const ()>,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum DataType {
108 Int8 = 0,
109 Uint8 = 1,
110 Int32 = 2,
111 Uint32 = 3,
112 Int64 = 4,
113 Uint64 = 5,
114 Float16 = 6,
115 Float32 = 7,
116 Float64 = 8,
117 Bfloat16 = 9,
118}
119
120impl TryFrom<ScalarType> for DataType {
121 type Error = NcclError;
122
123 fn try_from(value: ScalarType) -> Result<Self, Self::Error> {
124 match value {
125 ScalarType::Char => Ok(DataType::Int8),
126 ScalarType::Byte => Ok(DataType::Uint8),
127 ScalarType::Half => Ok(DataType::Float16),
128 ScalarType::Float => Ok(DataType::Float32),
129 ScalarType::Double => Ok(DataType::Float64),
130 ScalarType::Int => Ok(DataType::Int32),
131 ScalarType::Long => Ok(DataType::Int64),
132 ScalarType::Bool => Ok(DataType::Uint8),
133 ScalarType::BFloat16 => Ok(DataType::Bfloat16),
134 ScalarType::Float8_e5m2 => Ok(DataType::Uint8),
135 ScalarType::Float8_e4m3fn => Ok(DataType::Uint8),
136 ScalarType::Float8_e4m3fnuz => Ok(DataType::Uint8),
137 ScalarType::Float8_e5m2fnuz => Ok(DataType::Uint8),
138 _ => Err(NcclError::InvalidDataType(value)),
139 }
140 }
141}