torch_sys/
ivalue.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::ffi::c_void;
10use std::fmt;
11use std::mem::MaybeUninit;
12
13use cxx::ExternType;
14use cxx::type_id;
15use monarch_types::TryIntoPyObjectUnsafe;
16use paste::paste;
17use pyo3::exceptions::PyValueError;
18use pyo3::prelude::*;
19use serde::Deserialize;
20use serde::Serialize;
21use serde::de::Visitor;
22
23use crate::CloneUnsafe;
24use crate::Device;
25use crate::Tensor;
26use crate::bridge::clone_iv;
27use crate::bridge::ffi;
28use crate::cell::AliasTrackingRefCell;
29
30/// Rust binding for the C++ type `c10::IValue`.
31///
32/// `IValue` is a tagged union type that can hold any input to a PyTorch
33/// operator. See [`IValueKind`] for the list of supported types.
34///
35/// # Safety
36///
37/// `IValue` either contains [`Copy`]-able data or a Tensor-like object, so it
38/// inherits the safety properties of [`Tensor`]. See the safety discussion in
39/// [`Tensor#safety`] for more info.
40#[repr(C)]
41pub struct IValue {
42    /// #[doc(hidden)]
43    /// Internal representation of IValue in C++. An IValue is 16 bytes, with 8
44    /// bytes for a payload and 8 bytes for a type tag. We assert in `bridge.h`
45    /// that the size and alignment are what we expect.
46    repr: [*mut c_void; 2],
47}
48
49// SAFETY: Register our custom bindings with cxx. IValue is trivial, see the
50// discussion in `bridge.h`.
51unsafe impl ExternType for IValue {
52    type Id = type_id!("c10::IValue");
53    type Kind = cxx::kind::Trivial;
54}
55
56impl Drop for IValue {
57    fn drop(&mut self) {
58        // SAFETY: calls the C++ destructor for IValue, potentially
59        // decrementing a tensor refcount.
60        unsafe { crate::bridge::drop(self) };
61    }
62}
63
64impl PartialEq for IValue {
65    fn eq(&self, other: &Self) -> bool {
66        ffi::ivalues_equal_operator(self, other)
67    }
68}
69
70/// SAFETY: IValue is [`Send`], it is either a copyable type or atomically
71/// refcounted via `c10::intrusive_ptr`.
72unsafe impl Send for IValue {}
73
74/// SAFETY: IValue is [`Sync`], due to safety in exposing any of the interior
75/// mutability of the payload it holds. The value is converted to native types
76/// like [`Tensor`] for use in rust or left opaque.
77/// See [`OpaqueIValue`] for more details.
78unsafe impl Sync for IValue {}
79
80impl Serialize for IValue {
81    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
82    where
83        S: serde::Serializer,
84    {
85        serializer.serialize_bytes(
86            ffi::serialize_ivalue(self)
87                .map_err(serde::ser::Error::custom)?
88                .as_slice(),
89        )
90    }
91}
92
93impl<'de> Deserialize<'de> for IValue {
94    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
95    where
96        D: serde::Deserializer<'de>,
97    {
98        struct IValueVisitor;
99
100        impl<'de> Visitor<'de> for IValueVisitor {
101            type Value = IValue;
102
103            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
104                f.write_str("raw ivalue bytes")
105            }
106
107            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
108            where
109                E: serde::de::Error,
110            {
111                ffi::deserialize_ivalue(v).map_err(E::custom)
112            }
113        }
114
115        deserializer.deserialize_bytes(IValueVisitor)
116    }
117}
118
119impl CloneUnsafe for IValue {
120    /// This is *unsafe*, it creates an alias of the underlying data that is
121    /// not tracked by Rust. We use this to interface with C++ functions that
122    /// expect an owned IValue.
123    ///
124    /// The contract for calling this function is that the clone is local and
125    /// ephemeral. More precisely:
126    /// 1. The clone must not be sent to another thread (local).
127    /// 2. You must guarantee that clone is dropped before the originating
128    ///    mutable reference is dropped (ephemeral).
129    unsafe fn clone_unsafe(&self) -> Self {
130        let mut ivalue = MaybeUninit::<IValue>::uninit();
131        let new = ivalue.as_mut_ptr().cast();
132        // SAFETY: `ivalue` will be correctly initialized by the call to `clone_iv`.
133        unsafe {
134            clone_iv(self, new);
135            ivalue.assume_init()
136        }
137    }
138}
139
140/// An opaque container for an [`IValue`]. This is used to restrict safe direct access
141/// to the underlying [`IValue`].
142#[derive(Debug, Serialize, Deserialize)]
143pub struct OpaqueIValue(IValue);
144
145impl OpaqueIValue {
146    /// This is *unsafe*, it creates an alias of the underlying data that is
147    /// not tracked by Rust. We need this to interface with C++ functions that
148    /// expect an owned IValue. The caller is responsible for ensuring that
149    /// this is done in a safe way.
150    pub(crate) unsafe fn ivalue(&self) -> IValue {
151        // SAFETY: See above
152        unsafe { self.0.clone_unsafe() }
153    }
154
155    pub fn from_py_object_with_type(
156        obj: Bound<'_, PyAny>,
157        type_: &crate::call_op::TypePtr,
158        num_elements: i32,
159        allow_nums_as_tensors: bool,
160    ) -> PyResult<OpaqueIValue> {
161        IValue::from_py_object_with_type(obj, type_, num_elements, allow_nums_as_tensors)
162            .map(OpaqueIValue)
163    }
164}
165
166impl Clone for OpaqueIValue {
167    /// This creates a deep copy of the underlying data and can be expensive.
168    /// It might also panic if the `IValue` is not cloneable.
169    fn clone(&self) -> Self {
170        let serialized = bincode::serialize(&self.0).unwrap();
171        bincode::deserialize(&serialized).unwrap()
172    }
173}
174
175impl CloneUnsafe for OpaqueIValue {
176    unsafe fn clone_unsafe(&self) -> Self {
177        // SAFETY: See discussion for `IValue::clone_unsafe`.
178        Self(unsafe { self.0.clone_unsafe() })
179    }
180}
181
182impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for OpaqueIValue {
183    unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
184        // SAFETY: See discussion for `IValue::clone_unsafe`.
185        unsafe { self.ivalue() }.into_pyobject(py)
186    }
187}
188
189pub type OpaqueIValueCell = AliasTrackingRefCell<OpaqueIValue>;
190
191impl From<OpaqueIValue> for OpaqueIValueCell {
192    #[inline]
193    fn from(value: OpaqueIValue) -> Self {
194        Self::new(value)
195    }
196}
197
198macro_rules! gen_is_impl {
199    ($enum:ty, [$($kind:ident),* $(,)? ]) => {
200        paste! {
201            $(
202                pub fn [<is_ $kind:snake:lower>](&self) -> bool {
203                    self.[<is $kind>]()
204                }
205            )*
206
207            pub fn is_other(&self) -> bool {
208                self.kind() == $enum::Other
209            }
210
211            fn __exhaustive_checker(foo: $enum) {
212                match foo {
213                    $($enum::$kind => (),)*
214                    IValueKind::Other => (),
215                }
216            }
217
218            pub fn kind(&self) -> $enum {
219                if false {
220                    unreachable!();
221                } $(else if self.[<is_ $kind:snake:lower>]() {
222                    $enum::$kind
223                })*
224                else {
225                    $enum::Other
226                }
227            }
228        }
229    }
230}
231
232macro_rules! gen_from_impl {
233    ($enum:ty, $($kind:ident, $from_type:ty);* $(;)?) => {
234        paste! {
235            $(
236                impl From<$from_type> for IValue {
237                    fn from(value: $from_type) -> Self {
238                        ffi::[<ivalue_from_ $kind:snake:lower>](value)
239                    }
240                }
241
242            )*
243
244            fn __exhaustive_checker(foo: $enum) {
245                match foo {
246                    $($enum::$kind => (),)*
247                    IValueKind::None => (),
248                    IValueKind::Other => (),
249                }
250            }
251        }
252
253    }
254}
255
256impl From<()> for IValue {
257    fn from(_value: ()) -> Self {
258        ffi::ivalue_from_none()
259    }
260}
261
262impl IValue {
263    pub fn to_tensor(self) -> Option<Tensor> {
264        ffi::toTensor(self).ok()
265    }
266    pub fn to_string(&self) -> Option<String> {
267        ffi::toString(self).ok()
268    }
269    pub fn to_int_list(&self) -> Option<Vec<i64>> {
270        ffi::toIntList(self).ok()
271    }
272    pub fn to_int(&self) -> Option<i64> {
273        self.toInt().ok()
274    }
275    pub fn to_double(&self) -> Option<f64> {
276        self.toDouble().ok()
277    }
278    pub fn to_bool(&self) -> Option<bool> {
279        self.toBool().ok()
280    }
281    pub fn to_tensor_list(self) -> Option<Vec<Tensor>> {
282        ffi::toTensorList(self).ok()
283    }
284    pub fn to_device(&self) -> Option<Device> {
285        self.toDevice().ok()
286    }
287    pub fn to_none(&self) -> Option<()> {
288        if self.is_none() { Some(()) } else { None }
289    }
290    pub fn to_opaque(self) -> Option<OpaqueIValue> {
291        if self.is_other() {
292            Some(OpaqueIValue(self))
293        } else {
294            None
295        }
296    }
297    // Generate is_ methods for all IValue kinds.
298    // If you get a compile error here, make sure:
299    //   - Your new kind is registered on IValueKind
300    //   - You added a field here.
301    gen_is_impl! {
302        IValueKind, [
303            Tensor,
304            String,
305            IntList,
306            Int,
307            Double,
308            Bool,
309            TensorList,
310            Device,
311            None,
312        ]
313    }
314
315    pub fn from_py_object_with_type(
316        obj: Bound<'_, PyAny>,
317        type_: &crate::call_op::TypePtr,
318        num_elements: i32,
319        allow_nums_as_tensors: bool,
320    ) -> PyResult<IValue> {
321        ffi::ivalue_from_py_object_with_type(obj.into(), type_, num_elements, allow_nums_as_tensors)
322            .map_err(|err| {
323                PyValueError::new_err(format!(
324                    "Failed to extract IValue from python object: {:?}",
325                    err
326                ))
327            })
328    }
329
330    pub(crate) fn from_py_object_or_none(obj: &Bound<'_, PyAny>) -> Option<IValue> {
331        ffi::py_object_is_ivalue(obj.clone().into())
332            .then(|| ffi::ivalue_from_arbitrary_py_object(obj.into()).unwrap())
333    }
334}
335
336// impl `From` for all IValue kinds.
337// If you get a compile error here, make sure:
338//   - Your new kind is registered on IValueKind
339//   - You added a field here.
340gen_from_impl! {
341    IValueKind,
342    Tensor, Tensor;
343    String, &String;
344    IntList, &[i64];
345    Int, i64;
346    Double, f64;
347    Bool, bool;
348    TensorList, Vec<Tensor>;
349    Device, Device;
350}
351
352/// Enum representing the different internal types an [`IValue`] can hold.
353///
354/// Check each variant docs to see what the internal storage in C++ is, and what
355/// the cost of moving across the Rust<>C++ boundary is.
356#[derive(Debug, PartialEq, Eq, Clone, Copy)]
357pub enum IValueKind {
358    /// - C++ type: `at::Tensor`
359    /// - Rust type: [`Tensor`]
360    Tensor,
361    /// - C++ type: `bool`
362    /// - Rust type: [`bool`]
363    Bool,
364    /// - C++ type: `int64_t`
365    /// - Rust type: [`i64`]
366    Int,
367    /// - C++ type: `c10::List<int64_t>`
368    /// - Rust type: [`Vec<i64>`]
369    ///
370    /// <div class="warning">
371    /// Passing across the C++-Rust boundary will copy the vector.
372    /// </div>
373    IntList,
374    /// - C++ type: `double`
375    /// - Rust type: [`f64`]
376    Double,
377    /// - C++ type: `c10::intrusive_ptr<ConstantString>`
378    /// - Rust type: [`String`]
379    ///
380    /// <div class="warning">
381    /// Passing across the C++-Rust boundary will copy the string.
382    /// </div>
383    String,
384    /// - C++ type: `c10::List<at::Tensor>`
385    /// - Rust type: [`Vec<Tensor>`]
386    ///
387    /// <div class="warning">
388    /// Passing across the C++-Rust boundary will copy the vector.
389    /// </div>
390    TensorList,
391    /// - C++ type: `c10::Device`
392    /// - Rust type: [`Device`]
393    Device,
394    None,
395
396    /// Catch-all for all other types. This is used for types that are not
397    /// natively supported in rust and can remain as opaque IValues for
398    /// interacting with torch apis. There is an overhead associated with
399    /// tracking alias and borrows for any trivial IValues being converted
400    /// to this type so they should be natively supported. Most of them are
401    /// already supported.
402    Other,
403}
404
405impl fmt::Debug for IValue {
406    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407        write!(f, "{}", ffi::debug_print(self).map_err(|_| fmt::Error)?)
408    }
409}
410
411impl<'py> IntoPyObject<'py> for IValue {
412    type Target = PyAny;
413    type Output = Bound<'py, Self::Target>;
414    type Error = PyErr;
415
416    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
417        ffi::arbitrary_ivalue_to_py_object(self)
418            .map_err(|e| PyValueError::new_err(format!("Failed converting to py: {}", e)))?
419            .into_pyobject(py)
420    }
421}
422
423impl FromPyObject<'_> for IValue {
424    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
425        ffi::ivalue_from_arbitrary_py_object(obj.into()).map_err(|e| {
426            PyValueError::new_err(format!("Failed extracting from py: {}: {}", e, obj))
427        })
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    #![allow(clippy::useless_vec)]
434
435    use pyo3::types::PyFloat;
436
437    use super::*;
438    use crate::DeviceType;
439    use crate::bridge::ffi::test_make_opaque_ivalue;
440    use crate::bridge::ffi::test_make_tensor;
441    use crate::bridge::ffi::test_make_undefined_tensor_ivalue;
442
443    impl From<String> for IValue {
444        fn from(value: String) -> Self {
445            ffi::ivalue_from_string(&value)
446        }
447    }
448
449    // Check for the equality of two IValues using tensor.equal to compare
450    // tensors.
451    fn ivalues_equal_with_tensor_equal(a: IValue, b: IValue) -> bool {
452        if a.isTensor() {
453            return b.isTensor() && a.to_tensor().unwrap().equal(&b.to_tensor().unwrap());
454        }
455
456        if a.isTensorList() {
457            if !b.isTensorList() {
458                return false;
459            }
460            let a_list = a.to_tensor_list().unwrap();
461            let b_list = b.to_tensor_list().unwrap();
462            if a_list.len() != b_list.len() {
463                return false;
464            }
465            for (a_tensor, b_tensor) in a_list.iter().zip(b_list.iter()) {
466                if !a_tensor.equal(b_tensor) {
467                    return false;
468                }
469            }
470            return true;
471        }
472
473        a == b
474    }
475
476    #[test]
477    fn bincode_serialize() {
478        let tensor = test_make_tensor();
479        let i1 = IValue::from(tensor);
480        let buf = bincode::serialize(&i1).unwrap();
481        let i2: IValue = bincode::deserialize(&buf).unwrap();
482        assert!(ivalues_equal_with_tensor_equal(i1, i2));
483    }
484
485    #[test]
486    fn multipart_serialize() {
487        let tensor = test_make_tensor();
488        let i1 = IValue::from(tensor);
489        let buf = serde_multipart::serialize_bincode(&i1).unwrap();
490        let i2: IValue = serde_multipart::deserialize_bincode(buf).unwrap();
491        assert!(ivalues_equal_with_tensor_equal(i1, i2));
492    }
493
494    #[test]
495    fn test_ivalue_from_py_object_with_type() {
496        pyo3::prepare_freethreaded_python();
497
498        let args_info =
499            crate::call_op::get_schema_args_info("aten::_foreach_add_", "Tensor").unwrap();
500        let (list, tensor, tensor_1, tensor_1_err) = Python::with_gil(|py| {
501            let list = pyo3::types::PyList::empty(py).into_any();
502            let none = py.None().into_bound(py);
503            let one = PyFloat::new(py, 1.0).into_any();
504            (
505                IValue::from_py_object_with_type(
506                    list,
507                    args_info[0].type_,
508                    args_info[0].num_elements,
509                    false,
510                )
511                .unwrap(),
512                IValue::from_py_object_with_type(
513                    none,
514                    args_info[1].type_,
515                    args_info[1].num_elements,
516                    false,
517                )
518                .unwrap(),
519                IValue::from_py_object_with_type(
520                    one.clone(),
521                    args_info[1].type_,
522                    args_info[1].num_elements,
523                    true,
524                )
525                .unwrap(),
526                IValue::from_py_object_with_type(
527                    one,
528                    args_info[1].type_,
529                    args_info[1].num_elements,
530                    false,
531                ),
532            )
533        });
534        assert!(list.is_tensor_list());
535        assert!(tensor.is_tensor());
536        assert!(tensor_1.is_tensor());
537        assert!(tensor_1_err.is_err());
538    }
539
540    macro_rules! generate_py_object_roundtrip_tests {
541        ($($kind:ident, $input:expr);* $(;)?) => {
542            paste! {
543                $(
544                    #[test]
545                    fn [<test_py_object_roundtrip_ $kind:snake:lower>]() {
546                            pyo3::prepare_freethreaded_python();
547                            // We need to load torch to initialize some internal
548                            // structures used by the FFI funcs we use to convert
549                            // ivalues to/from py objects.
550                            Python::with_gil(|py| py.run(pyo3::ffi::c_str!("import torch"), None, None)).unwrap();
551                            let original = IValue::from($input);
552                            // SAFETY: `TryIntoPyObject` consumes the value, so
553                            // we clone here to use for the `assert_eq` at end.
554                            let converted = unsafe { original.clone_unsafe() };
555                            let converted = Python::with_gil(|py| {
556                                let py_object = converted.into_pyobject(py).unwrap();
557                                anyhow::Ok(IValue::extract_bound(&py_object).unwrap())
558                            }).unwrap();
559                            assert!(ivalues_equal_with_tensor_equal(original, converted));
560                    }
561                )*
562
563                #[test]
564                fn test_py_object_roundtrip_was_exhaustive() {
565                    match IValueKind::Int {
566                        $(IValueKind::$kind => (),)*
567                    }
568                }
569            }
570        }
571    }
572
573    // Generate exhaustive roundtrip tests for all IValue kind.
574    // If you got a "non-exhaustive patterns" error here, you need to add a new
575    // test entry for your IValue kind!
576    generate_py_object_roundtrip_tests! {
577        Int, 123;
578        Double, 1.23;
579        String, "foobar".to_owned();
580        IntList, [1, 2, 3].as_slice();
581        Bool, false;
582        Tensor, test_make_tensor();
583        TensorList, vec![test_make_tensor()];
584        Device, Device::new(DeviceType::CPU);
585        None, ();
586        Other, test_make_opaque_ivalue();
587    }
588
589    macro_rules! generate_serde_roundtrip_tests {
590        ($($kind:ident, $input:expr);* $(;)?) => {
591            paste! {
592                $(
593                    #[test]
594                    fn [<test_serde_roundtrip_ $kind:snake:lower>]() {
595                            pyo3::prepare_freethreaded_python();
596                            // We need to load torch to initialize some internal
597                            // structures used by the FFI funcs we use to convert
598                            // ivalues to/from py objects.
599                            Python::with_gil(|py| py.run(pyo3::ffi::c_str!("import torch"), None, None)).unwrap();
600                            let original = IValue::from($input);
601                            let converted: IValue = bincode::deserialize(&bincode::serialize(&original).unwrap()).unwrap();
602                            assert!(ivalues_equal_with_tensor_equal(original, converted));
603                    }
604                )*
605
606                #[test]
607                fn test_serde_roundtrip_was_exhaustive() {
608                    match IValueKind::Int {
609                        $(IValueKind::$kind => (),)*
610                    }
611                }
612            }
613        }
614    }
615
616    // Generate exhaustive serde roundtrip tests for all IValue kind.
617    // If you got a "non-exhaustive patterns" error here, you need to add a new
618    // test entry for your IValue kind!
619    generate_serde_roundtrip_tests! {
620        Int, 123;
621        Double, 1.23;
622        String, "foobar".to_owned();
623        IntList, [1, 2, 3].as_slice();
624        Bool, false;
625        Tensor, test_make_tensor();
626        TensorList, vec![test_make_tensor()];
627        Device, Device::new(DeviceType::CPU);
628        None, ();
629        Other, test_make_opaque_ivalue();
630    }
631
632    #[test]
633    fn test_serde_roundtrip_undefined_tensor() {
634        let original = test_make_undefined_tensor_ivalue();
635        assert!(original.is_tensor());
636        assert!(
637            // SAFETY: Since it is an undefined tensor that we dont mutate,
638            // it is safe to clone in this test.
639            !unsafe { original.clone_unsafe() }
640                .to_tensor()
641                .unwrap()
642                .defined()
643        );
644        let converted: IValue =
645            bincode::deserialize(&bincode::serialize(&original).unwrap()).unwrap();
646        assert!(converted.is_tensor());
647        assert!(!converted.to_tensor().unwrap().defined());
648    }
649}