torch_sys/
rvalue.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use derive_more::From;
10use derive_more::TryInto;
11use hyperactor::Named;
12use monarch_types::PickledPyObject;
13use monarch_types::TryIntoPyObjectUnsafe;
14use pyo3::IntoPyObjectExt;
15use pyo3::exceptions::PyValueError;
16use pyo3::prelude::*;
17use pyo3::types::PyNone;
18use serde::Deserialize;
19use serde::Serialize;
20
21use crate::Device;
22use crate::IValue;
23use crate::IValueKind;
24use crate::Layout;
25use crate::MemoryFormat;
26use crate::ScalarType;
27use crate::TensorCell;
28use crate::cell::CloneUnsafe;
29use crate::ivalue::OpaqueIValueCell;
30
31/// A pure Rust equivalent for [`IValue`]. This is safe to treat like a normal
32/// Rust value.
33#[derive(Debug, Clone, From, TryInto, Serialize, Deserialize, Named)]
34#[try_into(owned, ref, ref_mut)]
35pub enum RValue {
36    Tensor(TensorCell),
37    TensorList(Vec<TensorCell>),
38    Int(i64),
39    IntList(Vec<i64>),
40    Double(f64),
41    Bool(bool),
42    String(String),
43    Device(Device),
44    Layout(#[serde(with = "crate::LayoutDef")] Layout),
45    ScalarType(#[serde(with = "crate::ScalarTypeDef")] ScalarType),
46    MemoryFormat(#[serde(with = "crate::MemoryFormatDef")] MemoryFormat),
47    None,
48    PyObject(PickledPyObject),
49    /// This is meant to be a catch-all for types that we don't support
50    /// natively in Rust.
51    Opaque(OpaqueIValueCell),
52}
53
54// SAFETY: this function creates untracked aliases of tensors. The caller is
55// responsible for having acquired the suitable borrows and holding them for the
56// entire lifetime of the returned IValue.
57pub unsafe fn rvalue_to_ivalue(rvalue: &RValue) -> IValue {
58    match rvalue {
59        // TODO fix unwrap
60        RValue::Tensor(cell) => {
61            // SAFETY: caller is responsible for holding a borrow, so the outer
62            // function is marked unsafe.
63            IValue::from(unsafe { cell.get_unchecked().clone_unsafe() })
64        }
65        RValue::TensorList(cells) => {
66            let mut tensors = Vec::new();
67            for cell in cells {
68                // SAFETY: caller is responsible for holding a borrow, so the outer
69                // function is marked unsafe.
70                tensors.push(unsafe { cell.get_unchecked().clone_unsafe() });
71            }
72            IValue::from(tensors)
73        }
74        RValue::Int(val) => IValue::from(*val),
75        RValue::IntList(val) => IValue::from(val.as_slice()),
76        RValue::Double(val) => IValue::from(*val),
77        RValue::Bool(val) => IValue::from(*val),
78        RValue::String(val) => IValue::from(val),
79        RValue::Device(val) => IValue::from(*val),
80        // It appears that the enums for Layout/ScalarType/MemoryFormat are just
81        // stored as raw ints in `IValue` and that we lose all info about how
82        // to convert them back.
83        RValue::Layout(val) => IValue::from(val.0 as i64),
84        RValue::ScalarType(val) => IValue::from(val.0 as i64),
85        RValue::MemoryFormat(val) => IValue::from(val.0 as i64),
86        RValue::None => IValue::from(()),
87        RValue::PyObject(val) => {
88            Python::with_gil(|py| val.unpickle(py).unwrap().extract::<IValue>())
89                .expect("unable to convert PyObject to IValue")
90        }
91        RValue::Opaque(cell) => {
92            // SAFETY: caller is responsible for holding a borrow, so the outer
93            // function is marked unsafe.
94            unsafe { cell.get_unchecked().ivalue() }
95        }
96    }
97}
98
99impl From<IValue> for RValue {
100    fn from(ivalue: IValue) -> Self {
101        match ivalue.kind() {
102            IValueKind::Tensor => RValue::Tensor(TensorCell::new(ivalue.to_tensor().unwrap())),
103            IValueKind::Bool => RValue::Bool(ivalue.to_bool().unwrap()),
104            IValueKind::Int => RValue::Int(ivalue.to_int().unwrap()),
105            IValueKind::IntList => RValue::IntList(ivalue.to_int_list().unwrap()),
106            IValueKind::Double => RValue::Double(ivalue.to_double().unwrap()),
107            IValueKind::String => RValue::String(ivalue.to_string().unwrap()),
108            IValueKind::TensorList => RValue::TensorList(
109                ivalue
110                    .to_tensor_list()
111                    .unwrap()
112                    .into_iter()
113                    .map(TensorCell::new)
114                    .collect(),
115            ),
116            IValueKind::Device => RValue::Device(ivalue.to_device().unwrap()),
117            IValueKind::None => RValue::None,
118            IValueKind::Other => RValue::Opaque(OpaqueIValueCell::new(ivalue.to_opaque().unwrap())),
119        }
120    }
121}
122
123impl<'py> IntoPyObject<'py> for RValue {
124    type Target = PyAny;
125    type Output = Bound<'py, Self::Target>;
126    type Error = PyErr;
127
128    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
129        match self {
130            RValue::Int(val) => IValue::from(val).into_pyobject(py),
131            RValue::IntList(val) => IValue::from(val.as_slice()).into_pyobject(py),
132            RValue::Double(val) => IValue::from(val).into_pyobject(py),
133            RValue::Bool(val) => IValue::from(val).into_pyobject(py),
134            RValue::String(val) => IValue::from(&val).into_pyobject(py),
135            RValue::Device(val) => IValue::from(val).into_pyobject(py),
136            // Avoid converting layout and scalar type into ivalues, as it appears
137            // they just get converted to ints.
138            RValue::Layout(val) => val.clone().into_pyobject(py),
139            RValue::ScalarType(val) => val.clone().into_pyobject(py),
140            RValue::MemoryFormat(val) => val.clone().into_pyobject(py),
141            RValue::None => PyNone::get(py).into_bound_py_any(py),
142            RValue::PyObject(val) => val.unpickle(py),
143            _ => Err(PyErr::new::<PyValueError, _>(format!(
144                "cannot safely create py object from {:?}",
145                self
146            ))),
147        }
148    }
149}
150
151/// Convert into a `PyObject`.
152impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for &RValue {
153    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
154        match self {
155            // Avoid converting layout, scalar type, memory format into ivalues, as it appears
156            // they just get converted to ints.
157            // None and PyObject are also not converted as there is no need to do so.
158            RValue::Layout(val) => val.clone().into_pyobject(py),
159            RValue::ScalarType(val) => val.clone().into_pyobject(py),
160            RValue::MemoryFormat(val) => val.clone().into_pyobject(py),
161            RValue::None => PyNone::get(py).into_bound_py_any(py),
162            RValue::PyObject(val) => val.unpickle(py),
163            // SAFETY: This inherits the unsafety of `rvalue_to_ivalue` (see comment
164            // above).
165            _ => unsafe { rvalue_to_ivalue(self).into_pyobject(py) },
166        }
167    }
168}
169
170impl FromPyObject<'_> for RValue {
171    fn extract_bound(obj: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
172        // It's crucial for correctness to try converting to IValue after we've
173        // tried the other non-PyObject variants, because the IValue conversion
174        // will actually succeed when obj is a ScalarType, Layout, or MemoryFormat.
175        if let Some(val) = ScalarType::from_py_object_or_none(obj) {
176            Ok(RValue::ScalarType(val))
177        } else if let Some(val) = Layout::from_py_object_or_none(obj) {
178            Ok(RValue::Layout(val))
179        } else if let Some(val) = MemoryFormat::from_py_object_or_none(obj) {
180            Ok(RValue::MemoryFormat(val))
181        } else if let Some(val) = IValue::from_py_object_or_none(obj) {
182            Ok(val.into())
183        } else {
184            Ok(RValue::PyObject(PickledPyObject::pickle(obj)?))
185        }
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use std::assert_matches::assert_matches;
192
193    use anyhow::Result;
194    use pyo3::ffi::c_str;
195    use pyo3::prelude::*;
196
197    use super::*;
198
199    #[test]
200    fn test_py_object() -> Result<()> {
201        pyo3::prepare_freethreaded_python();
202        let rval = Python::with_gil(|py| {
203            // Needed to initialize torch.
204            py.import("torch")?;
205
206            // Define the Custom class inline
207            py.run(c_str!("class Custom:\n    pass"), None, None)?;
208
209            let obj = py.eval(pyo3::ffi::c_str!("Custom()"), None, None)?;
210            RValue::extract_bound(&obj)
211        })?;
212        // NOTE(agallagher): Among other things, verify this isn't accidentally
213        // extracted as an `IValue`.
214        assert_matches!(rval, RValue::PyObject(_));
215        Ok(())
216    }
217}