torch_sys/
memory_format.rs1use cxx::ExternType;
10use cxx::type_id;
11use pyo3::exceptions::PyValueError;
12use pyo3::prelude::*;
13use serde::Deserialize;
14use serde::Serialize;
15
16use crate::MemoryFormat;
17use crate::bridge::ffi;
18
19unsafe impl ExternType for MemoryFormat {
21 type Id = type_id!("c10::MemoryFormat");
22 type Kind = cxx::kind::Trivial;
23}
24
25impl MemoryFormat {
26 pub(crate) fn from_py_object_or_none(obj: &Bound<'_, PyAny>) -> Option<Self> {
27 ffi::py_object_is_memory_format(obj.clone().into())
28 .then(|| ffi::memory_format_from_py_object(obj.into()).unwrap())
29 }
30}
31
32impl FromPyObject<'_> for MemoryFormat {
33 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
34 ffi::memory_format_from_py_object(obj.into()).map_err(|e| {
35 PyValueError::new_err(format!(
36 "Failed extracting {} from py as Layout: {}",
37 obj, e
38 ))
39 })
40 }
41}
42
43impl<'py> IntoPyObject<'py> for MemoryFormat {
44 type Target = PyAny;
45 type Output = Bound<'py, Self::Target>;
46 type Error = PyErr;
47
48 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
49 ffi::memory_format_to_py_object(self).into_pyobject(py)
50 }
51}
52
53#[derive(Serialize, Deserialize)]
58#[serde(remote = "MemoryFormat")]
59pub struct MemoryFormatDef(i8);
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64
65 #[test]
66 fn convert_to_py_and_back() {
67 pyo3::prepare_freethreaded_python();
68 let memory_format = MemoryFormat::Contiguous;
69 let converted_type = Python::with_gil(|py| {
70 py.import("torch").unwrap();
72 let obj = memory_format.into_pyobject(py).unwrap();
73 obj.extract::<MemoryFormat>().unwrap()
74 });
75 assert_eq!(converted_type, MemoryFormat::Contiguous);
76 }
77
78 #[test]
79 fn from_py() {
80 pyo3::prepare_freethreaded_python();
81 let memory_format = Python::with_gil(|py| {
82 let obj = py
83 .import("torch")
84 .unwrap()
85 .getattr("preserve_format")
86 .unwrap();
87 obj.extract::<MemoryFormat>().unwrap()
88 });
89 assert_eq!(memory_format, MemoryFormat::Preserve);
90 }
91}