1use std::collections::HashMap;
10
11use derive_more::From;
12use derive_more::TryInto;
13use enum_as_inner::EnumAsInner;
14use hyperactor::Named;
15use monarch_types::PickledPyObject;
16use monarch_types::TryIntoPyObjectUnsafe;
17use pyo3::IntoPyObjectExt;
18use pyo3::exceptions::PyValueError;
19use pyo3::prelude::*;
20use pyo3::types::PyBool;
21use pyo3::types::PyDict;
22use pyo3::types::PyFloat;
23use pyo3::types::PyList;
24use pyo3::types::PyNone;
25use pyo3::types::PyString;
26use pyo3::types::PyTuple;
27use serde::Deserialize;
28use serde::Serialize;
29use torch_sys::Device;
30use torch_sys::Layout;
31use torch_sys::MemoryFormat;
32use torch_sys::OpaqueIValue;
33use torch_sys::ScalarType;
34
35use crate::worker::Ref;
36use crate::worker::ResolvableFunction;
37
38#[derive(
44 Serialize,
45 Deserialize,
46 Debug,
47 Clone,
48 TryInto,
49 Named,
50 From,
51 EnumAsInner
52)]
53pub enum WireValue {
54 Bool(bool),
57 Int(i64),
58 Double(f64),
59 String(String),
60 Ref(Ref),
61 IntList(Vec<i64>),
62 RefList(Vec<Ref>),
63 Device(Device),
64 Layout(#[serde(with = "torch_sys::LayoutDef")] Layout),
65 ScalarType(#[serde(with = "torch_sys::ScalarTypeDef")] ScalarType),
66 MemoryFormat(#[serde(with = "torch_sys::MemoryFormatDef")] MemoryFormat),
67 None(()),
70 PyObject(PickledPyObject),
71 IValue(torch_sys::OpaqueIValue),
76}
77
78impl FromPyObject<'_> for WireValue {
79 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
80 if let Ok(ref_) = Ref::from_py_object(obj) {
81 Ok(WireValue::Ref(ref_))
82 } else if let Ok(list) = obj.downcast::<PyList>() {
83 let len = list.len();
84 if len == 0 {
85 return Ok(WireValue::IntList(vec![]));
88 }
89
90 let item = unsafe { list.get_item_unchecked(0) };
92 let len = list.len();
93 if let Ok(int) = item.extract::<i64>() {
94 let mut int_list = Vec::with_capacity(len);
95 int_list.push(int);
96 for item in list.iter().skip(1) {
97 int_list.push(item.extract::<i64>().map_err(|_| {
98 PyValueError::new_err(format!(
99 "Expected homogeneous list of ints got: {:?}",
100 list
101 ))
102 })?);
103 }
104 return Ok(WireValue::IntList(int_list));
105 }
106 if let Ok(ref_) = Ref::from_py_object(&item) {
107 let mut ref_list = Vec::with_capacity(len);
108 ref_list.push(ref_);
109 for item in list.iter().skip(1) {
110 ref_list.push(Ref::from_py_object(&item).map_err(|_| {
111 PyValueError::new_err(format!(
112 "Expected homogeneous list of ints got: {:?}",
113 list
114 ))
115 })?);
116 }
117 return Ok(WireValue::RefList(ref_list));
118 }
119 Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?))
120 } else if obj.is_none() {
121 Ok(WireValue::None(()))
122 } else if let Ok(bool_) = obj.downcast::<PyBool>() {
123 Ok(WireValue::Bool(bool_.is_true()))
124 } else if let Ok(int) = obj.extract::<i64>() {
125 Ok(WireValue::Int(int))
126 } else if let Ok(double) = obj.downcast::<PyFloat>() {
127 Ok(WireValue::Double(double.value()))
128 } else if let Ok(string) = obj.downcast::<PyString>() {
129 Ok(WireValue::String(string.to_str()?.to_string()))
130 } else if let Ok(device) = obj.extract::<Device>() {
131 Ok(WireValue::Device(device))
132 } else if let Ok(layout) = obj.extract::<Layout>() {
133 Ok(WireValue::Layout(layout))
134 } else if let Ok(scalar_type) = obj.extract::<ScalarType>() {
135 Ok(WireValue::ScalarType(scalar_type))
136 } else if let Ok(memory_format) = obj.extract::<MemoryFormat>() {
137 Ok(WireValue::MemoryFormat(memory_format))
138 } else {
139 Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?))
140 }
141 }
142}
143
144impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for WireValue {
145 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
146 match self {
147 WireValue::Ref(ref_) => ref_.into_bound_py_any(py),
148 WireValue::RefList(ref_list) => ref_list.clone().into_bound_py_any(py),
149 WireValue::Int(int) => int.into_bound_py_any(py),
150 WireValue::IntList(int_list) => int_list.clone().into_bound_py_any(py),
151 WireValue::Double(double) => double.into_bound_py_any(py),
152 WireValue::Bool(bool_) => bool_.into_bound_py_any(py),
153 WireValue::String(string) => string.into_bound_py_any(py),
154 WireValue::Device(device) => device.into_bound_py_any(py),
155 WireValue::Layout(val) => val.into_bound_py_any(py),
156 WireValue::ScalarType(val) => val.into_bound_py_any(py),
157 WireValue::MemoryFormat(val) => val.into_bound_py_any(py),
158 WireValue::None(()) => PyNone::get(py).into_bound_py_any(py),
159 WireValue::PyObject(val) => val.unpickle(py),
160 WireValue::IValue(val) => unsafe { val.try_to_object_unsafe(py) },
164 }
165 }
166}
167
168impl From<PyObject> for WireValue {
169 fn from(obj: PyObject) -> Self {
170 Python::with_gil(|py| WireValue::PyObject(PickledPyObject::pickle(obj.bind(py)).unwrap()))
171 }
172}
173
174impl WireValue {
175 fn from_pyobject_with_torch_op_arg_type(
176 obj: Bound<'_, PyAny>,
177 type_: &torch_sys::call_op::TypePtr,
178 num_elements: i32,
179 allow_nums_as_tensors: bool,
180 ) -> PyResult<Self> {
181 if type_.is_tensor() || type_.is_optional_tensor() {
182 if type_.is_optional_tensor() && obj.is_none() {
183 return Ok(WireValue::None(()));
184 } else if let Ok(ref_) = Ref::from_py_object(&obj) {
185 return Ok(WireValue::Ref(ref_));
186 }
187 }
188 if type_.is_tensor_list() || type_.is_optional_tensor_list() {
189 if type_.is_optional_tensor_list() && obj.is_none() {
190 return Ok(WireValue::None(()));
191 }
192 let list = obj.downcast::<PyList>()?;
193 let len = list.len();
194 if len == 0 {
195 return Ok(WireValue::RefList(vec![]));
196 }
197 let item = unsafe { list.get_item_unchecked(0) };
199 if let Ok(ref_) = Ref::from_py_object(&item) {
200 let mut ref_list = Vec::with_capacity(len);
201 ref_list.push(ref_);
202 for item in list.iter().skip(1) {
203 ref_list.push(Ref::from_py_object(&item).map_err(|_| {
204 PyValueError::new_err(format!(
205 "Expected homogeneous list of refs got: {:?}",
206 list
207 ))
208 })?);
209 }
210 return Ok(WireValue::RefList(ref_list));
211 }
212 }
213 OpaqueIValue::from_py_object_with_type(obj, type_, num_elements, allow_nums_as_tensors)
214 .map(WireValue::IValue)
215 }
216}
217
218pub fn func_call_args_to_wire_values(
219 func: Option<&ResolvableFunction>,
220 args: &Bound<'_, PyTuple>,
221 kwargs: &Bound<'_, PyDict>,
222) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
223 if let Some((op, overload)) = func.and_then(|func| func.as_torch_op()) {
224 torch_op_args_to_wire_values(&op, &overload, args, kwargs)
225 } else {
226 python_func_args_to_wire_value(args, kwargs)
227 }
228}
229
230fn torch_op_args_to_wire_values(
231 op: &str,
232 overload: &str,
233 args: &Bound<'_, PyTuple>,
234 kwargs: &Bound<'_, PyDict>,
235) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
236 let args_info = torch_sys::call_op::get_schema_args_info(op, overload).map_err(|err| {
237 PyValueError::new_err(format!(
238 "Failed to get the operator schema for {}::{}: {}",
239 op, overload, err
240 ))
241 })?;
242
243 let args = args
244 .iter()
245 .zip(&args_info)
246 .map(|(arg, arg_info)| {
247 WireValue::from_pyobject_with_torch_op_arg_type(
248 arg,
249 arg_info.type_,
250 arg_info.num_elements,
251 arg_info.allows_number_as_tensor,
252 )
253 })
254 .collect::<Result<Vec<_>, _>>()?;
255 let kwargs = kwargs
256 .iter()
257 .map(|(k, v)| {
258 let key = k.extract::<String>()?;
259 let arg_info = args_info
260 .iter()
261 .find(|arg_info| arg_info.name == key)
262 .ok_or_else(|| {
263 PyValueError::new_err(format!(
264 "Torch op {}::{} does not support kwarg {}",
265 op, overload, key
266 ))
267 })?;
268 let val = WireValue::from_pyobject_with_torch_op_arg_type(
269 v,
270 arg_info.type_,
271 arg_info.num_elements,
272 arg_info.allows_number_as_tensor,
273 )?;
274 Ok((key, val))
275 })
276 .collect::<Result<HashMap<_, _>, PyErr>>()?;
277 Ok((args, kwargs))
278}
279
280fn python_func_args_to_wire_value(
281 args: &Bound<'_, PyTuple>,
282 kwargs: &Bound<'_, PyDict>,
283) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
284 let args = args
285 .iter()
286 .map(|arg| Ok(WireValue::PyObject(PickledPyObject::pickle(&arg)?)))
287 .collect::<PyResult<_>>()?;
288 let kwargs = kwargs
289 .iter()
290 .map(|(k, v)| {
291 Ok((
292 k.extract::<String>()?,
293 WireValue::PyObject(PickledPyObject::pickle(&v)?),
294 ))
295 })
296 .collect::<Result<HashMap<_, _>, PyErr>>()?;
297 Ok((args, kwargs))
298}
299
300#[cfg(test)]
301mod tests {
302 use std::assert_matches::assert_matches;
303
304 use anyhow::Result;
305 use anyhow::bail;
306 use paste::paste;
307 use pyo3::Python;
308 use pyo3::ffi::c_str;
309 use pyo3::types::PyDict;
310 use torch_sys::DeviceType;
311 use torch_sys::ScalarType;
312
313 use super::*;
314 use crate::worker::Ref;
315
316 const MOCK_REFERNCABLE_MODULE: &std::ffi::CStr = c_str!(
317 r#"
318class Referencable:
319 def __init__(self, ref: int):
320 self.ref = ref
321
322 def __monarch_ref__(self):
323 return self.ref
324"#
325 );
326
327 fn setup() -> Result<()> {
328 pyo3::prepare_freethreaded_python();
329 Python::with_gil(|py| py.run(c_str!("import torch"), None, None))?;
332 Ok(())
333 }
334
335 fn create_py_object() -> PyObject {
336 pyo3::prepare_freethreaded_python();
337 Python::with_gil(|py| {
338 let dict = PyDict::new(py);
339 dict.set_item("foo", "bar").unwrap();
340 dict.into_any().clone().unbind()
341 })
342 }
343
344 macro_rules! generate_wire_value_from_py_tests {
345 ($($kind:ident, $input:expr);* $(;)?) => {
346 paste! {
347 $(
348 #[test]
349 fn [<test_wire_value_from_py_$kind:snake:lower>]() -> Result<()> {
350 setup()?;
351 Python::with_gil(|py| {
352 let actual = $input.into_pyobject(py)?.extract::<WireValue>()?;
353 assert_matches!(actual, WireValue::$kind(_));
354 anyhow::Ok(())
355 })
356 }
357 )*
358
359 #[test]
360 fn test_wire_value_from_py_none() -> Result<()> {
361 setup()?;
362 Python::with_gil(|py| {
363 let obj = PyNone::get(py).into_pyobject(py)?;
364 let actual = obj.extract::<WireValue>()?;
365 assert_matches!(actual, WireValue::None(_));
366 anyhow::Ok(())
367 })
368 }
369
370 #[test]
371 fn test_wire_value_from_py_empty_list() -> Result<()> {
372 setup()?;
373 Python::with_gil(|py| {
374 let obj: PyObject = PyList::empty(py).into_any().unbind();
375 let actual = obj.extract::<WireValue>(py)?;
376 match actual {
377 WireValue::IntList(list) if list.len() == 0 => (),
378 _ => bail!("Expected empty list to be converted to empty int list"),
379 }
380 anyhow::Ok(())
381 })
382 }
383
384 #[test]
385 fn test_wire_value_from_py_referencable_class() -> Result<()> {
386 setup()?;
387 Python::with_gil(|py| {
388 let referencable = PyModule::from_code(
389 py,
390 MOCK_REFERNCABLE_MODULE,
391 c_str!("referencable.py"),
392 c_str!("referencable"),
393 )?;
394 let ref_ = referencable.getattr("Referencable")?.call1((1,))?.unbind();
395 let actual = ref_.extract::<WireValue>(py)?;
396 assert_matches!(actual, WireValue::Ref(Ref { id: 1 }));
397 anyhow::Ok(())
398 })
399 }
400
401 #[test]
402 fn test_wire_value_from_py_roundtrip_was_exhaustive() {
403 let val = WireValue::Int(0);
404 match val {
405 $(WireValue::$kind(_) => (),)*
406 WireValue::None(_) => (),
407 WireValue::IValue(_) => (),
411 }
412 }
413 }
414 }
415 }
416
417 generate_wire_value_from_py_tests! {
421 Bool, false;
422 Double, 1.23f64;
423 Int, 123i64;
424 IntList, vec![1i64];
425 Ref, Ref::from(1);
426 RefList, vec![Ref::from(1), Ref::from(2)];
427 String, "foobar".to_owned();
428 Device, Device::new(DeviceType::CPU);
429 Layout, Layout(2);
430 ScalarType, ScalarType(3);
431 MemoryFormat, MemoryFormat(1);
432 PyObject, create_py_object();
433 }
434}