nccl_sys/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use cxx::ExternType;
10use cxx::type_id;
11
12/// SAFETY: bindings
13unsafe impl ExternType for CUstream_st {
14    type Id = type_id!("CUstream_st");
15    type Kind = cxx::kind::Opaque;
16}
17
18/// SAFETY: bindings
19unsafe impl ExternType for ncclComm {
20    type Id = type_id!("ncclComm");
21    type Kind = cxx::kind::Opaque;
22}
23
24// When building with cargo, this is actually the lib.rs file for a crate.
25// Include the generated bindings.rs and suppress lints.
26#[allow(non_camel_case_types)]
27#[allow(non_upper_case_globals)]
28#[allow(non_snake_case)]
29mod inner {
30    use serde::Deserialize;
31    use serde::Deserializer;
32    use serde::Serialize;
33    use serde::Serializer;
34    use serde::ser::SerializeSeq;
35    #[cfg(cargo)]
36    include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
37
38    // This type is manually defined instead of generated because we want to derive
39    // Serialize/Deserialize on it.
40    #[repr(C)]
41    #[derive(Debug, Copy, Clone, Serialize, Deserialize)]
42    pub struct ncclUniqueId {
43        // Custom serializer required, as serde does not provide a built-in
44        // implementation of serialization for large arrays.
45        #[serde(
46            serialize_with = "serialize_array",
47            deserialize_with = "deserialize_array"
48        )]
49        pub internal: [::std::os::raw::c_char; 128usize],
50    }
51
52    fn deserialize_array<'de, D>(deserializer: D) -> Result<[::std::os::raw::c_char; 128], D::Error>
53    where
54        D: Deserializer<'de>,
55    {
56        let vec: Vec<::std::os::raw::c_char> = Deserialize::deserialize(deserializer)?;
57        vec.try_into().map_err(|v: Vec<::std::os::raw::c_char>| {
58            serde::de::Error::invalid_length(v.len(), &"expected an array of length 128")
59        })
60    }
61
62    fn serialize_array<S>(
63        array: &[::std::os::raw::c_char; 128],
64        serializer: S,
65    ) -> Result<S::Ok, S::Error>
66    where
67        S: Serializer,
68    {
69        let mut seq = serializer.serialize_seq(Some(128))?;
70        for element in array {
71            seq.serialize_element(element)?;
72        }
73        seq.end()
74    }
75}
76
77pub use inner::*;
78
79#[cfg(test)]
80mod tests {
81    use std::mem::MaybeUninit;
82
83    use super::*;
84
85    #[test]
86    fn sanity() {
87        // SAFETY: testing bindings
88        unsafe {
89            let mut version = MaybeUninit::<i32>::uninit();
90            let result = ncclGetVersion(version.as_mut_ptr());
91            assert_eq!(result.0, 0);
92        }
93    }
94}