torch_sys/
scalar_type.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use cxx::ExternType;
10use cxx::type_id;
11use pyo3::exceptions::PyValueError;
12use pyo3::prelude::*;
13use serde::Deserialize;
14use serde::Serialize;
15
16use crate::ScalarType;
17use crate::bridge::ffi;
18
19// Bind generated types to CXX
20// SAFETY: This type is trival, just an i8.
21unsafe impl ExternType for ScalarType {
22    type Id = type_id!("c10::ScalarType");
23    type Kind = cxx::kind::Trivial;
24}
25
26impl ScalarType {
27    pub(crate) fn from_py_object_or_none(obj: &Bound<'_, PyAny>) -> Option<Self> {
28        ffi::py_object_is_scalar_type(obj.clone().into())
29            .then(|| ffi::scalar_type_from_py_object(obj.into()).unwrap())
30    }
31}
32
33impl FromPyObject<'_> for ScalarType {
34    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
35        ffi::scalar_type_from_py_object(obj.into()).map_err(|e| {
36            PyValueError::new_err(format!(
37                "Failed extracting {} from py as ScalarType: {}",
38                obj, e
39            ))
40        })
41    }
42}
43
44impl<'py> IntoPyObject<'py> for ScalarType {
45    type Target = PyAny;
46    type Output = Bound<'py, Self::Target>;
47    type Error = PyErr;
48
49    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
50        ffi::scalar_type_to_py_object(self).into_pyobject(py)
51    }
52}
53
54// Remotely implement Serialize/Deserialize for generated types
55// TODO: we should be able to use parse_callbacks + add_derives, (see
56// https://github.com/rust-lang/rust-bindgen/pull/2059) and avoid a remote
57#[derive(Serialize, Deserialize)]
58#[serde(remote = "ScalarType")]
59pub struct ScalarTypeDef(i8);
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64
65    #[test]
66    fn convert_to_py_and_back() {
67        pyo3::prepare_freethreaded_python();
68        let scalar_type = ScalarType::Float;
69        let converted_type = Python::with_gil(|py| {
70            // import torch to ensure torch.dtype types are registered
71            py.import("torch").unwrap();
72            let obj = scalar_type.into_pyobject(py).unwrap();
73            obj.extract::<ScalarType>().unwrap()
74        });
75        assert_eq!(converted_type, ScalarType::Float);
76    }
77
78    #[test]
79    fn from_py() {
80        pyo3::prepare_freethreaded_python();
81        let scalar_type = Python::with_gil(|py| {
82            let obj = py.import("torch").unwrap().getattr("float32").unwrap();
83            obj.extract::<ScalarType>().unwrap()
84        });
85        assert_eq!(scalar_type, ScalarType::Float);
86    }
87}