1use cxx::ExternType;
10use cxx::type_id;
11
12unsafe impl ExternType for CUstream_st {
14 type Id = type_id!("CUstream_st");
15 type Kind = cxx::kind::Opaque;
16}
17
18unsafe impl ExternType for ncclConfig_t {
21 type Id = type_id!("ncclConfig_t");
22 type Kind = cxx::kind::Trivial;
23}
24
25unsafe impl ExternType for ncclComm {
27 type Id = type_id!("ncclComm");
28 type Kind = cxx::kind::Opaque;
29}
30
31#[allow(non_camel_case_types)]
34#[allow(non_upper_case_globals)]
35#[allow(non_snake_case)]
36mod inner {
37 use serde::Deserialize;
38 use serde::Deserializer;
39 use serde::Serialize;
40 use serde::Serializer;
41 use serde::ser::SerializeSeq;
42 #[cfg(cargo)]
43 include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
44
45 #[repr(C)]
48 #[derive(Debug, Copy, Clone, Serialize, Deserialize)]
49 pub struct ncclUniqueId {
50 #[serde(
53 serialize_with = "serialize_array",
54 deserialize_with = "deserialize_array"
55 )]
56 pub internal: [::std::os::raw::c_char; 128usize],
57 }
58
59 fn deserialize_array<'de, D>(deserializer: D) -> Result<[i8; 128], D::Error>
60 where
61 D: Deserializer<'de>,
62 {
63 let vec: Vec<i8> = Deserialize::deserialize(deserializer)?;
64 vec.try_into().map_err(|v: Vec<i8>| {
65 serde::de::Error::invalid_length(v.len(), &"expected an array of length 128")
66 })
67 }
68
69 fn serialize_array<S>(array: &[i8; 128], serializer: S) -> Result<S::Ok, S::Error>
70 where
71 S: Serializer,
72 {
73 let mut seq = serializer.serialize_seq(Some(128))?;
74 for element in array {
75 seq.serialize_element(element)?;
76 }
77 seq.end()
78 }
79}
80
81pub use inner::*;
82
83#[cfg(test)]
84mod tests {
85 use std::mem::MaybeUninit;
86
87 use super::*;
88
89 #[test]
90 fn sanity() {
91 unsafe {
93 let mut version = MaybeUninit::<i32>::uninit();
94 let result = ncclGetVersion(version.as_mut_ptr());
95 assert_eq!(result.0, 0);
96 }
97 }
98}