1use std::ffi::c_void;
10use std::fmt;
11use std::mem::MaybeUninit;
12
13use cxx::ExternType;
14use cxx::type_id;
15use monarch_types::TryIntoPyObjectUnsafe;
16use paste::paste;
17use pyo3::exceptions::PyValueError;
18use pyo3::prelude::*;
19use serde::Deserialize;
20use serde::Serialize;
21use serde::de::Visitor;
22
23use crate::CloneUnsafe;
24use crate::Device;
25use crate::Tensor;
26use crate::bridge::clone_iv;
27use crate::bridge::ffi;
28use crate::cell::AliasTrackingRefCell;
29
30#[repr(C)]
41pub struct IValue {
42 repr: [*mut c_void; 2],
47}
48
49unsafe impl ExternType for IValue {
52 type Id = type_id!("c10::IValue");
53 type Kind = cxx::kind::Trivial;
54}
55
56impl Drop for IValue {
57 fn drop(&mut self) {
58 unsafe { crate::bridge::drop(self) };
61 }
62}
63
64impl PartialEq for IValue {
65 fn eq(&self, other: &Self) -> bool {
66 ffi::ivalues_equal_operator(self, other)
67 }
68}
69
70unsafe impl Send for IValue {}
73
74unsafe impl Sync for IValue {}
79
80impl Serialize for IValue {
81 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
82 where
83 S: serde::Serializer,
84 {
85 serializer.serialize_bytes(
86 ffi::serialize_ivalue(self)
87 .map_err(serde::ser::Error::custom)?
88 .as_slice(),
89 )
90 }
91}
92
93impl<'de> Deserialize<'de> for IValue {
94 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
95 where
96 D: serde::Deserializer<'de>,
97 {
98 struct IValueVisitor;
99
100 impl<'de> Visitor<'de> for IValueVisitor {
101 type Value = IValue;
102
103 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
104 f.write_str("raw ivalue bytes")
105 }
106
107 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
108 where
109 E: serde::de::Error,
110 {
111 ffi::deserialize_ivalue(v).map_err(E::custom)
112 }
113 }
114
115 deserializer.deserialize_bytes(IValueVisitor)
116 }
117}
118
119impl CloneUnsafe for IValue {
120 unsafe fn clone_unsafe(&self) -> Self {
130 let mut ivalue = MaybeUninit::<IValue>::uninit();
131 let new = ivalue.as_mut_ptr().cast();
132 unsafe {
134 clone_iv(self, new);
135 ivalue.assume_init()
136 }
137 }
138}
139
140#[derive(Debug, Serialize, Deserialize)]
143pub struct OpaqueIValue(IValue);
144
145impl OpaqueIValue {
146 pub(crate) unsafe fn ivalue(&self) -> IValue {
151 unsafe { self.0.clone_unsafe() }
153 }
154
155 pub fn from_py_object_with_type(
156 obj: Bound<'_, PyAny>,
157 type_: &crate::call_op::TypePtr,
158 num_elements: i32,
159 allow_nums_as_tensors: bool,
160 ) -> PyResult<OpaqueIValue> {
161 IValue::from_py_object_with_type(obj, type_, num_elements, allow_nums_as_tensors)
162 .map(OpaqueIValue)
163 }
164}
165
166impl Clone for OpaqueIValue {
167 fn clone(&self) -> Self {
170 let serialized = bincode::serialize(&self.0).unwrap();
171 bincode::deserialize(&serialized).unwrap()
172 }
173}
174
175impl CloneUnsafe for OpaqueIValue {
176 unsafe fn clone_unsafe(&self) -> Self {
177 Self(unsafe { self.0.clone_unsafe() })
179 }
180}
181
182impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for OpaqueIValue {
183 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
184 unsafe { self.ivalue() }.into_pyobject(py)
186 }
187}
188
189pub type OpaqueIValueCell = AliasTrackingRefCell<OpaqueIValue>;
190
191impl From<OpaqueIValue> for OpaqueIValueCell {
192 #[inline]
193 fn from(value: OpaqueIValue) -> Self {
194 Self::new(value)
195 }
196}
197
198macro_rules! gen_is_impl {
199 ($enum:ty, [$($kind:ident),* $(,)? ]) => {
200 paste! {
201 $(
202 pub fn [<is_ $kind:snake:lower>](&self) -> bool {
203 self.[<is $kind>]()
204 }
205 )*
206
207 pub fn is_other(&self) -> bool {
208 self.kind() == $enum::Other
209 }
210
211 fn __exhaustive_checker(foo: $enum) {
212 match foo {
213 $($enum::$kind => (),)*
214 IValueKind::Other => (),
215 }
216 }
217
218 pub fn kind(&self) -> $enum {
219 if false {
220 unreachable!();
221 } $(else if self.[<is_ $kind:snake:lower>]() {
222 $enum::$kind
223 })*
224 else {
225 $enum::Other
226 }
227 }
228 }
229 }
230}
231
232macro_rules! gen_from_impl {
233 ($enum:ty, $($kind:ident, $from_type:ty);* $(;)?) => {
234 paste! {
235 $(
236 impl From<$from_type> for IValue {
237 fn from(value: $from_type) -> Self {
238 ffi::[<ivalue_from_ $kind:snake:lower>](value)
239 }
240 }
241
242 )*
243
244 fn __exhaustive_checker(foo: $enum) {
245 match foo {
246 $($enum::$kind => (),)*
247 IValueKind::None => (),
248 IValueKind::Other => (),
249 }
250 }
251 }
252
253 }
254}
255
256impl From<()> for IValue {
257 fn from(_value: ()) -> Self {
258 ffi::ivalue_from_none()
259 }
260}
261
262impl IValue {
263 pub fn to_tensor(self) -> Option<Tensor> {
264 ffi::toTensor(self).ok()
265 }
266 pub fn to_string(&self) -> Option<String> {
267 ffi::toString(self).ok()
268 }
269 pub fn to_int_list(&self) -> Option<Vec<i64>> {
270 ffi::toIntList(self).ok()
271 }
272 pub fn to_int(&self) -> Option<i64> {
273 self.toInt().ok()
274 }
275 pub fn to_double(&self) -> Option<f64> {
276 self.toDouble().ok()
277 }
278 pub fn to_bool(&self) -> Option<bool> {
279 self.toBool().ok()
280 }
281 pub fn to_tensor_list(self) -> Option<Vec<Tensor>> {
282 ffi::toTensorList(self).ok()
283 }
284 pub fn to_device(&self) -> Option<Device> {
285 self.toDevice().ok()
286 }
287 pub fn to_none(&self) -> Option<()> {
288 if self.is_none() { Some(()) } else { None }
289 }
290 pub fn to_opaque(self) -> Option<OpaqueIValue> {
291 if self.is_other() {
292 Some(OpaqueIValue(self))
293 } else {
294 None
295 }
296 }
297 gen_is_impl! {
302 IValueKind, [
303 Tensor,
304 String,
305 IntList,
306 Int,
307 Double,
308 Bool,
309 TensorList,
310 Device,
311 None,
312 ]
313 }
314
315 pub fn from_py_object_with_type(
316 obj: Bound<'_, PyAny>,
317 type_: &crate::call_op::TypePtr,
318 num_elements: i32,
319 allow_nums_as_tensors: bool,
320 ) -> PyResult<IValue> {
321 ffi::ivalue_from_py_object_with_type(obj.into(), type_, num_elements, allow_nums_as_tensors)
322 .map_err(|err| {
323 PyValueError::new_err(format!(
324 "Failed to extract IValue from python object: {:?}",
325 err
326 ))
327 })
328 }
329
330 pub(crate) fn from_py_object_or_none(obj: &Bound<'_, PyAny>) -> Option<IValue> {
331 ffi::py_object_is_ivalue(obj.clone().into())
332 .then(|| ffi::ivalue_from_arbitrary_py_object(obj.into()).unwrap())
333 }
334}
335
336gen_from_impl! {
341 IValueKind,
342 Tensor, Tensor;
343 String, &String;
344 IntList, &[i64];
345 Int, i64;
346 Double, f64;
347 Bool, bool;
348 TensorList, Vec<Tensor>;
349 Device, Device;
350}
351
352#[derive(Debug, PartialEq, Eq, Clone, Copy)]
357pub enum IValueKind {
358 Tensor,
361 Bool,
364 Int,
367 IntList,
374 Double,
377 String,
384 TensorList,
391 Device,
394 None,
395
396 Other,
403}
404
405impl fmt::Debug for IValue {
406 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407 write!(f, "{}", ffi::debug_print(self).map_err(|_| fmt::Error)?)
408 }
409}
410
411impl<'py> IntoPyObject<'py> for IValue {
412 type Target = PyAny;
413 type Output = Bound<'py, Self::Target>;
414 type Error = PyErr;
415
416 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
417 ffi::arbitrary_ivalue_to_py_object(self)
418 .map_err(|e| PyValueError::new_err(format!("Failed converting to py: {}", e)))?
419 .into_pyobject(py)
420 }
421}
422
423impl FromPyObject<'_> for IValue {
424 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
425 ffi::ivalue_from_arbitrary_py_object(obj.into()).map_err(|e| {
426 PyValueError::new_err(format!("Failed extracting from py: {}: {}", e, obj))
427 })
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 #![allow(clippy::useless_vec)]
434
435 use pyo3::types::PyFloat;
436
437 use super::*;
438 use crate::DeviceType;
439 use crate::bridge::ffi::test_make_opaque_ivalue;
440 use crate::bridge::ffi::test_make_tensor;
441 use crate::bridge::ffi::test_make_undefined_tensor_ivalue;
442
443 impl From<String> for IValue {
444 fn from(value: String) -> Self {
445 ffi::ivalue_from_string(&value)
446 }
447 }
448
449 fn ivalues_equal_with_tensor_equal(a: IValue, b: IValue) -> bool {
452 if a.isTensor() {
453 return b.isTensor() && a.to_tensor().unwrap().equal(&b.to_tensor().unwrap());
454 }
455
456 if a.isTensorList() {
457 if !b.isTensorList() {
458 return false;
459 }
460 let a_list = a.to_tensor_list().unwrap();
461 let b_list = b.to_tensor_list().unwrap();
462 if a_list.len() != b_list.len() {
463 return false;
464 }
465 for (a_tensor, b_tensor) in a_list.iter().zip(b_list.iter()) {
466 if !a_tensor.equal(b_tensor) {
467 return false;
468 }
469 }
470 return true;
471 }
472
473 a == b
474 }
475
476 #[test]
477 fn bincode_serialize() {
478 let tensor = test_make_tensor();
479 let i1 = IValue::from(tensor);
480 let buf = bincode::serialize(&i1).unwrap();
481 let i2: IValue = bincode::deserialize(&buf).unwrap();
482 assert!(ivalues_equal_with_tensor_equal(i1, i2));
483 }
484
485 #[test]
486 fn multipart_serialize() {
487 let tensor = test_make_tensor();
488 let i1 = IValue::from(tensor);
489 let buf = serde_multipart::serialize_bincode(&i1).unwrap();
490 let i2: IValue = serde_multipart::deserialize_bincode(buf).unwrap();
491 assert!(ivalues_equal_with_tensor_equal(i1, i2));
492 }
493
494 #[test]
495 fn test_ivalue_from_py_object_with_type() {
496 pyo3::prepare_freethreaded_python();
497
498 let args_info =
499 crate::call_op::get_schema_args_info("aten::_foreach_add_", "Tensor").unwrap();
500 let (list, tensor, tensor_1, tensor_1_err) = Python::with_gil(|py| {
501 let list = pyo3::types::PyList::empty(py).into_any();
502 let none = py.None().into_bound(py);
503 let one = PyFloat::new(py, 1.0).into_any();
504 (
505 IValue::from_py_object_with_type(
506 list,
507 args_info[0].type_,
508 args_info[0].num_elements,
509 false,
510 )
511 .unwrap(),
512 IValue::from_py_object_with_type(
513 none,
514 args_info[1].type_,
515 args_info[1].num_elements,
516 false,
517 )
518 .unwrap(),
519 IValue::from_py_object_with_type(
520 one.clone(),
521 args_info[1].type_,
522 args_info[1].num_elements,
523 true,
524 )
525 .unwrap(),
526 IValue::from_py_object_with_type(
527 one,
528 args_info[1].type_,
529 args_info[1].num_elements,
530 false,
531 ),
532 )
533 });
534 assert!(list.is_tensor_list());
535 assert!(tensor.is_tensor());
536 assert!(tensor_1.is_tensor());
537 assert!(tensor_1_err.is_err());
538 }
539
540 macro_rules! generate_py_object_roundtrip_tests {
541 ($($kind:ident, $input:expr);* $(;)?) => {
542 paste! {
543 $(
544 #[test]
545 fn [<test_py_object_roundtrip_ $kind:snake:lower>]() {
546 pyo3::prepare_freethreaded_python();
547 Python::with_gil(|py| py.run(pyo3::ffi::c_str!("import torch"), None, None)).unwrap();
551 let original = IValue::from($input);
552 let converted = unsafe { original.clone_unsafe() };
555 let converted = Python::with_gil(|py| {
556 let py_object = converted.into_pyobject(py).unwrap();
557 anyhow::Ok(IValue::extract_bound(&py_object).unwrap())
558 }).unwrap();
559 assert!(ivalues_equal_with_tensor_equal(original, converted));
560 }
561 )*
562
563 #[test]
564 fn test_py_object_roundtrip_was_exhaustive() {
565 match IValueKind::Int {
566 $(IValueKind::$kind => (),)*
567 }
568 }
569 }
570 }
571 }
572
573 generate_py_object_roundtrip_tests! {
577 Int, 123;
578 Double, 1.23;
579 String, "foobar".to_owned();
580 IntList, [1, 2, 3].as_slice();
581 Bool, false;
582 Tensor, test_make_tensor();
583 TensorList, vec![test_make_tensor()];
584 Device, Device::new(DeviceType::CPU);
585 None, ();
586 Other, test_make_opaque_ivalue();
587 }
588
589 macro_rules! generate_serde_roundtrip_tests {
590 ($($kind:ident, $input:expr);* $(;)?) => {
591 paste! {
592 $(
593 #[test]
594 fn [<test_serde_roundtrip_ $kind:snake:lower>]() {
595 pyo3::prepare_freethreaded_python();
596 Python::with_gil(|py| py.run(pyo3::ffi::c_str!("import torch"), None, None)).unwrap();
600 let original = IValue::from($input);
601 let converted: IValue = bincode::deserialize(&bincode::serialize(&original).unwrap()).unwrap();
602 assert!(ivalues_equal_with_tensor_equal(original, converted));
603 }
604 )*
605
606 #[test]
607 fn test_serde_roundtrip_was_exhaustive() {
608 match IValueKind::Int {
609 $(IValueKind::$kind => (),)*
610 }
611 }
612 }
613 }
614 }
615
616 generate_serde_roundtrip_tests! {
620 Int, 123;
621 Double, 1.23;
622 String, "foobar".to_owned();
623 IntList, [1, 2, 3].as_slice();
624 Bool, false;
625 Tensor, test_make_tensor();
626 TensorList, vec![test_make_tensor()];
627 Device, Device::new(DeviceType::CPU);
628 None, ();
629 Other, test_make_opaque_ivalue();
630 }
631
632 #[test]
633 fn test_serde_roundtrip_undefined_tensor() {
634 let original = test_make_undefined_tensor_ivalue();
635 assert!(original.is_tensor());
636 assert!(
637 !unsafe { original.clone_unsafe() }
640 .to_tensor()
641 .unwrap()
642 .defined()
643 );
644 let converted: IValue =
645 bincode::deserialize(&bincode::serialize(&original).unwrap()).unwrap();
646 assert!(converted.is_tensor());
647 assert!(!converted.to_tensor().unwrap().defined());
648 }
649}