torch_sys/
rvalue.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use derive_more::From;
10use derive_more::TryInto;
11use monarch_types::PickledPyObject;
12use monarch_types::TryIntoPyObjectUnsafe;
13use pyo3::IntoPyObjectExt;
14use pyo3::exceptions::PyValueError;
15use pyo3::prelude::*;
16use pyo3::types::PyNone;
17use serde::Deserialize;
18use serde::Serialize;
19
20use crate::Device;
21use crate::IValue;
22use crate::IValueKind;
23use crate::Layout;
24use crate::MemoryFormat;
25use crate::ScalarType;
26use crate::TensorCell;
27use crate::cell::CloneUnsafe;
28use crate::ivalue::OpaqueIValueCell;
29
30/// A pure Rust equivalent for [`IValue`]. This is safe to treat like a normal
31/// Rust value.
32#[derive(Debug, Clone, From, TryInto, Serialize, Deserialize)]
33#[try_into(owned, ref, ref_mut)]
34pub enum RValue {
35    Tensor(TensorCell),
36    TensorList(Vec<TensorCell>),
37    Int(i64),
38    IntList(Vec<i64>),
39    Double(f64),
40    Bool(bool),
41    String(String),
42    Device(Device),
43    Layout(#[serde(with = "crate::LayoutDef")] Layout),
44    ScalarType(#[serde(with = "crate::ScalarTypeDef")] ScalarType),
45    MemoryFormat(#[serde(with = "crate::MemoryFormatDef")] MemoryFormat),
46    None,
47    PyObject(PickledPyObject),
48    /// This is meant to be a catch-all for types that we don't support
49    /// natively in Rust.
50    Opaque(OpaqueIValueCell),
51}
52
53// SAFETY: this function creates untracked aliases of tensors. The caller is
54// responsible for having acquired the suitable borrows and holding them for the
55// entire lifetime of the returned IValue.
56pub unsafe fn rvalue_to_ivalue(rvalue: &RValue) -> IValue {
57    match rvalue {
58        // TODO fix unwrap
59        RValue::Tensor(cell) => {
60            // SAFETY: caller is responsible for holding a borrow, so the outer
61            // function is marked unsafe.
62            IValue::from(unsafe { cell.get_unchecked().clone_unsafe() })
63        }
64        RValue::TensorList(cells) => {
65            let mut tensors = Vec::new();
66            for cell in cells {
67                // SAFETY: caller is responsible for holding a borrow, so the outer
68                // function is marked unsafe.
69                tensors.push(unsafe { cell.get_unchecked().clone_unsafe() });
70            }
71            IValue::from(tensors)
72        }
73        RValue::Int(val) => IValue::from(*val),
74        RValue::IntList(val) => IValue::from(val.as_slice()),
75        RValue::Double(val) => IValue::from(*val),
76        RValue::Bool(val) => IValue::from(*val),
77        RValue::String(val) => IValue::from(val),
78        RValue::Device(val) => IValue::from(*val),
79        // It appears that the enums for Layout/ScalarType/MemoryFormat are just
80        // stored as raw ints in `IValue` and that we lose all info about how
81        // to convert them back.
82        RValue::Layout(val) => IValue::from(val.0 as i64),
83        RValue::ScalarType(val) => IValue::from(val.0 as i64),
84        RValue::MemoryFormat(val) => IValue::from(val.0 as i64),
85        RValue::None => IValue::from(()),
86        RValue::PyObject(val) => {
87            Python::with_gil(|py| val.unpickle(py).unwrap().extract::<IValue>())
88                .expect("unable to convert PyObject to IValue")
89        }
90        RValue::Opaque(cell) => {
91            // SAFETY: caller is responsible for holding a borrow, so the outer
92            // function is marked unsafe.
93            unsafe { cell.get_unchecked().ivalue() }
94        }
95    }
96}
97
98impl From<IValue> for RValue {
99    fn from(ivalue: IValue) -> Self {
100        match ivalue.kind() {
101            IValueKind::Tensor => RValue::Tensor(TensorCell::new(ivalue.to_tensor().unwrap())),
102            IValueKind::Bool => RValue::Bool(ivalue.to_bool().unwrap()),
103            IValueKind::Int => RValue::Int(ivalue.to_int().unwrap()),
104            IValueKind::IntList => RValue::IntList(ivalue.to_int_list().unwrap()),
105            IValueKind::Double => RValue::Double(ivalue.to_double().unwrap()),
106            IValueKind::String => RValue::String(ivalue.to_string().unwrap()),
107            IValueKind::TensorList => RValue::TensorList(
108                ivalue
109                    .to_tensor_list()
110                    .unwrap()
111                    .into_iter()
112                    .map(TensorCell::new)
113                    .collect(),
114            ),
115            IValueKind::Device => RValue::Device(ivalue.to_device().unwrap()),
116            IValueKind::None => RValue::None,
117            IValueKind::Other => RValue::Opaque(OpaqueIValueCell::new(ivalue.to_opaque().unwrap())),
118        }
119    }
120}
121
122impl<'py> IntoPyObject<'py> for RValue {
123    type Target = PyAny;
124    type Output = Bound<'py, Self::Target>;
125    type Error = PyErr;
126
127    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
128        match self {
129            RValue::Int(val) => IValue::from(val).into_pyobject(py),
130            RValue::IntList(val) => IValue::from(val.as_slice()).into_pyobject(py),
131            RValue::Double(val) => IValue::from(val).into_pyobject(py),
132            RValue::Bool(val) => IValue::from(val).into_pyobject(py),
133            RValue::String(val) => IValue::from(&val).into_pyobject(py),
134            RValue::Device(val) => IValue::from(val).into_pyobject(py),
135            // Avoid converting layout and scalar type into ivalues, as it appears
136            // they just get converted to ints.
137            RValue::Layout(val) => val.clone().into_pyobject(py),
138            RValue::ScalarType(val) => val.clone().into_pyobject(py),
139            RValue::MemoryFormat(val) => val.clone().into_pyobject(py),
140            RValue::None => PyNone::get(py).into_bound_py_any(py),
141            RValue::PyObject(val) => val.unpickle(py),
142            _ => Err(PyErr::new::<PyValueError, _>(format!(
143                "cannot safely create py object from {:?}",
144                self
145            ))),
146        }
147    }
148}
149
150/// Convert into a `PyObject`.
151impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for &RValue {
152    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
153        match self {
154            // Avoid converting layout, scalar type, memory format into ivalues, as it appears
155            // they just get converted to ints.
156            // None and PyObject are also not converted as there is no need to do so.
157            RValue::Layout(val) => val.clone().into_pyobject(py),
158            RValue::ScalarType(val) => val.clone().into_pyobject(py),
159            RValue::MemoryFormat(val) => val.clone().into_pyobject(py),
160            RValue::None => PyNone::get(py).into_bound_py_any(py),
161            RValue::PyObject(val) => val.unpickle(py),
162            // SAFETY: This inherits the unsafety of `rvalue_to_ivalue` (see comment
163            // above).
164            _ => unsafe { rvalue_to_ivalue(self).into_pyobject(py) },
165        }
166    }
167}
168
169impl FromPyObject<'_> for RValue {
170    fn extract_bound(obj: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
171        // It's crucial for correctness to try converting to IValue after we've
172        // tried the other non-PyObject variants, because the IValue conversion
173        // will actually succeed when obj is a ScalarType, Layout, or MemoryFormat.
174        if let Some(val) = ScalarType::from_py_object_or_none(obj) {
175            Ok(RValue::ScalarType(val))
176        } else if let Some(val) = Layout::from_py_object_or_none(obj) {
177            Ok(RValue::Layout(val))
178        } else if let Some(val) = MemoryFormat::from_py_object_or_none(obj) {
179            Ok(RValue::MemoryFormat(val))
180        } else if let Some(val) = IValue::from_py_object_or_none(obj) {
181            Ok(val.into())
182        } else {
183            Ok(RValue::PyObject(PickledPyObject::pickle(obj)?))
184        }
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use std::assert_matches::assert_matches;
191
192    use anyhow::Result;
193    use pyo3::ffi::c_str;
194    use pyo3::prelude::*;
195
196    use super::*;
197
198    #[test]
199    fn test_py_object() -> Result<()> {
200        pyo3::prepare_freethreaded_python();
201        let rval = Python::with_gil(|py| {
202            // Needed to initialize torch.
203            py.import("torch")?;
204
205            // Define the Custom class inline
206            py.run(c_str!("class Custom:\n    pass"), None, None)?;
207
208            let obj = py.eval(pyo3::ffi::c_str!("Custom()"), None, None)?;
209            RValue::extract_bound(&obj)
210        })?;
211        // NOTE(agallagher): Among other things, verify this isn't accidentally
212        // extracted as an `IValue`.
213        assert_matches!(rval, RValue::PyObject(_));
214        Ok(())
215    }
216}