torch_sys/
tensor.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::fmt;
10
11use cxx::ExternType;
12use cxx::type_id;
13use pyo3::exceptions::PyValueError;
14use pyo3::prelude::*;
15use serde::Deserialize;
16use serde::Deserializer;
17use serde::Serialize;
18use serde::Serializer;
19
20use crate::DeviceType;
21use crate::bridge::const_data_ptr;
22use crate::bridge::cpp_incref;
23use crate::bridge::ffi;
24use crate::bridge::ffi::copy_;
25use crate::bridge::ffi::load_tensor;
26use crate::bridge::ffi::repr;
27use crate::bridge::ffi::save_tensor;
28use crate::bridge::ffi::sizes;
29use crate::bridge::mut_data_ptr;
30use crate::cell::CloneUnsafe;
31
32/// Rust binding for the C++ type `at::Tensor`.
33///
34/// # Safety
35/// `Tensor` will properly manage the refcount of the underling `TensorImpl`.
36///
37/// `Tensor` is [`Send`]: it is safe to send across thread boundaries because
38/// the underlying C++ type is atomically refcounted.
39///
40/// `Tensor` is [`Sync`]: it can be shared across threads. The underlying C++
41/// type has interior mutability, (i.e. a `const Tensor&` can be used to mutate
42/// the tensor) but we are careful to expose Rust bindings that require
43/// exclusive access (ownership or mutable reference) for any C++ code that can
44/// mutate a tensor.
45#[repr(C)]
46pub struct Tensor {
47    /// This corresponds to impl_ in the C++ Tensor class.
48    repr: *mut std::ffi::c_void,
49}
50
51impl Drop for Tensor {
52    fn drop(&mut self) {
53        // Undefined tensors do not have their refcounts changed.
54        if self.defined() {
55            // SAFETY: decrement this tensor's refcount. This ptr is guaranteed to
56            // be non-null by the C++ side.
57            unsafe { crate::bridge::cpp_decref(self.repr) };
58        }
59    }
60}
61
62impl fmt::Debug for Tensor {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        f.debug_struct("Tensor").field("data", &"<...>").finish()
65    }
66}
67
68impl Tensor {
69    /// This is *unsafe* as it directly accesses the underlying data pointer.
70    /// Additionally, this should only be used when the user is sure the tensor
71    /// is defined.
72    pub unsafe fn data_ptr(&self) -> *const std::ffi::c_void {
73        // SAFETY: self.repr is guaranteed to be a non-null TensorImpl*
74        unsafe { const_data_ptr(self.repr) }
75    }
76
77    /// This is *unsafe* as it directly accesses the underlying data pointer.
78    /// Additionally, this should only be used when the user is sure the tensor
79    /// is defined.
80    pub unsafe fn mut_data_ptr(&self) -> *mut std::ffi::c_void {
81        // SAFETY: self.repr is guaranteed to be a non-null TensorImpl*
82        unsafe { mut_data_ptr(self.repr) }
83    }
84
85    /// Self-modify this tensor by copying data from another tensor. The other
86    /// tensor must be the same shape as this one.
87    pub fn copy_(&mut self, src: &Tensor) {
88        copy_(self, src);
89    }
90
91    /// Return the size of each dimension in this tensor.
92    pub fn sizes(&self) -> Vec<i32> {
93        sizes(self)
94    }
95
96    /// Alias of sizes.
97    pub fn shape(&self) -> Vec<i32> {
98        self.sizes()
99    }
100}
101
102impl CloneUnsafe for Tensor {
103    /// This is *unsafe*, it creates an alias of the underlying Tensor that is
104    /// not tracked by Rust. We use this to interface with C++ functions that
105    /// expect an `at::Tensor`.
106    ///
107    /// The contract for calling this function is that the clone is local and
108    /// ephemeral. More precisely:
109    /// 1. The clone must not be sent to another thread (local).
110    /// 2. You must guarantee that clone is dropped before the originating
111    ///    mutable reference is dropped (ephemeral).
112    unsafe fn clone_unsafe(&self) -> Self {
113        // Undefined tensors do not have their refcounts changed.
114        if self.defined() {
115            // SAFETY: increment this tensor's refcount. This ptr is guaranteed to
116            // be non-null by the C++ side.
117            unsafe { cpp_incref(self.repr) };
118        }
119        Tensor { repr: self.repr }
120    }
121}
122
123impl fmt::Display for Tensor {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        write!(f, "{}", repr(self))
126    }
127}
128
129impl PartialEq for Tensor {
130    fn eq(&self, other: &Self) -> bool {
131        self.equal(other)
132    }
133}
134
135// SAFETY: See safety discussion in [`Tensor`]
136unsafe impl Send for Tensor {}
137// SAFETY: See safety discussion in [`Tensor`]
138unsafe impl Sync for Tensor {}
139
140// SAFETY: Register our custom type implementation with cxx.
141// It is okay to mark as trivial, as Tensor is relocatable, see the discussion
142// in `bridge.h`.
143unsafe impl ExternType for Tensor {
144    type Id = type_id!("torch::Tensor");
145    type Kind = cxx::kind::Trivial;
146}
147
148// Simple serialize/desrialize impls for `Tensor` to support sending them over
149// the wire for e.g. `SendValue`.  Right now we just defer to C++'s `torch::save`
150// and `torch::load`, but there might be more efficient ways to do this.
151impl Serialize for Tensor {
152    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
153    where
154        S: Serializer,
155    {
156        // We see deadlocks in CUDA libs that appear to happen when attempting
157        // to save tensors from a thread that doesn't have the corresponding
158        // device active.  So, try to detect this and fail.
159        if self.device().device_type() != DeviceType::CPU {
160            return Err(serde::ser::Error::custom(format!(
161                "can only save CPU tensors (found {:?})",
162                self.device(),
163            )));
164        }
165
166        let bytes = save_tensor(self).map_err(serde::ser::Error::custom)?;
167        serializer.serialize_bytes(&bytes)
168    }
169}
170
171impl<'de> Deserialize<'de> for Tensor {
172    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
173    where
174        D: Deserializer<'de>,
175    {
176        let buf: &[u8] = Deserialize::deserialize(deserializer)?;
177        let tensor = load_tensor(buf).map_err(serde::de::Error::custom)?;
178        Ok(tensor)
179    }
180}
181
182impl FromPyObject<'_> for Tensor {
183    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
184        ffi::tensor_from_py_object(obj.into()).map_err(|e| {
185            PyValueError::new_err(format!(
186                "Failed extracting {} from py as Tensor: {}",
187                obj, e
188            ))
189        })
190    }
191}
192
193impl<'py> IntoPyObject<'py> for Tensor {
194    type Target = PyAny;
195    type Output = Bound<'py, Self::Target>;
196    type Error = PyErr;
197
198    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
199        ffi::tensor_to_py_object(self).into_pyobject(py)
200    }
201}
202
203pub type TensorCell = crate::cell::AliasTrackingRefCell<Tensor>;
204
205impl TensorCell {
206    /// Return cell with the backing tensor on the CPU.  If the backing tensor
207    /// is on a GPU, it'll create a new Tensor/TensorCell with a copy of the
208    /// backing tensor.  If the tensor is already on the CPU, it'll just return
209    /// a this cell.
210    pub fn try_cpu(self) -> Result<TensorCell, atomic_refcell::BorrowError> {
211        {
212            let borrow = self.try_borrow()?;
213            if borrow.device().device_type() != DeviceType::CPU {
214                return Ok(TensorCell::new(borrow.cpu()));
215            }
216        }
217        Ok(self)
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use anyhow::Result;
224    use pyo3::prelude::*;
225
226    use crate::Tensor;
227    use crate::bridge::ffi::deep_clone;
228    use crate::bridge::ffi::test_make_tensor;
229
230    #[test]
231    fn partial_eq() -> Result<()> {
232        let t1 = test_make_tensor();
233        let t2 = deep_clone(&t1);
234        assert_eq!(t1, t2);
235        Ok(())
236    }
237
238    #[test]
239    fn serialize() -> Result<()> {
240        let t1 = test_make_tensor();
241        let buf = bincode::serialize(&t1)?;
242        let t2: Tensor = bincode::deserialize(&buf)?;
243        assert_eq!(t1, t2);
244        Ok(())
245    }
246
247    #[test]
248    fn convert_to_py_and_back() {
249        pyo3::prepare_freethreaded_python();
250        let tensor = test_make_tensor();
251        let converted = Python::with_gil(|py| {
252            // import torch to ensure torch.layout types are registered
253            py.import("torch").unwrap();
254            let obj = deep_clone(&tensor).into_pyobject(py).unwrap();
255            obj.extract::<Tensor>().unwrap()
256        });
257        assert_eq!(converted, tensor);
258    }
259}