monarch_messages/
wire_value.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use derive_more::From;
10use derive_more::TryInto;
11use monarch_types::PickledPyObject;
12use pyo3::IntoPyObjectExt;
13use pyo3::prelude::*;
14use pyo3::types::PyNone;
15use serde::Deserialize;
16use serde::Serialize;
17use torch_sys2::Device;
18use torch_sys2::Layout;
19use torch_sys2::MemoryFormat;
20use torch_sys2::ScalarType;
21use typeuri::Named;
22
23use crate::worker::Ref;
24
25/// A value used as an input to CallFunction.
26// TODO, this is basically the same as RValue, but with TensorIndices swapped
27// out for refs. And IValue is the same as RValue, but with real tensors and
28// C++ types. I wonder if there is a nicer way to express this relationship.
29// TODO extend this to support other types of values, like bytes, dicts etc.
30#[derive(Serialize, Deserialize, Debug, Clone, TryInto, Named, From)]
31pub enum WireValue {
32    // Make sure boolean goes ealier than int as bool is a subclass of int.
33    // Otherwise, bool will be converted to int.
34    Bool(bool),
35    Int(i64),
36    Double(f64),
37    String(String),
38    Ref(Ref),
39    IntList(Vec<i64>),
40    RefList(Vec<Ref>),
41    Device(Device),
42    Layout(#[serde(with = "torch_sys2::LayoutDef")] Layout),
43    ScalarType(#[serde(with = "torch_sys2::ScalarTypeDef")] ScalarType),
44    MemoryFormat(#[serde(with = "torch_sys2::MemoryFormatDef")] MemoryFormat),
45    // Make this wrap the unit type, as `pyo3::FromPyObject` doesn't work with
46    // empty enum variants.
47    None(()),
48    PyObject(PickledPyObject),
49}
50wirevalue::register_type!(WireValue);
51
52impl FromPyObject<'_> for WireValue {
53    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
54        Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?))
55    }
56}
57
58impl<'py> IntoPyObject<'py> for WireValue {
59    type Target = PyAny;
60    type Output = Bound<'py, PyAny>;
61    type Error = PyErr;
62
63    fn into_pyobject(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
64        match self {
65            WireValue::Ref(ref_) => ref_.into_bound_py_any(py),
66            WireValue::RefList(ref_list) => ref_list.clone().into_bound_py_any(py),
67            WireValue::Int(int) => int.into_bound_py_any(py),
68            WireValue::IntList(int_list) => int_list.clone().into_bound_py_any(py),
69            WireValue::Double(double) => double.into_bound_py_any(py),
70            WireValue::Bool(bool_) => bool_.into_bound_py_any(py),
71            WireValue::String(string) => string.into_bound_py_any(py),
72            WireValue::Device(device) => device.into_bound_py_any(py),
73            WireValue::Layout(val) => val.into_bound_py_any(py),
74            WireValue::ScalarType(val) => val.into_bound_py_any(py),
75            WireValue::MemoryFormat(val) => val.into_bound_py_any(py),
76            WireValue::None(()) => PyNone::get(py).into_bound_py_any(py),
77            WireValue::PyObject(val) => val.unpickle(py),
78        }
79    }
80}
81
82impl From<PyObject> for WireValue {
83    fn from(obj: PyObject) -> Self {
84        Python::with_gil(|py| WireValue::PyObject(PickledPyObject::pickle(obj.bind(py)).unwrap()))
85    }
86}