torch_sys/
layout.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use cxx::ExternType;
10use cxx::type_id;
11use pyo3::exceptions::PyValueError;
12use pyo3::prelude::*;
13use serde::Deserialize;
14use serde::Serialize;
15
16use crate::Layout;
17use crate::bridge::ffi;
18
19// SAFETY: This type is trival, just an i8.
20unsafe impl ExternType for Layout {
21    type Id = type_id!("c10::Layout");
22    type Kind = cxx::kind::Trivial;
23}
24
25impl Layout {
26    pub(crate) fn from_py_object_or_none(obj: &Bound<'_, PyAny>) -> Option<Self> {
27        ffi::py_object_is_layout(obj.clone().into())
28            .then(|| ffi::layout_from_py_object(obj.into()).unwrap())
29    }
30}
31
32impl FromPyObject<'_> for Layout {
33    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
34        ffi::layout_from_py_object(obj.into()).map_err(|e| {
35            PyValueError::new_err(format!(
36                "Failed extracting {} from py as Layout: {}",
37                obj, e
38            ))
39        })
40    }
41}
42
43impl<'py> IntoPyObject<'py> for Layout {
44    type Target = PyAny;
45    type Output = Bound<'py, Self::Target>;
46    type Error = PyErr;
47
48    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
49        ffi::layout_to_py_object(self).into_pyobject(py)
50    }
51}
52
53// Remotely implement Serialize/Deserialize for generated types
54// TODO: we should be able to use parse_callbacks + add_derives, (see
55// https://github.com/rust-lang/rust-bindgen/pull/2059) and avoid a remote
56// implementation, but this is not supported by rust_bindgen_library.
57#[derive(Serialize, Deserialize)]
58#[serde(remote = "Layout")]
59pub struct LayoutDef(i8);
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use crate::Layout;
65
66    #[test]
67    fn convert_to_py_and_back() {
68        pyo3::prepare_freethreaded_python();
69        let layout = Layout::Strided;
70        let converted_type = Python::with_gil(|py| {
71            // import torch to ensure torch.layout types are registered
72            py.import("torch").unwrap();
73            let obj = layout.into_pyobject(py).unwrap();
74            obj.extract::<Layout>().unwrap()
75        });
76        assert_eq!(converted_type, Layout::Strided);
77    }
78
79    #[test]
80    fn from_py() {
81        pyo3::prepare_freethreaded_python();
82        let layout = Python::with_gil(|py| {
83            let obj = py.import("torch").unwrap().getattr("strided").unwrap();
84            obj.extract::<Layout>().unwrap()
85        });
86        assert_eq!(layout, Layout::Strided);
87    }
88}