torch_sys/
tensor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt;
10
11use cxx::ExternType;
12use cxx::type_id;
13use pyo3::exceptions::PyValueError;
14use pyo3::prelude::*;
15use serde::Deserialize;
16use serde::Deserializer;
17use serde::Serialize;
18use serde::Serializer;
19use serde::de::Visitor;
20
21use crate::DeviceType;
22use crate::bridge::const_data_ptr;
23use crate::bridge::cpp_incref;
24use crate::bridge::ffi;
25use crate::bridge::ffi::copy_;
26use crate::bridge::ffi::load_tensor;
27use crate::bridge::ffi::repr;
28use crate::bridge::ffi::save_tensor;
29use crate::bridge::ffi::sizes;
30use crate::bridge::mut_data_ptr;
31use crate::cell::CloneUnsafe;
32
33/// Rust binding for the C++ type `at::Tensor`.
34///
35/// # Safety
36/// `Tensor` will properly manage the refcount of the underling `TensorImpl`.
37///
38/// `Tensor` is [`Send`]: it is safe to send across thread boundaries because
39/// the underlying C++ type is atomically refcounted.
40///
41/// `Tensor` is [`Sync`]: it can be shared across threads. The underlying C++
42/// type has interior mutability, (i.e. a `const Tensor&` can be used to mutate
43/// the tensor) but we are careful to expose Rust bindings that require
44/// exclusive access (ownership or mutable reference) for any C++ code that can
45/// mutate a tensor.
46#[repr(C)]
47pub struct Tensor {
48    /// This corresponds to impl_ in the C++ Tensor class.
49    repr: *mut std::ffi::c_void,
50}
51
52impl Drop for Tensor {
53    fn drop(&mut self) {
54        // Undefined tensors do not have their refcounts changed.
55        if self.defined() {
56            // SAFETY: decrement this tensor's refcount. This ptr is guaranteed to
57            // be non-null by the C++ side.
58            unsafe { crate::bridge::cpp_decref(self.repr) };
59        }
60    }
61}
62
63impl fmt::Debug for Tensor {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        f.debug_struct("Tensor").field("data", &"<...>").finish()
66    }
67}
68
69impl Tensor {
70    /// This is *unsafe* as it directly accesses the underlying data pointer.
71    /// Additionally, this should only be used when the user is sure the tensor
72    /// is defined.
73    pub unsafe fn data_ptr(&self) -> *const std::ffi::c_void {
74        // SAFETY: self.repr is guaranteed to be a non-null TensorImpl*
75        unsafe { const_data_ptr(self.repr) }
76    }
77
78    /// This is *unsafe* as it directly accesses the underlying data pointer.
79    /// Additionally, this should only be used when the user is sure the tensor
80    /// is defined.
81    pub unsafe fn mut_data_ptr(&self) -> *mut std::ffi::c_void {
82        // SAFETY: self.repr is guaranteed to be a non-null TensorImpl*
83        unsafe { mut_data_ptr(self.repr) }
84    }
85
86    /// Self-modify this tensor by copying data from another tensor. The other
87    /// tensor must be the same shape as this one.
88    pub fn copy_(&mut self, src: &Tensor) {
89        copy_(self, src);
90    }
91
92    /// Return the size of each dimension in this tensor.
93    pub fn sizes(&self) -> Vec<i32> {
94        sizes(self)
95    }
96
97    /// Alias of sizes.
98    pub fn shape(&self) -> Vec<i32> {
99        self.sizes()
100    }
101}
102
103impl CloneUnsafe for Tensor {
104    /// This is *unsafe*, it creates an alias of the underlying Tensor that is
105    /// not tracked by Rust. We use this to interface with C++ functions that
106    /// expect an `at::Tensor`.
107    ///
108    /// The contract for calling this function is that the clone is local and
109    /// ephemeral. More precisely:
110    /// 1. The clone must not be sent to another thread (local).
111    /// 2. You must guarantee that clone is dropped before the originating
112    ///    mutable reference is dropped (ephemeral).
113    unsafe fn clone_unsafe(&self) -> Self {
114        // Undefined tensors do not have their refcounts changed.
115        if self.defined() {
116            // SAFETY: increment this tensor's refcount. This ptr is guaranteed to
117            // be non-null by the C++ side.
118            unsafe { cpp_incref(self.repr) };
119        }
120        Tensor { repr: self.repr }
121    }
122}
123
124impl fmt::Display for Tensor {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        write!(f, "{}", repr(self))
127    }
128}
129
130impl PartialEq for Tensor {
131    fn eq(&self, other: &Self) -> bool {
132        self.equal(other)
133    }
134}
135
136// SAFETY: See safety discussion in [`Tensor`]
137unsafe impl Send for Tensor {}
138// SAFETY: See safety discussion in [`Tensor`]
139unsafe impl Sync for Tensor {}
140
141// SAFETY: Register our custom type implementation with cxx.
142// It is okay to mark as trivial, as Tensor is relocatable, see the discussion
143// in `bridge.h`.
144unsafe impl ExternType for Tensor {
145    type Id = type_id!("torch::Tensor");
146    type Kind = cxx::kind::Trivial;
147}
148
149// Simple serialize/desrialize impls for `Tensor` to support sending them over
150// the wire for e.g. `SendValue`.  Right now we just defer to C++'s `torch::save`
151// and `torch::load`, but there might be more efficient ways to do this.
152impl Serialize for Tensor {
153    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
154    where
155        S: Serializer,
156    {
157        // We see deadlocks in CUDA libs that appear to happen when attempting
158        // to save tensors from a thread that doesn't have the corresponding
159        // device active.  So, try to detect this and fail.
160        if self.device().device_type() != DeviceType::CPU {
161            return Err(serde::ser::Error::custom(format!(
162                "can only save CPU tensors (found {:?})",
163                self.device(),
164            )));
165        }
166
167        let bytes = save_tensor(self).map_err(serde::ser::Error::custom)?;
168        serializer.serialize_bytes(&bytes)
169    }
170}
171
172impl<'de> Deserialize<'de> for Tensor {
173    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
174    where
175        D: Deserializer<'de>,
176    {
177        struct TensorVisitor;
178
179        impl<'de> Visitor<'de> for TensorVisitor {
180            type Value = Tensor;
181
182            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
183                f.write_str("raw tensor bytes")
184            }
185
186            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
187            where
188                E: serde::de::Error,
189            {
190                load_tensor(v).map_err(E::custom)
191            }
192        }
193
194        deserializer.deserialize_bytes(TensorVisitor)
195    }
196}
197
198impl FromPyObject<'_> for Tensor {
199    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
200        ffi::tensor_from_py_object(obj.into()).map_err(|e| {
201            PyValueError::new_err(format!(
202                "Failed extracting {} from py as Tensor: {}",
203                obj, e
204            ))
205        })
206    }
207}
208
209impl<'py> IntoPyObject<'py> for Tensor {
210    type Target = PyAny;
211    type Output = Bound<'py, Self::Target>;
212    type Error = PyErr;
213
214    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
215        ffi::tensor_to_py_object(self).into_pyobject(py)
216    }
217}
218
219pub type TensorCell = crate::cell::AliasTrackingRefCell<Tensor>;
220
221impl TensorCell {
222    /// Return cell with the backing tensor on the CPU.  If the backing tensor
223    /// is on a GPU, it'll create a new Tensor/TensorCell with a copy of the
224    /// backing tensor.  If the tensor is already on the CPU, it'll just return
225    /// a this cell.
226    pub fn try_cpu(self) -> Result<TensorCell, atomic_refcell::BorrowError> {
227        {
228            let borrow = self.try_borrow()?;
229            if borrow.device().device_type() != DeviceType::CPU {
230                return Ok(TensorCell::new(borrow.cpu()));
231            }
232        }
233        Ok(self)
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use pyo3::prelude::*;
240
241    use crate::Tensor;
242    use crate::bridge::ffi::deep_clone;
243    use crate::bridge::ffi::test_make_tensor;
244
245    #[test]
246    fn partial_eq() {
247        let t1 = test_make_tensor();
248        let t2 = deep_clone(&t1);
249        assert_eq!(t1, t2);
250    }
251
252    #[test]
253    fn bincode_serialize() {
254        let t1 = test_make_tensor();
255        let buf = bincode::serialize(&t1).unwrap();
256        let t2: Tensor = bincode::deserialize(&buf).unwrap();
257        assert_eq!(t1, t2);
258    }
259
260    #[test]
261    fn multipart_serialize() {
262        let t1 = test_make_tensor();
263        let buf = serde_multipart::serialize_bincode(&t1).unwrap();
264        let t2: Tensor = serde_multipart::deserialize_bincode(buf).unwrap();
265        assert_eq!(t1, t2);
266    }
267
268    #[test]
269    fn convert_to_py_and_back() {
270        pyo3::prepare_freethreaded_python();
271        let tensor = test_make_tensor();
272        let converted = Python::with_gil(|py| {
273            // import torch to ensure torch.layout types are registered
274            py.import("torch").unwrap();
275            let obj = deep_clone(&tensor).into_pyobject(py).unwrap();
276            obj.extract::<Tensor>().unwrap()
277        });
278        assert_eq!(converted, tensor);
279    }
280}