nccl_sys/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use cxx::ExternType;
10use cxx::type_id;
11
12#[cfg(not(use_rocm))]
13mod extern_types {
14    use super::*;
15
16    /// SAFETY: bindings
17    unsafe impl ExternType for CUstream_st {
18        type Id = type_id!("CUstream_st");
19        type Kind = cxx::kind::Opaque;
20    }
21
22    /// SAFETY: bindings
23    unsafe impl ExternType for ncclComm {
24        type Id = type_id!("ncclComm");
25        type Kind = cxx::kind::Opaque;
26    }
27}
28
29#[cfg(use_rocm)]
30mod extern_types {
31    use super::inner::ihipStream_t;
32    use super::inner::ncclComm;
33    use super::*;
34
35    /// SAFETY: bindings
36    /// Note: HIP uses ihipStream_t as the opaque type behind hipStream_t pointer
37    unsafe impl ExternType for ihipStream_t {
38        type Id = type_id!("ihipStream_t");
39        type Kind = cxx::kind::Opaque;
40    }
41
42    /// SAFETY: bindings
43    unsafe impl ExternType for ncclComm {
44        type Id = type_id!("ncclComm");
45        type Kind = cxx::kind::Opaque;
46    }
47}
48
49// When building with cargo, this is actually the lib.rs file for a crate.
50// Include the generated bindings.rs and suppress lints.
51#[allow(non_camel_case_types)]
52#[allow(non_upper_case_globals)]
53#[allow(non_snake_case)]
54mod inner {
55    use serde::Deserialize;
56    use serde::Deserializer;
57    use serde::Serialize;
58    use serde::Serializer;
59    use serde::ser::SerializeSeq;
60    #[cfg(cargo)]
61    include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
62
63    // This type is manually defined instead of generated because we want to derive
64    // Serialize/Deserialize on it.
65    #[repr(C)]
66    #[derive(Debug, Copy, Clone, Serialize, Deserialize)]
67    pub struct ncclUniqueId {
68        // Custom serializer required, as serde does not provide a built-in
69        // implementation of serialization for large arrays.
70        #[serde(
71            serialize_with = "serialize_array",
72            deserialize_with = "deserialize_array"
73        )]
74        pub internal: [::std::os::raw::c_char; 128usize],
75    }
76
77    fn deserialize_array<'de, D>(deserializer: D) -> Result<[::std::os::raw::c_char; 128], D::Error>
78    where
79        D: Deserializer<'de>,
80    {
81        let vec: Vec<::std::os::raw::c_char> = Deserialize::deserialize(deserializer)?;
82        vec.try_into().map_err(|v: Vec<::std::os::raw::c_char>| {
83            serde::de::Error::invalid_length(v.len(), &"expected an array of length 128")
84        })
85    }
86
87    fn serialize_array<S>(
88        array: &[::std::os::raw::c_char; 128],
89        serializer: S,
90    ) -> Result<S::Ok, S::Error>
91    where
92        S: Serializer,
93    {
94        let mut seq = serializer.serialize_seq(Some(128))?;
95        for element in array {
96            seq.serialize_element(element)?;
97        }
98        seq.end()
99    }
100}
101
102// Export all inner bindings for both CUDA and ROCm builds
103pub use inner::*;
104
105// For ROCm: also export compatibility aliases that map CUDA names to HIP
106#[cfg(use_rocm)]
107pub use self::rocm_compat::*;
108
109#[cfg(use_rocm)]
110#[allow(non_camel_case_types)]
111mod rocm_compat {
112    use super::inner;
113
114    // ROCm/HIP compatibility layer
115    //
116    // Hipify converts CUDA APIs to HIP in C++ code, causing bindgen to generate HIP types.
117    // These aliases map CUDA names back to their HIP equivalents for Rust code compatibility.
118    pub type cudaError_t = inner::hipError_t;
119    pub type cudaStream_t = inner::hipStream_t;
120    pub type CUstream_st = inner::ihipStream_t;
121
122    // Function aliases - hipify converts cudaSetDevice -> hipSetDevice, etc.
123    pub use inner::hipSetDevice as cudaSetDevice;
124    pub use inner::hipStreamSynchronize as cudaStreamSynchronize;
125}
126
127#[cfg(test)]
128mod tests {
129    use std::mem::MaybeUninit;
130
131    use super::*;
132
133    #[test]
134    fn sanity() {
135        // SAFETY: testing bindings
136        unsafe {
137            let mut version = MaybeUninit::<i32>::uninit();
138            let result = ncclGetVersion(version.as_mut_ptr());
139            assert_eq!(result.0, 0);
140        }
141    }
142}