1use cxx::ExternType;
10use cxx::type_id;
11
12#[cfg(not(use_rocm))]
13mod extern_types {
14 use super::*;
15
16 unsafe impl ExternType for CUstream_st {
18 type Id = type_id!("CUstream_st");
19 type Kind = cxx::kind::Opaque;
20 }
21
22 unsafe impl ExternType for ncclComm {
24 type Id = type_id!("ncclComm");
25 type Kind = cxx::kind::Opaque;
26 }
27}
28
29#[cfg(use_rocm)]
30mod extern_types {
31 use super::inner::ihipStream_t;
32 use super::inner::ncclComm;
33 use super::*;
34
35 unsafe impl ExternType for ihipStream_t {
38 type Id = type_id!("ihipStream_t");
39 type Kind = cxx::kind::Opaque;
40 }
41
42 unsafe impl ExternType for ncclComm {
44 type Id = type_id!("ncclComm");
45 type Kind = cxx::kind::Opaque;
46 }
47}
48
49#[allow(non_camel_case_types)]
52#[allow(non_upper_case_globals)]
53#[allow(non_snake_case)]
54mod inner {
55 use serde::Deserialize;
56 use serde::Deserializer;
57 use serde::Serialize;
58 use serde::Serializer;
59 use serde::ser::SerializeSeq;
60 #[cfg(cargo)]
61 include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
62
63 #[repr(C)]
66 #[derive(Debug, Copy, Clone, Serialize, Deserialize)]
67 pub struct ncclUniqueId {
68 #[serde(
71 serialize_with = "serialize_array",
72 deserialize_with = "deserialize_array"
73 )]
74 pub internal: [::std::os::raw::c_char; 128usize],
75 }
76
77 fn deserialize_array<'de, D>(deserializer: D) -> Result<[::std::os::raw::c_char; 128], D::Error>
78 where
79 D: Deserializer<'de>,
80 {
81 let vec: Vec<::std::os::raw::c_char> = Deserialize::deserialize(deserializer)?;
82 vec.try_into().map_err(|v: Vec<::std::os::raw::c_char>| {
83 serde::de::Error::invalid_length(v.len(), &"expected an array of length 128")
84 })
85 }
86
87 fn serialize_array<S>(
88 array: &[::std::os::raw::c_char; 128],
89 serializer: S,
90 ) -> Result<S::Ok, S::Error>
91 where
92 S: Serializer,
93 {
94 let mut seq = serializer.serialize_seq(Some(128))?;
95 for element in array {
96 seq.serialize_element(element)?;
97 }
98 seq.end()
99 }
100}
101
102pub use inner::*;
104
105#[cfg(use_rocm)]
107pub use self::rocm_compat::*;
108
109#[cfg(use_rocm)]
110#[allow(non_camel_case_types)]
111mod rocm_compat {
112 use super::inner;
113
114 pub type cudaError_t = inner::hipError_t;
119 pub type cudaStream_t = inner::hipStream_t;
120 pub type CUstream_st = inner::ihipStream_t;
121
122 pub use inner::hipSetDevice as cudaSetDevice;
124 pub use inner::hipStreamSynchronize as cudaStreamSynchronize;
125}
126
127#[cfg(test)]
128mod tests {
129 use std::mem::MaybeUninit;
130
131 use super::*;
132
133 #[test]
134 fn sanity() {
135 unsafe {
137 let mut version = MaybeUninit::<i32>::uninit();
138 let result = ncclGetVersion(version.as_mut_ptr());
139 assert_eq!(result.0, 0);
140 }
141 }
142}