1use derive_more::From;
10use derive_more::TryInto;
11use hyperactor::Named;
12use monarch_types::PickledPyObject;
13use monarch_types::TryIntoPyObjectUnsafe;
14use pyo3::IntoPyObjectExt;
15use pyo3::exceptions::PyValueError;
16use pyo3::prelude::*;
17use pyo3::types::PyNone;
18use serde::Deserialize;
19use serde::Serialize;
20
21use crate::Device;
22use crate::IValue;
23use crate::IValueKind;
24use crate::Layout;
25use crate::MemoryFormat;
26use crate::ScalarType;
27use crate::TensorCell;
28use crate::cell::CloneUnsafe;
29use crate::ivalue::OpaqueIValueCell;
30
31#[derive(Debug, Clone, From, TryInto, Serialize, Deserialize, Named)]
34#[try_into(owned, ref, ref_mut)]
35pub enum RValue {
36 Tensor(TensorCell),
37 TensorList(Vec<TensorCell>),
38 Int(i64),
39 IntList(Vec<i64>),
40 Double(f64),
41 Bool(bool),
42 String(String),
43 Device(Device),
44 Layout(#[serde(with = "crate::LayoutDef")] Layout),
45 ScalarType(#[serde(with = "crate::ScalarTypeDef")] ScalarType),
46 MemoryFormat(#[serde(with = "crate::MemoryFormatDef")] MemoryFormat),
47 None,
48 PyObject(PickledPyObject),
49 Opaque(OpaqueIValueCell),
52}
53
54pub unsafe fn rvalue_to_ivalue(rvalue: &RValue) -> IValue {
58 match rvalue {
59 RValue::Tensor(cell) => {
61 IValue::from(unsafe { cell.get_unchecked().clone_unsafe() })
64 }
65 RValue::TensorList(cells) => {
66 let mut tensors = Vec::new();
67 for cell in cells {
68 tensors.push(unsafe { cell.get_unchecked().clone_unsafe() });
71 }
72 IValue::from(tensors)
73 }
74 RValue::Int(val) => IValue::from(*val),
75 RValue::IntList(val) => IValue::from(val.as_slice()),
76 RValue::Double(val) => IValue::from(*val),
77 RValue::Bool(val) => IValue::from(*val),
78 RValue::String(val) => IValue::from(val),
79 RValue::Device(val) => IValue::from(*val),
80 RValue::Layout(val) => IValue::from(val.0 as i64),
84 RValue::ScalarType(val) => IValue::from(val.0 as i64),
85 RValue::MemoryFormat(val) => IValue::from(val.0 as i64),
86 RValue::None => IValue::from(()),
87 RValue::PyObject(val) => {
88 Python::with_gil(|py| val.unpickle(py).unwrap().extract::<IValue>())
89 .expect("unable to convert PyObject to IValue")
90 }
91 RValue::Opaque(cell) => {
92 unsafe { cell.get_unchecked().ivalue() }
95 }
96 }
97}
98
99impl From<IValue> for RValue {
100 fn from(ivalue: IValue) -> Self {
101 match ivalue.kind() {
102 IValueKind::Tensor => RValue::Tensor(TensorCell::new(ivalue.to_tensor().unwrap())),
103 IValueKind::Bool => RValue::Bool(ivalue.to_bool().unwrap()),
104 IValueKind::Int => RValue::Int(ivalue.to_int().unwrap()),
105 IValueKind::IntList => RValue::IntList(ivalue.to_int_list().unwrap()),
106 IValueKind::Double => RValue::Double(ivalue.to_double().unwrap()),
107 IValueKind::String => RValue::String(ivalue.to_string().unwrap()),
108 IValueKind::TensorList => RValue::TensorList(
109 ivalue
110 .to_tensor_list()
111 .unwrap()
112 .into_iter()
113 .map(TensorCell::new)
114 .collect(),
115 ),
116 IValueKind::Device => RValue::Device(ivalue.to_device().unwrap()),
117 IValueKind::None => RValue::None,
118 IValueKind::Other => RValue::Opaque(OpaqueIValueCell::new(ivalue.to_opaque().unwrap())),
119 }
120 }
121}
122
123impl<'py> IntoPyObject<'py> for RValue {
124 type Target = PyAny;
125 type Output = Bound<'py, Self::Target>;
126 type Error = PyErr;
127
128 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
129 match self {
130 RValue::Int(val) => IValue::from(val).into_pyobject(py),
131 RValue::IntList(val) => IValue::from(val.as_slice()).into_pyobject(py),
132 RValue::Double(val) => IValue::from(val).into_pyobject(py),
133 RValue::Bool(val) => IValue::from(val).into_pyobject(py),
134 RValue::String(val) => IValue::from(&val).into_pyobject(py),
135 RValue::Device(val) => IValue::from(val).into_pyobject(py),
136 RValue::Layout(val) => val.clone().into_pyobject(py),
139 RValue::ScalarType(val) => val.clone().into_pyobject(py),
140 RValue::MemoryFormat(val) => val.clone().into_pyobject(py),
141 RValue::None => PyNone::get(py).into_bound_py_any(py),
142 RValue::PyObject(val) => val.unpickle(py),
143 _ => Err(PyErr::new::<PyValueError, _>(format!(
144 "cannot safely create py object from {:?}",
145 self
146 ))),
147 }
148 }
149}
150
151impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for &RValue {
153 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
154 match self {
155 RValue::Layout(val) => val.clone().into_pyobject(py),
159 RValue::ScalarType(val) => val.clone().into_pyobject(py),
160 RValue::MemoryFormat(val) => val.clone().into_pyobject(py),
161 RValue::None => PyNone::get(py).into_bound_py_any(py),
162 RValue::PyObject(val) => val.unpickle(py),
163 _ => unsafe { rvalue_to_ivalue(self).into_pyobject(py) },
166 }
167 }
168}
169
170impl FromPyObject<'_> for RValue {
171 fn extract_bound(obj: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
172 if let Some(val) = ScalarType::from_py_object_or_none(obj) {
176 Ok(RValue::ScalarType(val))
177 } else if let Some(val) = Layout::from_py_object_or_none(obj) {
178 Ok(RValue::Layout(val))
179 } else if let Some(val) = MemoryFormat::from_py_object_or_none(obj) {
180 Ok(RValue::MemoryFormat(val))
181 } else if let Some(val) = IValue::from_py_object_or_none(obj) {
182 Ok(val.into())
183 } else {
184 Ok(RValue::PyObject(PickledPyObject::pickle(obj)?))
185 }
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use std::assert_matches::assert_matches;
192
193 use anyhow::Result;
194 use pyo3::ffi::c_str;
195 use pyo3::prelude::*;
196
197 use super::*;
198
199 #[test]
200 fn test_py_object() -> Result<()> {
201 pyo3::prepare_freethreaded_python();
202 let rval = Python::with_gil(|py| {
203 py.import("torch")?;
205
206 py.run(c_str!("class Custom:\n pass"), None, None)?;
208
209 let obj = py.eval(pyo3::ffi::c_str!("Custom()"), None, None)?;
210 RValue::extract_bound(&obj)
211 })?;
212 assert_matches!(rval, RValue::PyObject(_));
215 Ok(())
216 }
217}