nccl_sys/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use cxx::ExternType;
10use cxx::type_id;
11
12/// SAFETY: bindings
13unsafe impl ExternType for CUstream_st {
14    type Id = type_id!("CUstream_st");
15    type Kind = cxx::kind::Opaque;
16}
17
18/// SAFETY: bindings
19/// Trivial because this is POD struct
20unsafe impl ExternType for ncclConfig_t {
21    type Id = type_id!("ncclConfig_t");
22    type Kind = cxx::kind::Trivial;
23}
24
25/// SAFETY: bindings
26unsafe impl ExternType for ncclComm {
27    type Id = type_id!("ncclComm");
28    type Kind = cxx::kind::Opaque;
29}
30
31// When building with cargo, this is actually the lib.rs file for a crate.
32// Include the generated bindings.rs and suppress lints.
33#[allow(non_camel_case_types)]
34#[allow(non_upper_case_globals)]
35#[allow(non_snake_case)]
36mod inner {
37    use serde::Deserialize;
38    use serde::Deserializer;
39    use serde::Serialize;
40    use serde::Serializer;
41    use serde::ser::SerializeSeq;
42    #[cfg(cargo)]
43    include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
44
45    // This type is manually defined instead of generated because we want to dervice
46    // Serialize/Deserialize on it.
47    #[repr(C)]
48    #[derive(Debug, Copy, Clone, Serialize, Deserialize)]
49    pub struct ncclUniqueId {
50        // Custom serializer required, as serde does not provide a built-in
51        // implementation of serialization for large arrays.
52        #[serde(
53            serialize_with = "serialize_array",
54            deserialize_with = "deserialize_array"
55        )]
56        pub internal: [::std::os::raw::c_char; 128usize],
57    }
58
59    fn deserialize_array<'de, D>(deserializer: D) -> Result<[i8; 128], D::Error>
60    where
61        D: Deserializer<'de>,
62    {
63        let vec: Vec<i8> = Deserialize::deserialize(deserializer)?;
64        vec.try_into().map_err(|v: Vec<i8>| {
65            serde::de::Error::invalid_length(v.len(), &"expected an array of length 128")
66        })
67    }
68
69    fn serialize_array<S>(array: &[i8; 128], serializer: S) -> Result<S::Ok, S::Error>
70    where
71        S: Serializer,
72    {
73        let mut seq = serializer.serialize_seq(Some(128))?;
74        for element in array {
75            seq.serialize_element(element)?;
76        }
77        seq.end()
78    }
79}
80
81pub use inner::*;
82
83#[cfg(test)]
84mod tests {
85    use std::mem::MaybeUninit;
86
87    use super::*;
88
89    #[test]
90    fn sanity() {
91        // SAFETY: testing bindings
92        unsafe {
93            let mut version = MaybeUninit::<i32>::uninit();
94            let result = ncclGetVersion(version.as_mut_ptr());
95            assert_eq!(result.0, 0);
96        }
97    }
98}