1use derive_more::From;
10use derive_more::TryInto;
11use monarch_types::PickledPyObject;
12use monarch_types::TryIntoPyObjectUnsafe;
13use pyo3::IntoPyObjectExt;
14use pyo3::exceptions::PyValueError;
15use pyo3::prelude::*;
16use pyo3::types::PyNone;
17use serde::Deserialize;
18use serde::Serialize;
19
20use crate::Device;
21use crate::IValue;
22use crate::IValueKind;
23use crate::Layout;
24use crate::MemoryFormat;
25use crate::ScalarType;
26use crate::TensorCell;
27use crate::cell::CloneUnsafe;
28use crate::ivalue::OpaqueIValueCell;
29
30#[derive(Debug, Clone, From, TryInto, Serialize, Deserialize)]
33#[try_into(owned, ref, ref_mut)]
34pub enum RValue {
35 Tensor(TensorCell),
36 TensorList(Vec<TensorCell>),
37 Int(i64),
38 IntList(Vec<i64>),
39 Double(f64),
40 Bool(bool),
41 String(String),
42 Device(Device),
43 Layout(#[serde(with = "crate::LayoutDef")] Layout),
44 ScalarType(#[serde(with = "crate::ScalarTypeDef")] ScalarType),
45 MemoryFormat(#[serde(with = "crate::MemoryFormatDef")] MemoryFormat),
46 None,
47 PyObject(PickledPyObject),
48 Opaque(OpaqueIValueCell),
51}
52
53pub unsafe fn rvalue_to_ivalue(rvalue: &RValue) -> IValue {
57 match rvalue {
58 RValue::Tensor(cell) => {
60 IValue::from(unsafe { cell.get_unchecked().clone_unsafe() })
63 }
64 RValue::TensorList(cells) => {
65 let mut tensors = Vec::new();
66 for cell in cells {
67 tensors.push(unsafe { cell.get_unchecked().clone_unsafe() });
70 }
71 IValue::from(tensors)
72 }
73 RValue::Int(val) => IValue::from(*val),
74 RValue::IntList(val) => IValue::from(val.as_slice()),
75 RValue::Double(val) => IValue::from(*val),
76 RValue::Bool(val) => IValue::from(*val),
77 RValue::String(val) => IValue::from(val),
78 RValue::Device(val) => IValue::from(*val),
79 RValue::Layout(val) => IValue::from(val.0 as i64),
83 RValue::ScalarType(val) => IValue::from(val.0 as i64),
84 RValue::MemoryFormat(val) => IValue::from(val.0 as i64),
85 RValue::None => IValue::from(()),
86 RValue::PyObject(val) => {
87 Python::with_gil(|py| val.unpickle(py).unwrap().extract::<IValue>())
88 .expect("unable to convert PyObject to IValue")
89 }
90 RValue::Opaque(cell) => {
91 unsafe { cell.get_unchecked().ivalue() }
94 }
95 }
96}
97
98impl From<IValue> for RValue {
99 fn from(ivalue: IValue) -> Self {
100 match ivalue.kind() {
101 IValueKind::Tensor => RValue::Tensor(TensorCell::new(ivalue.to_tensor().unwrap())),
102 IValueKind::Bool => RValue::Bool(ivalue.to_bool().unwrap()),
103 IValueKind::Int => RValue::Int(ivalue.to_int().unwrap()),
104 IValueKind::IntList => RValue::IntList(ivalue.to_int_list().unwrap()),
105 IValueKind::Double => RValue::Double(ivalue.to_double().unwrap()),
106 IValueKind::String => RValue::String(ivalue.to_string().unwrap()),
107 IValueKind::TensorList => RValue::TensorList(
108 ivalue
109 .to_tensor_list()
110 .unwrap()
111 .into_iter()
112 .map(TensorCell::new)
113 .collect(),
114 ),
115 IValueKind::Device => RValue::Device(ivalue.to_device().unwrap()),
116 IValueKind::None => RValue::None,
117 IValueKind::Other => RValue::Opaque(OpaqueIValueCell::new(ivalue.to_opaque().unwrap())),
118 }
119 }
120}
121
122impl<'py> IntoPyObject<'py> for RValue {
123 type Target = PyAny;
124 type Output = Bound<'py, Self::Target>;
125 type Error = PyErr;
126
127 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
128 match self {
129 RValue::Int(val) => IValue::from(val).into_pyobject(py),
130 RValue::IntList(val) => IValue::from(val.as_slice()).into_pyobject(py),
131 RValue::Double(val) => IValue::from(val).into_pyobject(py),
132 RValue::Bool(val) => IValue::from(val).into_pyobject(py),
133 RValue::String(val) => IValue::from(&val).into_pyobject(py),
134 RValue::Device(val) => IValue::from(val).into_pyobject(py),
135 RValue::Layout(val) => val.clone().into_pyobject(py),
138 RValue::ScalarType(val) => val.clone().into_pyobject(py),
139 RValue::MemoryFormat(val) => val.clone().into_pyobject(py),
140 RValue::None => PyNone::get(py).into_bound_py_any(py),
141 RValue::PyObject(val) => val.unpickle(py),
142 _ => Err(PyErr::new::<PyValueError, _>(format!(
143 "cannot safely create py object from {:?}",
144 self
145 ))),
146 }
147 }
148}
149
150impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for &RValue {
152 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
153 match self {
154 RValue::Layout(val) => val.clone().into_pyobject(py),
158 RValue::ScalarType(val) => val.clone().into_pyobject(py),
159 RValue::MemoryFormat(val) => val.clone().into_pyobject(py),
160 RValue::None => PyNone::get(py).into_bound_py_any(py),
161 RValue::PyObject(val) => val.unpickle(py),
162 _ => unsafe { rvalue_to_ivalue(self).into_pyobject(py) },
165 }
166 }
167}
168
169impl FromPyObject<'_> for RValue {
170 fn extract_bound(obj: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
171 if let Some(val) = ScalarType::from_py_object_or_none(obj) {
175 Ok(RValue::ScalarType(val))
176 } else if let Some(val) = Layout::from_py_object_or_none(obj) {
177 Ok(RValue::Layout(val))
178 } else if let Some(val) = MemoryFormat::from_py_object_or_none(obj) {
179 Ok(RValue::MemoryFormat(val))
180 } else if let Some(val) = IValue::from_py_object_or_none(obj) {
181 Ok(val.into())
182 } else {
183 Ok(RValue::PyObject(PickledPyObject::pickle(obj)?))
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use std::assert_matches::assert_matches;
191
192 use anyhow::Result;
193 use pyo3::ffi::c_str;
194 use pyo3::prelude::*;
195
196 use super::*;
197
198 #[test]
199 fn test_py_object() -> Result<()> {
200 pyo3::prepare_freethreaded_python();
201 let rval = Python::with_gil(|py| {
202 py.import("torch")?;
204
205 py.run(c_str!("class Custom:\n pass"), None, None)?;
207
208 let obj = py.eval(pyo3::ffi::c_str!("Custom()"), None, None)?;
209 RValue::extract_bound(&obj)
210 })?;
211 assert_matches!(rval, RValue::PyObject(_));
214 Ok(())
215 }
216}