1use std::fmt;
10
11use cxx::ExternType;
12use cxx::type_id;
13use pyo3::exceptions::PyValueError;
14use pyo3::prelude::*;
15use serde::Deserialize;
16use serde::Deserializer;
17use serde::Serialize;
18use serde::Serializer;
19use serde::de::Visitor;
20
21use crate::DeviceType;
22use crate::bridge::const_data_ptr;
23use crate::bridge::cpp_incref;
24use crate::bridge::ffi;
25use crate::bridge::ffi::copy_;
26use crate::bridge::ffi::load_tensor;
27use crate::bridge::ffi::repr;
28use crate::bridge::ffi::save_tensor;
29use crate::bridge::ffi::sizes;
30use crate::bridge::mut_data_ptr;
31use crate::cell::CloneUnsafe;
32
33#[repr(C)]
47pub struct Tensor {
48 repr: *mut std::ffi::c_void,
50}
51
52impl Drop for Tensor {
53 fn drop(&mut self) {
54 if self.defined() {
56 unsafe { crate::bridge::cpp_decref(self.repr) };
59 }
60 }
61}
62
63impl fmt::Debug for Tensor {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 f.debug_struct("Tensor").field("data", &"<...>").finish()
66 }
67}
68
69impl Tensor {
70 pub unsafe fn data_ptr(&self) -> *const std::ffi::c_void {
74 unsafe { const_data_ptr(self.repr) }
76 }
77
78 pub unsafe fn mut_data_ptr(&self) -> *mut std::ffi::c_void {
82 unsafe { mut_data_ptr(self.repr) }
84 }
85
86 pub fn copy_(&mut self, src: &Tensor) {
89 copy_(self, src);
90 }
91
92 pub fn sizes(&self) -> Vec<i32> {
94 sizes(self)
95 }
96
97 pub fn shape(&self) -> Vec<i32> {
99 self.sizes()
100 }
101}
102
103impl CloneUnsafe for Tensor {
104 unsafe fn clone_unsafe(&self) -> Self {
114 if self.defined() {
116 unsafe { cpp_incref(self.repr) };
119 }
120 Tensor { repr: self.repr }
121 }
122}
123
124impl fmt::Display for Tensor {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 write!(f, "{}", repr(self))
127 }
128}
129
130impl PartialEq for Tensor {
131 fn eq(&self, other: &Self) -> bool {
132 self.equal(other)
133 }
134}
135
136unsafe impl Send for Tensor {}
138unsafe impl Sync for Tensor {}
140
141unsafe impl ExternType for Tensor {
145 type Id = type_id!("torch::Tensor");
146 type Kind = cxx::kind::Trivial;
147}
148
149impl Serialize for Tensor {
153 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
154 where
155 S: Serializer,
156 {
157 if self.device().device_type() != DeviceType::CPU {
161 return Err(serde::ser::Error::custom(format!(
162 "can only save CPU tensors (found {:?})",
163 self.device(),
164 )));
165 }
166
167 let bytes = save_tensor(self).map_err(serde::ser::Error::custom)?;
168 serializer.serialize_bytes(&bytes)
169 }
170}
171
172impl<'de> Deserialize<'de> for Tensor {
173 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
174 where
175 D: Deserializer<'de>,
176 {
177 struct TensorVisitor;
178
179 impl<'de> Visitor<'de> for TensorVisitor {
180 type Value = Tensor;
181
182 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
183 f.write_str("raw tensor bytes")
184 }
185
186 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
187 where
188 E: serde::de::Error,
189 {
190 load_tensor(v).map_err(E::custom)
191 }
192 }
193
194 deserializer.deserialize_bytes(TensorVisitor)
195 }
196}
197
198impl FromPyObject<'_> for Tensor {
199 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
200 ffi::tensor_from_py_object(obj.into()).map_err(|e| {
201 PyValueError::new_err(format!(
202 "Failed extracting {} from py as Tensor: {}",
203 obj, e
204 ))
205 })
206 }
207}
208
209impl<'py> IntoPyObject<'py> for Tensor {
210 type Target = PyAny;
211 type Output = Bound<'py, Self::Target>;
212 type Error = PyErr;
213
214 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
215 ffi::tensor_to_py_object(self).into_pyobject(py)
216 }
217}
218
219pub type TensorCell = crate::cell::AliasTrackingRefCell<Tensor>;
220
221impl TensorCell {
222 pub fn try_cpu(self) -> Result<TensorCell, atomic_refcell::BorrowError> {
227 {
228 let borrow = self.try_borrow()?;
229 if borrow.device().device_type() != DeviceType::CPU {
230 return Ok(TensorCell::new(borrow.cpu()));
231 }
232 }
233 Ok(self)
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use pyo3::prelude::*;
240
241 use crate::Tensor;
242 use crate::bridge::ffi::deep_clone;
243 use crate::bridge::ffi::test_make_tensor;
244
245 #[test]
246 fn partial_eq() {
247 let t1 = test_make_tensor();
248 let t2 = deep_clone(&t1);
249 assert_eq!(t1, t2);
250 }
251
252 #[test]
253 fn bincode_serialize() {
254 let t1 = test_make_tensor();
255 let buf = bincode::serialize(&t1).unwrap();
256 let t2: Tensor = bincode::deserialize(&buf).unwrap();
257 assert_eq!(t1, t2);
258 }
259
260 #[test]
261 fn multipart_serialize() {
262 let t1 = test_make_tensor();
263 let buf = serde_multipart::serialize_bincode(&t1).unwrap();
264 let t2: Tensor = serde_multipart::deserialize_bincode(buf).unwrap();
265 assert_eq!(t1, t2);
266 }
267
268 #[test]
269 fn convert_to_py_and_back() {
270 pyo3::prepare_freethreaded_python();
271 let tensor = test_make_tensor();
272 let converted = Python::with_gil(|py| {
273 py.import("torch").unwrap();
275 let obj = deep_clone(&tensor).into_pyobject(py).unwrap();
276 obj.extract::<Tensor>().unwrap()
277 });
278 assert_eq!(converted, tensor);
279 }
280}