1use std::fmt;
10
11use cxx::ExternType;
12use cxx::type_id;
13use pyo3::exceptions::PyValueError;
14use pyo3::prelude::*;
15use serde::Deserialize;
16use serde::Deserializer;
17use serde::Serialize;
18use serde::Serializer;
19
20use crate::DeviceType;
21use crate::bridge::const_data_ptr;
22use crate::bridge::cpp_incref;
23use crate::bridge::ffi;
24use crate::bridge::ffi::copy_;
25use crate::bridge::ffi::load_tensor;
26use crate::bridge::ffi::repr;
27use crate::bridge::ffi::save_tensor;
28use crate::bridge::ffi::sizes;
29use crate::bridge::mut_data_ptr;
30use crate::cell::CloneUnsafe;
31
32#[repr(C)]
46pub struct Tensor {
47 repr: *mut std::ffi::c_void,
49}
50
51impl Drop for Tensor {
52 fn drop(&mut self) {
53 if self.defined() {
55 unsafe { crate::bridge::cpp_decref(self.repr) };
58 }
59 }
60}
61
62impl fmt::Debug for Tensor {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 f.debug_struct("Tensor").field("data", &"<...>").finish()
65 }
66}
67
68impl Tensor {
69 pub unsafe fn data_ptr(&self) -> *const std::ffi::c_void {
73 unsafe { const_data_ptr(self.repr) }
75 }
76
77 pub unsafe fn mut_data_ptr(&self) -> *mut std::ffi::c_void {
81 unsafe { mut_data_ptr(self.repr) }
83 }
84
85 pub fn copy_(&mut self, src: &Tensor) {
88 copy_(self, src);
89 }
90
91 pub fn sizes(&self) -> Vec<i32> {
93 sizes(self)
94 }
95
96 pub fn shape(&self) -> Vec<i32> {
98 self.sizes()
99 }
100}
101
102impl CloneUnsafe for Tensor {
103 unsafe fn clone_unsafe(&self) -> Self {
113 if self.defined() {
115 unsafe { cpp_incref(self.repr) };
118 }
119 Tensor { repr: self.repr }
120 }
121}
122
123impl fmt::Display for Tensor {
124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125 write!(f, "{}", repr(self))
126 }
127}
128
129impl PartialEq for Tensor {
130 fn eq(&self, other: &Self) -> bool {
131 self.equal(other)
132 }
133}
134
135unsafe impl Send for Tensor {}
137unsafe impl Sync for Tensor {}
139
140unsafe impl ExternType for Tensor {
144 type Id = type_id!("torch::Tensor");
145 type Kind = cxx::kind::Trivial;
146}
147
148impl Serialize for Tensor {
152 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
153 where
154 S: Serializer,
155 {
156 if self.device().device_type() != DeviceType::CPU {
160 return Err(serde::ser::Error::custom(format!(
161 "can only save CPU tensors (found {:?})",
162 self.device(),
163 )));
164 }
165
166 let bytes = save_tensor(self).map_err(serde::ser::Error::custom)?;
167 serializer.serialize_bytes(&bytes)
168 }
169}
170
171impl<'de> Deserialize<'de> for Tensor {
172 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
173 where
174 D: Deserializer<'de>,
175 {
176 let buf: &[u8] = Deserialize::deserialize(deserializer)?;
177 let tensor = load_tensor(buf).map_err(serde::de::Error::custom)?;
178 Ok(tensor)
179 }
180}
181
182impl FromPyObject<'_> for Tensor {
183 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
184 ffi::tensor_from_py_object(obj.into()).map_err(|e| {
185 PyValueError::new_err(format!(
186 "Failed extracting {} from py as Tensor: {}",
187 obj, e
188 ))
189 })
190 }
191}
192
193impl<'py> IntoPyObject<'py> for Tensor {
194 type Target = PyAny;
195 type Output = Bound<'py, Self::Target>;
196 type Error = PyErr;
197
198 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
199 ffi::tensor_to_py_object(self).into_pyobject(py)
200 }
201}
202
203pub type TensorCell = crate::cell::AliasTrackingRefCell<Tensor>;
204
205impl TensorCell {
206 pub fn try_cpu(self) -> Result<TensorCell, atomic_refcell::BorrowError> {
211 {
212 let borrow = self.try_borrow()?;
213 if borrow.device().device_type() != DeviceType::CPU {
214 return Ok(TensorCell::new(borrow.cpu()));
215 }
216 }
217 Ok(self)
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use anyhow::Result;
224 use pyo3::prelude::*;
225
226 use crate::Tensor;
227 use crate::bridge::ffi::deep_clone;
228 use crate::bridge::ffi::test_make_tensor;
229
230 #[test]
231 fn partial_eq() -> Result<()> {
232 let t1 = test_make_tensor();
233 let t2 = deep_clone(&t1);
234 assert_eq!(t1, t2);
235 Ok(())
236 }
237
238 #[test]
239 fn serialize() -> Result<()> {
240 let t1 = test_make_tensor();
241 let buf = bincode::serialize(&t1)?;
242 let t2: Tensor = bincode::deserialize(&buf)?;
243 assert_eq!(t1, t2);
244 Ok(())
245 }
246
247 #[test]
248 fn convert_to_py_and_back() {
249 pyo3::prepare_freethreaded_python();
250 let tensor = test_make_tensor();
251 let converted = Python::with_gil(|py| {
252 py.import("torch").unwrap();
254 let obj = deep_clone(&tensor).into_pyobject(py).unwrap();
255 obj.extract::<Tensor>().unwrap()
256 });
257 assert_eq!(converted, tensor);
258 }
259}