torch_sys/
ivalue.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::ffi::c_void;
10use std::fmt;
11use std::mem::MaybeUninit;
12
13use cxx::ExternType;
14use cxx::type_id;
15use monarch_types::TryIntoPyObjectUnsafe;
16use paste::paste;
17use pyo3::exceptions::PyValueError;
18use pyo3::prelude::*;
19use serde::Deserialize;
20use serde::Serialize;
21
22use crate::CloneUnsafe;
23use crate::Device;
24use crate::Tensor;
25use crate::bridge::clone_iv;
26use crate::bridge::ffi;
27use crate::cell::AliasTrackingRefCell;
28
29/// Rust binding for the C++ type `c10::IValue`.
30///
31/// `IValue` is a tagged union type that can hold any input to a PyTorch
32/// operator. See [`IValueKind`] for the list of supported types.
33///
34/// # Safety
35///
36/// `IValue` either contains [`Copy`]-able data or a Tensor-like object, so it
37/// inherits the safety properties of [`Tensor`]. See the safety discussion in
38/// [`Tensor#safety`] for more info.
39#[repr(C)]
40pub struct IValue {
41    /// #[doc(hidden)]
42    /// Internal representation of IValue in C++. An IValue is 16 bytes, with 8
43    /// bytes for a payload and 8 bytes for a type tag. We assert in `bridge.h`
44    /// that the size and alignment are what we expect.
45    repr: [*mut c_void; 2],
46}
47
48// SAFETY: Register our custom bindings with cxx. IValue is trivial, see the
49// discussion in `bridge.h`.
50unsafe impl ExternType for IValue {
51    type Id = type_id!("c10::IValue");
52    type Kind = cxx::kind::Trivial;
53}
54
55impl Drop for IValue {
56    fn drop(&mut self) {
57        // SAFETY: calls the C++ destructor for IValue, potentially
58        // decrementing a tensor refcount.
59        unsafe { crate::bridge::drop(self) };
60    }
61}
62
63impl PartialEq for IValue {
64    fn eq(&self, other: &Self) -> bool {
65        ffi::ivalues_equal_operator(self, other)
66    }
67}
68
69/// SAFETY: IValue is [`Send`], it is either a copyable type or atomically
70/// refcounted via `c10::intrusive_ptr`.
71unsafe impl Send for IValue {}
72
73/// SAFETY: IValue is [`Sync`], due to safety in exposing any of the interior
74/// mutability of the payload it holds. The value is converted to native types
75/// like [`Tensor`] for use in rust or left opaque.
76/// See [`OpaqueIValue`] for more details.
77unsafe impl Sync for IValue {}
78
79impl Serialize for IValue {
80    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
81    where
82        S: serde::Serializer,
83    {
84        serializer.serialize_bytes(
85            ffi::serialize_ivalue(self)
86                .map_err(serde::ser::Error::custom)?
87                .as_slice(),
88        )
89    }
90}
91
92impl<'de> Deserialize<'de> for IValue {
93    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
94    where
95        D: serde::Deserializer<'de>,
96    {
97        let buf: &[u8] = Deserialize::deserialize(deserializer)?;
98        ffi::deserialize_ivalue(buf).map_err(serde::de::Error::custom)
99    }
100}
101
102impl CloneUnsafe for IValue {
103    /// This is *unsafe*, it creates an alias of the underlying data that is
104    /// not tracked by Rust. We use this to interface with C++ functions that
105    /// expect an owned IValue.
106    ///
107    /// The contract for calling this function is that the clone is local and
108    /// ephemeral. More precisely:
109    /// 1. The clone must not be sent to another thread (local).
110    /// 2. You must guarantee that clone is dropped before the originating
111    ///    mutable reference is dropped (ephemeral).
112    unsafe fn clone_unsafe(&self) -> Self {
113        let mut ivalue = MaybeUninit::<IValue>::uninit();
114        let new = ivalue.as_mut_ptr().cast();
115        // SAFETY: `ivalue` will be correctly initialized by the call to `clone_iv`.
116        unsafe {
117            clone_iv(self, new);
118            ivalue.assume_init()
119        }
120    }
121}
122
123/// An opaque container for an [`IValue`]. This is used to restrict safe direct access
124/// to the underlying [`IValue`].
125#[derive(Debug, Serialize, Deserialize)]
126pub struct OpaqueIValue(IValue);
127
128impl OpaqueIValue {
129    /// This is *unsafe*, it creates an alias of the underlying data that is
130    /// not tracked by Rust. We need this to interface with C++ functions that
131    /// expect an owned IValue. The caller is responsible for ensuring that
132    /// this is done in a safe way.
133    pub(crate) unsafe fn ivalue(&self) -> IValue {
134        // SAFETY: See above
135        unsafe { self.0.clone_unsafe() }
136    }
137
138    pub fn from_py_object_with_type(
139        obj: Bound<'_, PyAny>,
140        type_: &crate::call_op::TypePtr,
141        num_elements: i32,
142        allow_nums_as_tensors: bool,
143    ) -> PyResult<OpaqueIValue> {
144        IValue::from_py_object_with_type(obj, type_, num_elements, allow_nums_as_tensors)
145            .map(OpaqueIValue)
146    }
147}
148
149impl Clone for OpaqueIValue {
150    /// This creates a deep copy of the underlying data and can be expensive.
151    /// It might also panic if the `IValue` is not cloneable.
152    fn clone(&self) -> Self {
153        let serialized = bincode::serialize(&self.0).unwrap();
154        bincode::deserialize(&serialized).unwrap()
155    }
156}
157
158impl CloneUnsafe for OpaqueIValue {
159    unsafe fn clone_unsafe(&self) -> Self {
160        // SAFETY: See discussion for `IValue::clone_unsafe`.
161        Self(unsafe { self.0.clone_unsafe() })
162    }
163}
164
165impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for OpaqueIValue {
166    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
167        // SAFETY: See discussion for `IValue::clone_unsafe`.
168        unsafe { self.ivalue() }.into_pyobject(py)
169    }
170}
171
172pub type OpaqueIValueCell = AliasTrackingRefCell<OpaqueIValue>;
173
174impl From<OpaqueIValue> for OpaqueIValueCell {
175    #[inline]
176    fn from(value: OpaqueIValue) -> Self {
177        Self::new(value)
178    }
179}
180
181macro_rules! gen_is_impl {
182    ($enum:ty, [$($kind:ident),* $(,)? ]) => {
183        paste! {
184            $(
185                pub fn [<is_ $kind:snake:lower>](&self) -> bool {
186                    self.[<is $kind>]()
187                }
188            )*
189
190            pub fn is_other(&self) -> bool {
191                self.kind() == $enum::Other
192            }
193
194            fn __exhaustive_checker(foo: $enum) {
195                match foo {
196                    $($enum::$kind => (),)*
197                    IValueKind::Other => (),
198                }
199            }
200
201            pub fn kind(&self) -> $enum {
202                if false {
203                    unreachable!();
204                } $(else if self.[<is_ $kind:snake:lower>]() {
205                    $enum::$kind
206                })*
207                else {
208                    $enum::Other
209                }
210            }
211        }
212    }
213}
214
215macro_rules! gen_from_impl {
216    ($enum:ty, $($kind:ident, $from_type:ty);* $(;)?) => {
217        paste! {
218            $(
219                impl From<$from_type> for IValue {
220                    fn from(value: $from_type) -> Self {
221                        ffi::[<ivalue_from_ $kind:snake:lower>](value)
222                    }
223                }
224
225            )*
226
227            fn __exhaustive_checker(foo: $enum) {
228                match foo {
229                    $($enum::$kind => (),)*
230                    IValueKind::None => (),
231                    IValueKind::Other => (),
232                }
233            }
234        }
235
236    }
237}
238
239impl From<()> for IValue {
240    fn from(_value: ()) -> Self {
241        ffi::ivalue_from_none()
242    }
243}
244
245impl IValue {
246    pub fn to_tensor(self) -> Option<Tensor> {
247        ffi::toTensor(self).ok()
248    }
249    pub fn to_string(&self) -> Option<String> {
250        ffi::toString(self).ok()
251    }
252    pub fn to_int_list(&self) -> Option<Vec<i64>> {
253        ffi::toIntList(self).ok()
254    }
255    pub fn to_int(&self) -> Option<i64> {
256        self.toInt().ok()
257    }
258    pub fn to_double(&self) -> Option<f64> {
259        self.toDouble().ok()
260    }
261    pub fn to_bool(&self) -> Option<bool> {
262        self.toBool().ok()
263    }
264    pub fn to_tensor_list(self) -> Option<Vec<Tensor>> {
265        ffi::toTensorList(self).ok()
266    }
267    pub fn to_device(&self) -> Option<Device> {
268        self.toDevice().ok()
269    }
270    pub fn to_none(&self) -> Option<()> {
271        if self.is_none() { Some(()) } else { None }
272    }
273    pub fn to_opaque(self) -> Option<OpaqueIValue> {
274        if self.is_other() {
275            Some(OpaqueIValue(self))
276        } else {
277            None
278        }
279    }
280    // Generate is_ methods for all IValue kinds.
281    // If you get a compile error here, make sure:
282    //   - Your new kind is registered on IValueKind
283    //   - You added a field here.
284    gen_is_impl! {
285        IValueKind, [
286            Tensor,
287            String,
288            IntList,
289            Int,
290            Double,
291            Bool,
292            TensorList,
293            Device,
294            None,
295        ]
296    }
297
298    pub fn from_py_object_with_type(
299        obj: Bound<'_, PyAny>,
300        type_: &crate::call_op::TypePtr,
301        num_elements: i32,
302        allow_nums_as_tensors: bool,
303    ) -> PyResult<IValue> {
304        ffi::ivalue_from_py_object_with_type(obj.into(), type_, num_elements, allow_nums_as_tensors)
305            .map_err(|err| {
306                PyValueError::new_err(format!(
307                    "Failed to extract IValue from python object: {:?}",
308                    err
309                ))
310            })
311    }
312
313    pub(crate) fn from_py_object_or_none(obj: &Bound<'_, PyAny>) -> Option<IValue> {
314        ffi::py_object_is_ivalue(obj.clone().into())
315            .then(|| ffi::ivalue_from_arbitrary_py_object(obj.into()).unwrap())
316    }
317}
318
319// impl `From` for all IValue kinds.
320// If you get a compile error here, make sure:
321//   - Your new kind is registered on IValueKind
322//   - You added a field here.
323gen_from_impl! {
324    IValueKind,
325    Tensor, Tensor;
326    String, &String;
327    IntList, &[i64];
328    Int, i64;
329    Double, f64;
330    Bool, bool;
331    TensorList, Vec<Tensor>;
332    Device, Device;
333}
334
335/// Enum representing the different internal types an [`IValue`] can hold.
336///
337/// Check each variant docs to see what the internal storage in C++ is, and what
338/// the cost of moving across the Rust<>C++ boundary is.
339#[derive(Debug, PartialEq, Eq, Clone, Copy)]
340pub enum IValueKind {
341    /// - C++ type: `at::Tensor`
342    /// - Rust type: [`Tensor`]
343    Tensor,
344    /// - C++ type: `bool`
345    /// - Rust type: [`bool`]
346    Bool,
347    /// - C++ type: `int64_t`
348    /// - Rust type: [`i64`]
349    Int,
350    /// - C++ type: `c10::List<int64_t>`
351    /// - Rust type: [`Vec<i64>`]
352    ///
353    /// <div class="warning">
354    /// Passing across the C++-Rust boundary will copy the vector.
355    /// </div>
356    IntList,
357    /// - C++ type: `double`
358    /// - Rust type: [`f64`]
359    Double,
360    /// - C++ type: `c10::intrusive_ptr<ConstantString>`
361    /// - Rust type: [`String`]
362    ///
363    /// <div class="warning">
364    /// Passing across the C++-Rust boundary will copy the string.
365    /// </div>
366    String,
367    /// - C++ type: `c10::List<at::Tensor>`
368    /// - Rust type: [`Vec<Tensor>`]
369    ///
370    /// <div class="warning">
371    /// Passing across the C++-Rust boundary will copy the vector.
372    /// </div>
373    TensorList,
374    /// - C++ type: `c10::Device`
375    /// - Rust type: [`Device`]
376    Device,
377    None,
378
379    /// Catch-all for all other types. This is used for types that are not
380    /// natively supported in rust and can remain as opaque IValues for
381    /// interacting with torch apis. There is an overhead associated with
382    /// tracking alias and borrows for any trivial IValues being converted
383    /// to this type so they should be natively supported. Most of them are
384    /// already supported.
385    Other,
386}
387
388impl fmt::Debug for IValue {
389    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
390        write!(f, "{}", ffi::debug_print(self).map_err(|_| fmt::Error)?)
391    }
392}
393
394impl<'py> IntoPyObject<'py> for IValue {
395    type Target = PyAny;
396    type Output = Bound<'py, Self::Target>;
397    type Error = PyErr;
398
399    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
400        ffi::arbitrary_ivalue_to_py_object(self)
401            .map_err(|e| PyValueError::new_err(format!("Failed converting to py: {}", e)))?
402            .into_pyobject(py)
403    }
404}
405
406impl FromPyObject<'_> for IValue {
407    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
408        ffi::ivalue_from_arbitrary_py_object(obj.into()).map_err(|e| {
409            PyValueError::new_err(format!("Failed extracting from py: {}: {}", e, obj))
410        })
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    #![allow(clippy::useless_vec)]
417
418    use pyo3::types::PyFloat;
419
420    use super::*;
421    use crate::DeviceType;
422    use crate::bridge::ffi::test_make_opaque_ivalue;
423    use crate::bridge::ffi::test_make_tensor;
424    use crate::bridge::ffi::test_make_undefined_tensor_ivalue;
425
426    impl From<String> for IValue {
427        fn from(value: String) -> Self {
428            ffi::ivalue_from_string(&value)
429        }
430    }
431
432    // Check for the equality of two IValues using tensor.equal to compare
433    // tensors.
434    fn ivalues_equal_with_tensor_equal(a: IValue, b: IValue) -> bool {
435        if a.isTensor() {
436            return b.isTensor() && a.to_tensor().unwrap().equal(&b.to_tensor().unwrap());
437        }
438
439        if a.isTensorList() {
440            if !b.isTensorList() {
441                return false;
442            }
443            let a_list = a.to_tensor_list().unwrap();
444            let b_list = b.to_tensor_list().unwrap();
445            if a_list.len() != b_list.len() {
446                return false;
447            }
448            for (a_tensor, b_tensor) in a_list.iter().zip(b_list.iter()) {
449                if !a_tensor.equal(b_tensor) {
450                    return false;
451                }
452            }
453            return true;
454        }
455
456        a == b
457    }
458
459    #[test]
460    fn test_ivalue_from_py_object_with_type() {
461        pyo3::prepare_freethreaded_python();
462
463        let args_info =
464            crate::call_op::get_schema_args_info("aten::_foreach_add_", "Tensor").unwrap();
465        let (list, tensor, tensor_1, tensor_1_err) = Python::with_gil(|py| {
466            let list = pyo3::types::PyList::empty(py).into_any();
467            let none = py.None().into_bound(py);
468            let one = PyFloat::new(py, 1.0).into_any();
469            (
470                IValue::from_py_object_with_type(
471                    list,
472                    args_info[0].type_,
473                    args_info[0].num_elements,
474                    false,
475                )
476                .unwrap(),
477                IValue::from_py_object_with_type(
478                    none,
479                    args_info[1].type_,
480                    args_info[1].num_elements,
481                    false,
482                )
483                .unwrap(),
484                IValue::from_py_object_with_type(
485                    one.clone(),
486                    args_info[1].type_,
487                    args_info[1].num_elements,
488                    true,
489                )
490                .unwrap(),
491                IValue::from_py_object_with_type(
492                    one,
493                    args_info[1].type_,
494                    args_info[1].num_elements,
495                    false,
496                ),
497            )
498        });
499        assert!(list.is_tensor_list());
500        assert!(tensor.is_tensor());
501        assert!(tensor_1.is_tensor());
502        assert!(tensor_1_err.is_err());
503    }
504
505    macro_rules! generate_py_object_roundtrip_tests {
506        ($($kind:ident, $input:expr_2021);* $(;)?) => {
507            paste! {
508                $(
509                    #[test]
510                    fn [<test_py_object_roundtrip_ $kind:snake:lower>]() {
511                            pyo3::prepare_freethreaded_python();
512                            // We need to load torch to initialize some internal
513                            // structures used by the FFI funcs we use to convert
514                            // ivalues to/from py objects.
515                            Python::with_gil(|py| py.run(pyo3::ffi::c_str!("import torch"), None, None)).unwrap();
516                            let original = IValue::from($input);
517                            // SAFETY: `TryIntoPyObject` consumes the value, so
518                            // we clone here to use for the `assert_eq` at end.
519                            let converted = unsafe { original.clone_unsafe() };
520                            let converted = Python::with_gil(|py| {
521                                let py_object = converted.into_pyobject(py).unwrap();
522                                anyhow::Ok(IValue::extract_bound(&py_object).unwrap())
523                            }).unwrap();
524                            assert!(ivalues_equal_with_tensor_equal(original, converted));
525                    }
526                )*
527
528                #[test]
529                fn test_py_object_roundtrip_was_exhaustive() {
530                    match IValueKind::Int {
531                        $(IValueKind::$kind => (),)*
532                    }
533                }
534            }
535        }
536    }
537
538    // Generate exhaustive roundtrip tests for all IValue kind.
539    // If you got a "non-exhaustive patterns" error here, you need to add a new
540    // test entry for your IValue kind!
541    generate_py_object_roundtrip_tests! {
542        Int, 123;
543        Double, 1.23;
544        String, "foobar".to_owned();
545        IntList, [1, 2, 3].as_slice();
546        Bool, false;
547        Tensor, test_make_tensor();
548        TensorList, vec![test_make_tensor()];
549        Device, Device::new(DeviceType::CPU);
550        None, ();
551        Other, test_make_opaque_ivalue();
552    }
553
554    macro_rules! generate_serde_roundtrip_tests {
555        ($($kind:ident, $input:expr_2021);* $(;)?) => {
556            paste! {
557                $(
558                    #[test]
559                    fn [<test_serde_roundtrip_ $kind:snake:lower>]() {
560                            pyo3::prepare_freethreaded_python();
561                            // We need to load torch to initialize some internal
562                            // structures used by the FFI funcs we use to convert
563                            // ivalues to/from py objects.
564                            Python::with_gil(|py| py.run(pyo3::ffi::c_str!("import torch"), None, None)).unwrap();
565                            let original = IValue::from($input);
566                            let converted: IValue = bincode::deserialize(&bincode::serialize(&original).unwrap()).unwrap();
567                            assert!(ivalues_equal_with_tensor_equal(original, converted));
568                    }
569                )*
570
571                #[test]
572                fn test_serde_roundtrip_was_exhaustive() {
573                    match IValueKind::Int {
574                        $(IValueKind::$kind => (),)*
575                    }
576                }
577            }
578        }
579    }
580
581    // Generate exhaustive serde roundtrip tests for all IValue kind.
582    // If you got a "non-exhaustive patterns" error here, you need to add a new
583    // test entry for your IValue kind!
584    generate_serde_roundtrip_tests! {
585        Int, 123;
586        Double, 1.23;
587        String, "foobar".to_owned();
588        IntList, [1, 2, 3].as_slice();
589        Bool, false;
590        Tensor, test_make_tensor();
591        TensorList, vec![test_make_tensor()];
592        Device, Device::new(DeviceType::CPU);
593        None, ();
594        Other, test_make_opaque_ivalue();
595    }
596
597    #[test]
598    fn test_serde_roundtrip_undefined_tensor() {
599        let original = test_make_undefined_tensor_ivalue();
600        assert!(original.is_tensor());
601        assert!(
602            // SAFETY: Since it is an undefined tensor that we dont mutate,
603            // it is safe to clone in this test.
604            !unsafe { original.clone_unsafe() }
605                .to_tensor()
606                .unwrap()
607                .defined()
608        );
609        let converted: IValue =
610            bincode::deserialize(&bincode::serialize(&original).unwrap()).unwrap();
611        assert!(converted.is_tensor());
612        assert!(!converted.to_tensor().unwrap().defined());
613    }
614}