nccl_sys/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use cxx::ExternType;
10use cxx::type_id;
11
12/// SAFETY: bindings
13unsafe impl ExternType for CUstream_st {
14    type Id = type_id!("CUstream_st");
15    type Kind = cxx::kind::Opaque;
16}
17
18/// SAFETY: bindings
19/// Trivial because this is POD struct
20unsafe impl ExternType for ncclConfig_t {
21    type Id = type_id!("ncclConfig_t");
22    type Kind = cxx::kind::Trivial;
23}
24
25/// SAFETY: bindings
26unsafe impl ExternType for ncclComm {
27    type Id = type_id!("ncclComm");
28    type Kind = cxx::kind::Opaque;
29}
30
31// When building with cargo, this is actually the lib.rs file for a crate.
32// Include the generated bindings.rs and suppress lints.
33#[allow(non_camel_case_types)]
34#[allow(non_upper_case_globals)]
35#[allow(non_snake_case)]
36mod inner {
37    use serde::Deserialize;
38    use serde::Deserializer;
39    use serde::Serialize;
40    use serde::Serializer;
41    use serde::ser::SerializeSeq;
42    #[cfg(cargo)]
43    include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
44
45    // This type is manually defined instead of generated because we want to dervice
46    // Serialize/Deserialize on it.
47    #[repr(C)]
48    #[derive(Debug, Copy, Clone, Serialize, Deserialize)]
49    pub struct ncclUniqueId {
50        // Custom serializer required, as serde does not provide a built-in
51        // implementation of serialization for large arrays.
52        #[serde(
53            serialize_with = "serialize_array",
54            deserialize_with = "deserialize_array"
55        )]
56        pub internal: [::std::os::raw::c_char; 128usize],
57    }
58
59    fn deserialize_array<'de, D>(deserializer: D) -> Result<[::std::os::raw::c_char; 128], D::Error>
60    where
61        D: Deserializer<'de>,
62    {
63        let vec: Vec<::std::os::raw::c_char> = Deserialize::deserialize(deserializer)?;
64        vec.try_into().map_err(|v: Vec<::std::os::raw::c_char>| {
65            serde::de::Error::invalid_length(v.len(), &"expected an array of length 128")
66        })
67    }
68
69    fn serialize_array<S>(
70        array: &[::std::os::raw::c_char; 128],
71        serializer: S,
72    ) -> Result<S::Ok, S::Error>
73    where
74        S: Serializer,
75    {
76        let mut seq = serializer.serialize_seq(Some(128))?;
77        for element in array {
78            seq.serialize_element(element)?;
79        }
80        seq.end()
81    }
82}
83
84pub use inner::*;
85
86#[cfg(test)]
87mod tests {
88    use std::mem::MaybeUninit;
89
90    use super::*;
91
92    #[test]
93    fn sanity() {
94        // SAFETY: testing bindings
95        unsafe {
96            let mut version = MaybeUninit::<i32>::uninit();
97            let result = ncclGetVersion(version.as_mut_ptr());
98            assert_eq!(result.0, 0);
99        }
100    }
101}