1use std::ffi::c_void;
10use std::fmt;
11use std::mem::MaybeUninit;
12
13use cxx::ExternType;
14use cxx::type_id;
15use monarch_types::TryIntoPyObjectUnsafe;
16use paste::paste;
17use pyo3::exceptions::PyValueError;
18use pyo3::prelude::*;
19use serde::Deserialize;
20use serde::Serialize;
21
22use crate::CloneUnsafe;
23use crate::Device;
24use crate::Tensor;
25use crate::bridge::clone_iv;
26use crate::bridge::ffi;
27use crate::cell::AliasTrackingRefCell;
28
29#[repr(C)]
40pub struct IValue {
41 repr: [*mut c_void; 2],
46}
47
48unsafe impl ExternType for IValue {
51 type Id = type_id!("c10::IValue");
52 type Kind = cxx::kind::Trivial;
53}
54
55impl Drop for IValue {
56 fn drop(&mut self) {
57 unsafe { crate::bridge::drop(self) };
60 }
61}
62
63impl PartialEq for IValue {
64 fn eq(&self, other: &Self) -> bool {
65 ffi::ivalues_equal_operator(self, other)
66 }
67}
68
69unsafe impl Send for IValue {}
72
73unsafe impl Sync for IValue {}
78
79impl Serialize for IValue {
80 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
81 where
82 S: serde::Serializer,
83 {
84 serializer.serialize_bytes(
85 ffi::serialize_ivalue(self)
86 .map_err(serde::ser::Error::custom)?
87 .as_slice(),
88 )
89 }
90}
91
92impl<'de> Deserialize<'de> for IValue {
93 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
94 where
95 D: serde::Deserializer<'de>,
96 {
97 let buf: &[u8] = Deserialize::deserialize(deserializer)?;
98 ffi::deserialize_ivalue(buf).map_err(serde::de::Error::custom)
99 }
100}
101
102impl CloneUnsafe for IValue {
103 unsafe fn clone_unsafe(&self) -> Self {
113 let mut ivalue = MaybeUninit::<IValue>::uninit();
114 let new = ivalue.as_mut_ptr().cast();
115 unsafe {
117 clone_iv(self, new);
118 ivalue.assume_init()
119 }
120 }
121}
122
123#[derive(Debug, Serialize, Deserialize)]
126pub struct OpaqueIValue(IValue);
127
128impl OpaqueIValue {
129 pub(crate) unsafe fn ivalue(&self) -> IValue {
134 unsafe { self.0.clone_unsafe() }
136 }
137
138 pub fn from_py_object_with_type(
139 obj: Bound<'_, PyAny>,
140 type_: &crate::call_op::TypePtr,
141 num_elements: i32,
142 allow_nums_as_tensors: bool,
143 ) -> PyResult<OpaqueIValue> {
144 IValue::from_py_object_with_type(obj, type_, num_elements, allow_nums_as_tensors)
145 .map(OpaqueIValue)
146 }
147}
148
149impl Clone for OpaqueIValue {
150 fn clone(&self) -> Self {
153 let serialized = bincode::serialize(&self.0).unwrap();
154 bincode::deserialize(&serialized).unwrap()
155 }
156}
157
158impl CloneUnsafe for OpaqueIValue {
159 unsafe fn clone_unsafe(&self) -> Self {
160 Self(unsafe { self.0.clone_unsafe() })
162 }
163}
164
165impl<'py> TryIntoPyObjectUnsafe<'py, PyAny> for OpaqueIValue {
166 unsafe fn try_to_object_unsafe(self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
167 unsafe { self.ivalue() }.into_pyobject(py)
169 }
170}
171
172pub type OpaqueIValueCell = AliasTrackingRefCell<OpaqueIValue>;
173
174impl From<OpaqueIValue> for OpaqueIValueCell {
175 #[inline]
176 fn from(value: OpaqueIValue) -> Self {
177 Self::new(value)
178 }
179}
180
181macro_rules! gen_is_impl {
182 ($enum:ty, [$($kind:ident),* $(,)? ]) => {
183 paste! {
184 $(
185 pub fn [<is_ $kind:snake:lower>](&self) -> bool {
186 self.[<is $kind>]()
187 }
188 )*
189
190 pub fn is_other(&self) -> bool {
191 self.kind() == $enum::Other
192 }
193
194 fn __exhaustive_checker(foo: $enum) {
195 match foo {
196 $($enum::$kind => (),)*
197 IValueKind::Other => (),
198 }
199 }
200
201 pub fn kind(&self) -> $enum {
202 if false {
203 unreachable!();
204 } $(else if self.[<is_ $kind:snake:lower>]() {
205 $enum::$kind
206 })*
207 else {
208 $enum::Other
209 }
210 }
211 }
212 }
213}
214
215macro_rules! gen_from_impl {
216 ($enum:ty, $($kind:ident, $from_type:ty);* $(;)?) => {
217 paste! {
218 $(
219 impl From<$from_type> for IValue {
220 fn from(value: $from_type) -> Self {
221 ffi::[<ivalue_from_ $kind:snake:lower>](value)
222 }
223 }
224
225 )*
226
227 fn __exhaustive_checker(foo: $enum) {
228 match foo {
229 $($enum::$kind => (),)*
230 IValueKind::None => (),
231 IValueKind::Other => (),
232 }
233 }
234 }
235
236 }
237}
238
239impl From<()> for IValue {
240 fn from(_value: ()) -> Self {
241 ffi::ivalue_from_none()
242 }
243}
244
245impl IValue {
246 pub fn to_tensor(self) -> Option<Tensor> {
247 ffi::toTensor(self).ok()
248 }
249 pub fn to_string(&self) -> Option<String> {
250 ffi::toString(self).ok()
251 }
252 pub fn to_int_list(&self) -> Option<Vec<i64>> {
253 ffi::toIntList(self).ok()
254 }
255 pub fn to_int(&self) -> Option<i64> {
256 self.toInt().ok()
257 }
258 pub fn to_double(&self) -> Option<f64> {
259 self.toDouble().ok()
260 }
261 pub fn to_bool(&self) -> Option<bool> {
262 self.toBool().ok()
263 }
264 pub fn to_tensor_list(self) -> Option<Vec<Tensor>> {
265 ffi::toTensorList(self).ok()
266 }
267 pub fn to_device(&self) -> Option<Device> {
268 self.toDevice().ok()
269 }
270 pub fn to_none(&self) -> Option<()> {
271 if self.is_none() { Some(()) } else { None }
272 }
273 pub fn to_opaque(self) -> Option<OpaqueIValue> {
274 if self.is_other() {
275 Some(OpaqueIValue(self))
276 } else {
277 None
278 }
279 }
280 gen_is_impl! {
285 IValueKind, [
286 Tensor,
287 String,
288 IntList,
289 Int,
290 Double,
291 Bool,
292 TensorList,
293 Device,
294 None,
295 ]
296 }
297
298 pub fn from_py_object_with_type(
299 obj: Bound<'_, PyAny>,
300 type_: &crate::call_op::TypePtr,
301 num_elements: i32,
302 allow_nums_as_tensors: bool,
303 ) -> PyResult<IValue> {
304 ffi::ivalue_from_py_object_with_type(obj.into(), type_, num_elements, allow_nums_as_tensors)
305 .map_err(|err| {
306 PyValueError::new_err(format!(
307 "Failed to extract IValue from python object: {:?}",
308 err
309 ))
310 })
311 }
312
313 pub(crate) fn from_py_object_or_none(obj: &Bound<'_, PyAny>) -> Option<IValue> {
314 ffi::py_object_is_ivalue(obj.clone().into())
315 .then(|| ffi::ivalue_from_arbitrary_py_object(obj.into()).unwrap())
316 }
317}
318
319gen_from_impl! {
324 IValueKind,
325 Tensor, Tensor;
326 String, &String;
327 IntList, &[i64];
328 Int, i64;
329 Double, f64;
330 Bool, bool;
331 TensorList, Vec<Tensor>;
332 Device, Device;
333}
334
335#[derive(Debug, PartialEq, Eq, Clone, Copy)]
340pub enum IValueKind {
341 Tensor,
344 Bool,
347 Int,
350 IntList,
357 Double,
360 String,
367 TensorList,
374 Device,
377 None,
378
379 Other,
386}
387
388impl fmt::Debug for IValue {
389 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
390 write!(f, "{}", ffi::debug_print(self).map_err(|_| fmt::Error)?)
391 }
392}
393
394impl<'py> IntoPyObject<'py> for IValue {
395 type Target = PyAny;
396 type Output = Bound<'py, Self::Target>;
397 type Error = PyErr;
398
399 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
400 ffi::arbitrary_ivalue_to_py_object(self)
401 .map_err(|e| PyValueError::new_err(format!("Failed converting to py: {}", e)))?
402 .into_pyobject(py)
403 }
404}
405
406impl FromPyObject<'_> for IValue {
407 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
408 ffi::ivalue_from_arbitrary_py_object(obj.into()).map_err(|e| {
409 PyValueError::new_err(format!("Failed extracting from py: {}: {}", e, obj))
410 })
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 #![allow(clippy::useless_vec)]
417
418 use pyo3::types::PyFloat;
419
420 use super::*;
421 use crate::DeviceType;
422 use crate::bridge::ffi::test_make_opaque_ivalue;
423 use crate::bridge::ffi::test_make_tensor;
424 use crate::bridge::ffi::test_make_undefined_tensor_ivalue;
425
426 impl From<String> for IValue {
427 fn from(value: String) -> Self {
428 ffi::ivalue_from_string(&value)
429 }
430 }
431
432 fn ivalues_equal_with_tensor_equal(a: IValue, b: IValue) -> bool {
435 if a.isTensor() {
436 return b.isTensor() && a.to_tensor().unwrap().equal(&b.to_tensor().unwrap());
437 }
438
439 if a.isTensorList() {
440 if !b.isTensorList() {
441 return false;
442 }
443 let a_list = a.to_tensor_list().unwrap();
444 let b_list = b.to_tensor_list().unwrap();
445 if a_list.len() != b_list.len() {
446 return false;
447 }
448 for (a_tensor, b_tensor) in a_list.iter().zip(b_list.iter()) {
449 if !a_tensor.equal(b_tensor) {
450 return false;
451 }
452 }
453 return true;
454 }
455
456 a == b
457 }
458
459 #[test]
460 fn test_ivalue_from_py_object_with_type() {
461 pyo3::prepare_freethreaded_python();
462
463 let args_info =
464 crate::call_op::get_schema_args_info("aten::_foreach_add_", "Tensor").unwrap();
465 let (list, tensor, tensor_1, tensor_1_err) = Python::with_gil(|py| {
466 let list = pyo3::types::PyList::empty(py).into_any();
467 let none = py.None().into_bound(py);
468 let one = PyFloat::new(py, 1.0).into_any();
469 (
470 IValue::from_py_object_with_type(
471 list,
472 args_info[0].type_,
473 args_info[0].num_elements,
474 false,
475 )
476 .unwrap(),
477 IValue::from_py_object_with_type(
478 none,
479 args_info[1].type_,
480 args_info[1].num_elements,
481 false,
482 )
483 .unwrap(),
484 IValue::from_py_object_with_type(
485 one.clone(),
486 args_info[1].type_,
487 args_info[1].num_elements,
488 true,
489 )
490 .unwrap(),
491 IValue::from_py_object_with_type(
492 one,
493 args_info[1].type_,
494 args_info[1].num_elements,
495 false,
496 ),
497 )
498 });
499 assert!(list.is_tensor_list());
500 assert!(tensor.is_tensor());
501 assert!(tensor_1.is_tensor());
502 assert!(tensor_1_err.is_err());
503 }
504
505 macro_rules! generate_py_object_roundtrip_tests {
506 ($($kind:ident, $input:expr_2021);* $(;)?) => {
507 paste! {
508 $(
509 #[test]
510 fn [<test_py_object_roundtrip_ $kind:snake:lower>]() {
511 pyo3::prepare_freethreaded_python();
512 Python::with_gil(|py| py.run(pyo3::ffi::c_str!("import torch"), None, None)).unwrap();
516 let original = IValue::from($input);
517 let converted = unsafe { original.clone_unsafe() };
520 let converted = Python::with_gil(|py| {
521 let py_object = converted.into_pyobject(py).unwrap();
522 anyhow::Ok(IValue::extract_bound(&py_object).unwrap())
523 }).unwrap();
524 assert!(ivalues_equal_with_tensor_equal(original, converted));
525 }
526 )*
527
528 #[test]
529 fn test_py_object_roundtrip_was_exhaustive() {
530 match IValueKind::Int {
531 $(IValueKind::$kind => (),)*
532 }
533 }
534 }
535 }
536 }
537
538 generate_py_object_roundtrip_tests! {
542 Int, 123;
543 Double, 1.23;
544 String, "foobar".to_owned();
545 IntList, [1, 2, 3].as_slice();
546 Bool, false;
547 Tensor, test_make_tensor();
548 TensorList, vec![test_make_tensor()];
549 Device, Device::new(DeviceType::CPU);
550 None, ();
551 Other, test_make_opaque_ivalue();
552 }
553
554 macro_rules! generate_serde_roundtrip_tests {
555 ($($kind:ident, $input:expr_2021);* $(;)?) => {
556 paste! {
557 $(
558 #[test]
559 fn [<test_serde_roundtrip_ $kind:snake:lower>]() {
560 pyo3::prepare_freethreaded_python();
561 Python::with_gil(|py| py.run(pyo3::ffi::c_str!("import torch"), None, None)).unwrap();
565 let original = IValue::from($input);
566 let converted: IValue = bincode::deserialize(&bincode::serialize(&original).unwrap()).unwrap();
567 assert!(ivalues_equal_with_tensor_equal(original, converted));
568 }
569 )*
570
571 #[test]
572 fn test_serde_roundtrip_was_exhaustive() {
573 match IValueKind::Int {
574 $(IValueKind::$kind => (),)*
575 }
576 }
577 }
578 }
579 }
580
581 generate_serde_roundtrip_tests! {
585 Int, 123;
586 Double, 1.23;
587 String, "foobar".to_owned();
588 IntList, [1, 2, 3].as_slice();
589 Bool, false;
590 Tensor, test_make_tensor();
591 TensorList, vec![test_make_tensor()];
592 Device, Device::new(DeviceType::CPU);
593 None, ();
594 Other, test_make_opaque_ivalue();
595 }
596
597 #[test]
598 fn test_serde_roundtrip_undefined_tensor() {
599 let original = test_make_undefined_tensor_ivalue();
600 assert!(original.is_tensor());
601 assert!(
602 !unsafe { original.clone_unsafe() }
605 .to_tensor()
606 .unwrap()
607 .defined()
608 );
609 let converted: IValue =
610 bincode::deserialize(&bincode::serialize(&original).unwrap()).unwrap();
611 assert!(converted.is_tensor());
612 assert!(!converted.to_tensor().unwrap().defined());
613 }
614}