monarch_messages/
wire_value.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10
11use derive_more::From;
12use derive_more::TryInto;
13use enum_as_inner::EnumAsInner;
14use hyperactor::Named;
15use monarch_types::PickledPyObject;
16use monarch_types::TryIntoPyObjectUnsafe;
17use pyo3::IntoPyObjectExt;
18use pyo3::exceptions::PyValueError;
19use pyo3::prelude::*;
20use pyo3::types::PyBool;
21use pyo3::types::PyDict;
22use pyo3::types::PyFloat;
23use pyo3::types::PyList;
24use pyo3::types::PyNone;
25use pyo3::types::PyString;
26use pyo3::types::PyTuple;
27use serde::Deserialize;
28use serde::Serialize;
29use torch_sys::Device;
30use torch_sys::Layout;
31use torch_sys::MemoryFormat;
32use torch_sys::OpaqueIValue;
33use torch_sys::ScalarType;
34
35use crate::worker::Ref;
36use crate::worker::ResolvableFunction;
37
38/// A value used as an input to CallFunction.
39// TODO, this is basically the same as RValue, but with TensorIndices swapped
40// out for refs. And IValue is the same as RValue, but with real tensors and
41// C++ types. I wonder if there is a nicer way to express this relationship.
42// TODO extend this to support other types of values, like bytes, dicts etc.
43#[derive(
44    Serialize,
45    Deserialize,
46    Debug,
47    Clone,
48    TryInto,
49    Named,
50    From,
51    EnumAsInner
52)]
53pub enum WireValue {
54    // Make sure boolean goes ealier than int as bool is a subclass of int.
55    // Otherwise, bool will be converted to int.
56    Bool(bool),
57    Int(i64),
58    Double(f64),
59    String(String),
60    Ref(Ref),
61    IntList(Vec<i64>),
62    RefList(Vec<Ref>),
63    Device(Device),
64    Layout(#[serde(with = "torch_sys::LayoutDef")] Layout),
65    ScalarType(#[serde(with = "torch_sys::ScalarTypeDef")] ScalarType),
66    MemoryFormat(#[serde(with = "torch_sys::MemoryFormatDef")] MemoryFormat),
67    // Make this wrap the unit type, as `pyo3::FromPyObject` doesn't work with
68    // empty enum variants.
69    None(()),
70    PyObject(PickledPyObject),
71    // It is ok to just have IValue without an alias tracking cell as we just use
72    // WireValue as a way to serialize and send args to workers. We dont mutate the
73    // IValue and use the opaque wrapper to make accessing the IValue directly
74    // an unsafe op.
75    IValue(torch_sys::OpaqueIValue),
76}
77
78impl FromPyObject<'_> for WireValue {
79    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
80        if let Ok(ref_) = Ref::from_py_object(obj) {
81            Ok(WireValue::Ref(ref_))
82        } else if let Ok(list) = obj.downcast::<PyList>() {
83            let len = list.len();
84            if len == 0 {
85                // TODO: This is done for now as this seems to be the most common case for empty lists
86                // in torch ops but we should use the op schema to do this correctly.
87                return Ok(WireValue::IntList(vec![]));
88            }
89
90            // SAFETY: We know it is within bounds
91            let item = unsafe { list.get_item_unchecked(0) };
92            let len = list.len();
93            if let Ok(int) = item.extract::<i64>() {
94                let mut int_list = Vec::with_capacity(len);
95                int_list.push(int);
96                for item in list.iter().skip(1) {
97                    int_list.push(item.extract::<i64>().map_err(|_| {
98                        PyValueError::new_err(format!(
99                            "Expected homogeneous list of ints got: {:?}",
100                            list
101                        ))
102                    })?);
103                }
104                return Ok(WireValue::IntList(int_list));
105            }
106            if let Ok(ref_) = Ref::from_py_object(&item) {
107                let mut ref_list = Vec::with_capacity(len);
108                ref_list.push(ref_);
109                for item in list.iter().skip(1) {
110                    ref_list.push(Ref::from_py_object(&item).map_err(|_| {
111                        PyValueError::new_err(format!(
112                            "Expected homogeneous list of ints got: {:?}",
113                            list
114                        ))
115                    })?);
116                }
117                return Ok(WireValue::RefList(ref_list));
118            }
119            Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?))
120        } else if obj.is_none() {
121            Ok(WireValue::None(()))
122        } else if let Ok(bool_) = obj.downcast::<PyBool>() {
123            Ok(WireValue::Bool(bool_.is_true()))
124        } else if let Ok(int) = obj.extract::<i64>() {
125            Ok(WireValue::Int(int))
126        } else if let Ok(double) = obj.downcast::<PyFloat>() {
127            Ok(WireValue::Double(double.value()))
128        } else if let Ok(string) = obj.downcast::<PyString>() {
129            Ok(WireValue::String(string.to_str()?.to_string()))
130        } else if let Ok(device) = obj.extract::<Device>() {
131            Ok(WireValue::Device(device))
132        } else if let Ok(layout) = obj.extract::<Layout>() {
133            Ok(WireValue::Layout(layout))
134        } else if let Ok(scalar_type) = obj.extract::<ScalarType>() {
135            Ok(WireValue::ScalarType(scalar_type))
136        } else if let Ok(memory_format) = obj.extract::<MemoryFormat>() {
137            Ok(WireValue::MemoryFormat(memory_format))
138        } else {
139            Ok(WireValue::PyObject(PickledPyObject::pickle(obj)?))
140        }
141    }
142}
143
144impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for WireValue {
145    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
146        match self {
147            WireValue::Ref(ref_) => ref_.into_bound_py_any(py),
148            WireValue::RefList(ref_list) => ref_list.clone().into_bound_py_any(py),
149            WireValue::Int(int) => int.into_bound_py_any(py),
150            WireValue::IntList(int_list) => int_list.clone().into_bound_py_any(py),
151            WireValue::Double(double) => double.into_bound_py_any(py),
152            WireValue::Bool(bool_) => bool_.into_bound_py_any(py),
153            WireValue::String(string) => string.into_bound_py_any(py),
154            WireValue::Device(device) => device.into_bound_py_any(py),
155            WireValue::Layout(val) => val.into_bound_py_any(py),
156            WireValue::ScalarType(val) => val.into_bound_py_any(py),
157            WireValue::MemoryFormat(val) => val.into_bound_py_any(py),
158            WireValue::None(()) => PyNone::get(py).into_bound_py_any(py),
159            WireValue::PyObject(val) => val.unpickle(py),
160            // SAFETY: WireValue is only used for serde between client and worker.
161            // This function is used to access the args / kwargs of a function call
162            // on the client side only.
163            WireValue::IValue(val) => unsafe { val.try_to_object_unsafe(py) },
164        }
165    }
166}
167
168impl From<PyObject> for WireValue {
169    fn from(obj: PyObject) -> Self {
170        Python::with_gil(|py| WireValue::PyObject(PickledPyObject::pickle(obj.bind(py)).unwrap()))
171    }
172}
173
174impl WireValue {
175    fn from_pyobject_with_torch_op_arg_type(
176        obj: Bound<'_, PyAny>,
177        type_: &torch_sys::call_op::TypePtr,
178        num_elements: i32,
179        allow_nums_as_tensors: bool,
180    ) -> PyResult<Self> {
181        if type_.is_tensor() || type_.is_optional_tensor() {
182            if type_.is_optional_tensor() && obj.is_none() {
183                return Ok(WireValue::None(()));
184            } else if let Ok(ref_) = Ref::from_py_object(&obj) {
185                return Ok(WireValue::Ref(ref_));
186            }
187        }
188        if type_.is_tensor_list() || type_.is_optional_tensor_list() {
189            if type_.is_optional_tensor_list() && obj.is_none() {
190                return Ok(WireValue::None(()));
191            }
192            let list = obj.downcast::<PyList>()?;
193            let len = list.len();
194            if len == 0 {
195                return Ok(WireValue::RefList(vec![]));
196            }
197            // SAFETY: We know it is within bounds
198            let item = unsafe { list.get_item_unchecked(0) };
199            if let Ok(ref_) = Ref::from_py_object(&item) {
200                let mut ref_list = Vec::with_capacity(len);
201                ref_list.push(ref_);
202                for item in list.iter().skip(1) {
203                    ref_list.push(Ref::from_py_object(&item).map_err(|_| {
204                        PyValueError::new_err(format!(
205                            "Expected homogeneous list of refs got: {:?}",
206                            list
207                        ))
208                    })?);
209                }
210                return Ok(WireValue::RefList(ref_list));
211            }
212        }
213        OpaqueIValue::from_py_object_with_type(obj, type_, num_elements, allow_nums_as_tensors)
214            .map(WireValue::IValue)
215    }
216}
217
218pub fn func_call_args_to_wire_values(
219    func: Option<&ResolvableFunction>,
220    args: &Bound<'_, PyTuple>,
221    kwargs: &Bound<'_, PyDict>,
222) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
223    if let Some((op, overload)) = func.and_then(|func| func.as_torch_op()) {
224        torch_op_args_to_wire_values(&op, &overload, args, kwargs)
225    } else {
226        python_func_args_to_wire_value(args, kwargs)
227    }
228}
229
230fn torch_op_args_to_wire_values(
231    op: &str,
232    overload: &str,
233    args: &Bound<'_, PyTuple>,
234    kwargs: &Bound<'_, PyDict>,
235) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
236    let args_info = torch_sys::call_op::get_schema_args_info(op, overload).map_err(|err| {
237        PyValueError::new_err(format!(
238            "Failed to get the operator schema for {}::{}: {}",
239            op, overload, err
240        ))
241    })?;
242
243    let args = args
244        .iter()
245        .zip(&args_info)
246        .map(|(arg, arg_info)| {
247            WireValue::from_pyobject_with_torch_op_arg_type(
248                arg,
249                arg_info.type_,
250                arg_info.num_elements,
251                arg_info.allows_number_as_tensor,
252            )
253        })
254        .collect::<Result<Vec<_>, _>>()?;
255    let kwargs = kwargs
256        .iter()
257        .map(|(k, v)| {
258            let key = k.extract::<String>()?;
259            let arg_info = args_info
260                .iter()
261                .find(|arg_info| arg_info.name == key)
262                .ok_or_else(|| {
263                    PyValueError::new_err(format!(
264                        "Torch op {}::{} does not support kwarg {}",
265                        op, overload, key
266                    ))
267                })?;
268            let val = WireValue::from_pyobject_with_torch_op_arg_type(
269                v,
270                arg_info.type_,
271                arg_info.num_elements,
272                arg_info.allows_number_as_tensor,
273            )?;
274            Ok((key, val))
275        })
276        .collect::<Result<HashMap<_, _>, PyErr>>()?;
277    Ok((args, kwargs))
278}
279
280fn python_func_args_to_wire_value(
281    args: &Bound<'_, PyTuple>,
282    kwargs: &Bound<'_, PyDict>,
283) -> PyResult<(Vec<WireValue>, HashMap<String, WireValue>)> {
284    let args = args
285        .iter()
286        .map(|arg| Ok(WireValue::PyObject(PickledPyObject::pickle(&arg)?)))
287        .collect::<PyResult<_>>()?;
288    let kwargs = kwargs
289        .iter()
290        .map(|(k, v)| {
291            Ok((
292                k.extract::<String>()?,
293                WireValue::PyObject(PickledPyObject::pickle(&v)?),
294            ))
295        })
296        .collect::<Result<HashMap<_, _>, PyErr>>()?;
297    Ok((args, kwargs))
298}
299
300#[cfg(test)]
301mod tests {
302    use std::assert_matches::assert_matches;
303
304    use anyhow::Result;
305    use anyhow::bail;
306    use paste::paste;
307    use pyo3::Python;
308    use pyo3::ffi::c_str;
309    use pyo3::types::PyDict;
310    use torch_sys::DeviceType;
311    use torch_sys::ScalarType;
312
313    use super::*;
314    use crate::worker::Ref;
315
316    const MOCK_REFERNCABLE_MODULE: &std::ffi::CStr = c_str!(
317        r#"
318class Referencable:
319    def __init__(self, ref: int):
320        self.ref = ref
321
322    def __monarch_ref__(self):
323        return self.ref
324"#
325    );
326
327    fn setup() -> Result<()> {
328        pyo3::prepare_freethreaded_python();
329        // We need to load torch to initialize some internal structures used by
330        // the FFI funcs we use to convert ivalues to/from py objects.
331        Python::with_gil(|py| py.run(c_str!("import torch"), None, None))?;
332        Ok(())
333    }
334
335    fn create_py_object() -> PyObject {
336        pyo3::prepare_freethreaded_python();
337        Python::with_gil(|py| {
338            let dict = PyDict::new(py);
339            dict.set_item("foo", "bar").unwrap();
340            dict.into_any().clone().unbind()
341        })
342    }
343
344    macro_rules! generate_wire_value_from_py_tests {
345        ($($kind:ident, $input:expr);* $(;)?) => {
346            paste! {
347                $(
348                    #[test]
349                    fn [<test_wire_value_from_py_$kind:snake:lower>]() -> Result<()> {
350                            setup()?;
351                            Python::with_gil(|py| {
352                                let actual = $input.into_pyobject(py)?.extract::<WireValue>()?;
353                                assert_matches!(actual, WireValue::$kind(_));
354                                anyhow::Ok(())
355                            })
356                    }
357                )*
358
359                #[test]
360                fn test_wire_value_from_py_none() -> Result<()> {
361                    setup()?;
362                    Python::with_gil(|py| {
363                        let obj = PyNone::get(py).into_pyobject(py)?;
364                        let actual = obj.extract::<WireValue>()?;
365                        assert_matches!(actual, WireValue::None(_));
366                        anyhow::Ok(())
367                    })
368                }
369
370                #[test]
371                fn test_wire_value_from_py_empty_list() -> Result<()> {
372                    setup()?;
373                    Python::with_gil(|py| {
374                        let obj: PyObject = PyList::empty(py).into_any().unbind();
375                        let actual = obj.extract::<WireValue>(py)?;
376                        match actual {
377                            WireValue::IntList(list) if list.len() == 0 => (),
378                            _ => bail!("Expected empty list to be converted to empty int list"),
379                        }
380                        anyhow::Ok(())
381                    })
382                }
383
384                #[test]
385                fn test_wire_value_from_py_referencable_class() -> Result<()> {
386                    setup()?;
387                    Python::with_gil(|py| {
388                        let referencable = PyModule::from_code(
389                            py,
390                            MOCK_REFERNCABLE_MODULE,
391                            c_str!("referencable.py"),
392                            c_str!("referencable"),
393                        )?;
394                        let ref_ = referencable.getattr("Referencable")?.call1((1,))?.unbind();
395                        let actual = ref_.extract::<WireValue>(py)?;
396                        assert_matches!(actual, WireValue::Ref(Ref { id: 1 }));
397                        anyhow::Ok(())
398                    })
399                }
400
401                #[test]
402                fn test_wire_value_from_py_roundtrip_was_exhaustive() {
403                    let val = WireValue::Int(0);
404                    match val {
405                        $(WireValue::$kind(_) => (),)*
406                        WireValue::None(_) => (),
407                        // Can't test from py here as PyObject behaves as catch all for conversion from PY.
408                        // We will manually convert torch ops args to IValue respecting the schema so its
409                        // not super important to have this.
410                        WireValue::IValue(_) => (),
411                    }
412                }
413            }
414        }
415    }
416
417    // Generate exhaustive roundtrip tests for all IValue kind.
418    // If you got a "non-exhaustive patterns" error here, you need to add a new
419    // test entry for your IValue kind!
420    generate_wire_value_from_py_tests! {
421        Bool, false;
422        Double, 1.23f64;
423        Int, 123i64;
424        IntList, vec![1i64];
425        Ref, Ref::from(1);
426        RefList, vec![Ref::from(1), Ref::from(2)];
427        String, "foobar".to_owned();
428        Device, Device::new(DeviceType::CPU);
429        Layout, Layout(2);
430        ScalarType, ScalarType(3);
431        MemoryFormat, MemoryFormat(1);
432        PyObject, create_py_object();
433    }
434}