1use std::collections::HashMap;
10
11use cxx::ExternType;
12use cxx::type_id;
13use thiserror::Error;
14
15use crate::IValue;
16use crate::IValueKind;
17use crate::RValue;
18use crate::TensorCell;
19use crate::borrow::BorrowError;
20use crate::borrow::BorrowType;
21use crate::borrow::MultiBorrow;
22use crate::bridge::ffi::AliasInfo;
23use crate::bridge::ffi::AliasKind;
24use crate::bridge::ffi::Kwarg;
25use crate::bridge::ffi::Tensor;
26pub use crate::bridge::ffi::get_schema_args_info;
27use crate::ivalue::OpaqueIValue;
28use crate::ivalue::OpaqueIValueCell;
29use crate::rvalue::rvalue_to_ivalue;
30
31#[derive(Error, Debug)]
33#[non_exhaustive]
34pub enum CallOpError {
35 #[error("torch operator error {0}")]
36 TorchOperatorError(#[from] cxx::Exception),
37
38 #[error("error borrowing: {0}")]
39 BorrowError(#[from] BorrowError),
40
41 #[error("invalid kwarg '{kwarg}' for op: '{operator}.{overload}'")]
42 InvalidKwargs {
43 kwarg: String,
44 operator: String,
45 overload: String,
46 },
47}
48
49#[repr(C)]
53pub struct TypePtr {
54 _private: [u8; 0],
55}
56
57unsafe impl ExternType for TypePtr {
60 type Id = type_id!("c10::TypePtr");
61 type Kind = cxx::kind::Opaque;
62}
63
64impl TypePtr {
65 #[allow(dead_code)]
66 #[inline]
67 pub fn is_tensor(&self) -> bool {
68 crate::bridge::ffi::type_ptr_is_tensor(self)
69 }
70
71 #[allow(dead_code)]
72 #[inline]
73 pub fn is_tensor_list(&self) -> bool {
74 crate::bridge::ffi::type_ptr_is_tensor_list(self)
75 }
76
77 #[allow(dead_code)]
78 #[inline]
79 pub fn is_optional_tensor(&self) -> bool {
80 crate::bridge::ffi::type_ptr_is_optional_tensor(self)
81 }
82
83 #[allow(dead_code)]
84 #[inline]
85 pub fn is_optional_tensor_list(&self) -> bool {
86 crate::bridge::ffi::type_ptr_is_optional_tensor_list(self)
87 }
88}
89
90impl std::fmt::Debug for TypePtr {
91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92 f.debug_struct("TypePtr").field("type", &"<...>").finish()
93 }
94}
95
96fn get_aliased_rvalue<'a>(
97 alias_info: &'a AliasInfo,
98 args: &'a [RValue],
99 kwargs: &'a HashMap<String, RValue>,
100) -> &'a RValue {
101 match args.get(alias_info.arg_idx) {
102 Some(rvalue) => rvalue,
104 None => {
105 let (_name, rvalue) = kwargs
107 .iter()
108 .find(|(key, _)| *key == &alias_info.arg_name)
109 .unwrap();
112 rvalue
113 }
114 }
115}
116
117fn make_tensor_cell(
119 tensor: Tensor,
120 alias_info: &AliasInfo,
121 args: &[RValue],
122 kwargs: &HashMap<String, RValue>,
123) -> TensorCell {
124 match alias_info.kind {
125 AliasKind::NewValue => TensorCell::new(tensor),
126 AliasKind::Alias => match get_aliased_rvalue(alias_info, args, kwargs) {
127 RValue::Tensor(cell) => TensorCell::new_with_alias(tensor, cell),
128 RValue::Opaque(_) => TensorCell::new(tensor),
132 _ => panic!("must be a tensor to create an aliases tensorcell"),
133 },
134 _ => panic!("unsupported alias kind"),
135 }
136}
137
138fn make_opaque_ivalue_cell(
139 ivalue: OpaqueIValue,
140 alias_info: &AliasInfo,
141 args: &[RValue],
142 kwargs: &HashMap<String, RValue>,
143) -> OpaqueIValueCell {
144 match alias_info.kind {
145 AliasKind::NewValue => OpaqueIValueCell::new(ivalue),
146 AliasKind::Alias => match get_aliased_rvalue(alias_info, args, kwargs) {
147 RValue::Opaque(cell) => OpaqueIValueCell::new_with_alias(ivalue, cell),
148 _ => panic!("must be an opaque ivalue to create an aliases opaque ivalue cell"),
149 },
150 _ => panic!("unsupported alias kind"),
151 }
152}
153
154pub fn call_op(
170 op_name: impl AsRef<str>,
171 overload: impl AsRef<str>,
172 args: &[RValue],
173 kwargs: &HashMap<String, RValue>,
174 flatten_results: bool,
175) -> Result<Vec<RValue>, CallOpError> {
176 let mut multiborrow = MultiBorrow::new();
180
181 let mutates = get_schema_args_info(op_name.as_ref(), overload.as_ref())?;
182
183 for (arg, arg_mutability) in args.iter().zip(&mutates) {
185 let borrow_type = if arg_mutability.is_mutable {
186 BorrowType::Mutable
187 } else {
188 BorrowType::Shared
189 };
190 multiborrow.add(arg, borrow_type);
191 }
192
193 for (key, arg) in kwargs.iter() {
195 let arg_mutability = mutates.iter().find(|arg| &arg.name == key).ok_or_else(|| {
196 CallOpError::InvalidKwargs {
197 kwarg: key.to_string(),
198 operator: op_name.as_ref().to_string(),
199 overload: overload.as_ref().to_string(),
200 }
201 })?;
202 let borrow_type = if arg_mutability.is_mutable {
203 BorrowType::Mutable
204 } else {
205 BorrowType::Shared
206 };
207 multiborrow.add(arg, borrow_type);
208 }
209
210 let _borrows = multiborrow.borrow()?;
212
213 let mut ivalue_args: Vec<IValue> = args
214 .iter()
215 .map(|rvalue| unsafe { rvalue_to_ivalue(rvalue) })
220 .collect();
221 let mut ivalue_kwargs: Vec<Kwarg> = kwargs
222 .iter()
223 .map(|(key, value)| Kwarg {
224 name: key.clone(),
225 arg: unsafe { rvalue_to_ivalue(value) },
227 })
228 .collect();
229
230 let call_op_result = unsafe {
233 crate::bridge::ffi::call_op_raw(
234 op_name.as_ref(),
235 overload.as_ref(),
236 &mut ivalue_args,
237 &mut ivalue_kwargs,
238 flatten_results,
239 )?
240 };
241 Ok(call_op_result
242 .outputs
243 .into_iter()
244 .zip(call_op_result.alias_infos)
245 .map(|(ivalue, alias_info)| match ivalue.kind() {
246 IValueKind::Tensor => RValue::Tensor(make_tensor_cell(
247 ivalue.to_tensor().unwrap(),
248 &alias_info,
249 args,
250 kwargs,
251 )),
252 IValueKind::Bool => RValue::Bool(ivalue.to_bool().unwrap()),
253 IValueKind::Int => RValue::Int(ivalue.to_int().unwrap()),
254 IValueKind::IntList => RValue::IntList(ivalue.to_int_list().unwrap()),
255 IValueKind::Double => RValue::Double(ivalue.to_double().unwrap()),
256 IValueKind::String => RValue::String(ivalue.to_string().unwrap()),
257 IValueKind::TensorList => {
258 let mut tensors = Vec::new();
259 let tensor_list = ivalue.to_tensor_list().unwrap();
260 for tensor in tensor_list {
261 tensors.push(make_tensor_cell(tensor, &alias_info, args, kwargs));
262 }
263 RValue::TensorList(tensors)
264 }
265 IValueKind::Device => RValue::Device(ivalue.to_device().unwrap()),
266 IValueKind::None => RValue::None,
267 IValueKind::Other => RValue::Opaque(make_opaque_ivalue_cell(
268 ivalue.to_opaque().unwrap(),
269 &alias_info,
270 args,
271 kwargs,
272 )),
273 })
274 .collect())
275}
276
277#[cfg(test)]
278mod tests {
279 use core::panic;
280
281 use super::*;
282 use crate::CloneUnsafe;
283 use crate::bridge::ffi::AliasKind;
284 use crate::bridge::ffi::call_op_raw;
285 use crate::bridge::ffi::is_alias;
286
287 #[test]
288 fn call_op_raw_basic() {
289 let iv = IValue::from(vec![2, 3].as_slice());
290 #[allow(clippy::undocumented_unsafe_blocks)]
291 let mut results =
292 unsafe { call_op_raw("aten::ones", "", &mut [iv], &mut [], false) }.unwrap();
293
294 assert_eq!(results.outputs.len(), 1, "Expected 1 output");
295 assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
296
297 let t1 = results.outputs.pop().unwrap();
298 assert!(
299 matches!(results.alias_infos[0].kind, AliasKind::NewValue),
300 "output should be a new value"
301 );
302
303 let iv = IValue::from(vec![2, 3].as_slice());
304 #[allow(clippy::undocumented_unsafe_blocks)]
305 let mut results =
306 unsafe { call_op_raw("aten::ones", "", &mut [iv], &mut [], false) }.unwrap();
307
308 assert_eq!(results.outputs.len(), 1, "Expected 1 output");
309 assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
310
311 let t2 = results.outputs.pop().unwrap();
312 assert!(
313 matches!(results.alias_infos[0].kind, AliasKind::NewValue),
314 "output should be a new value"
315 );
316
317 #[allow(clippy::undocumented_unsafe_blocks)]
318 let results =
319 unsafe { call_op_raw("aten::allclose", "", &mut [t1, t2], &mut [], false) }.unwrap();
320 assert_eq!(results.outputs.len(), 1, "Expected 1 output");
321 assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
322 assert!(
323 matches!(results.alias_infos[0].kind, AliasKind::NewValue),
324 "output should be a new value"
325 );
326
327 assert!(
328 results.outputs[0]
329 .to_bool()
330 .expect("expected boolean return"),
331 "expected allclose to be true",
332 );
333 }
334
335 #[test]
336 fn call_op_raw_with_aliasing() {
337 let size = IValue::from(vec![2, 3].as_slice());
338 #[allow(clippy::undocumented_unsafe_blocks)]
339 let mut results =
340 unsafe { call_op_raw("aten::ones", "", &mut [size], &mut [], false) }.unwrap();
341 assert_eq!(results.outputs.len(), 1, "Expected 1 output");
342 assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
343
344 let t1 = results.outputs.pop().unwrap();
345 assert!(
346 matches!(results.alias_infos[0].kind, AliasKind::NewValue),
347 "output should be a new value"
348 );
349
350 let size = IValue::from(vec![2, 3].as_slice());
351 let mut args = vec![t1, size];
352 #[allow(clippy::undocumented_unsafe_blocks)]
353 let mut results =
354 unsafe { call_op_raw("aten::view", "", args.as_mut_slice(), &mut [], false) }.unwrap();
355 assert_eq!(results.outputs.len(), 1, "Expected 1 output");
356 assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
357
358 assert!(
359 matches!(results.alias_infos[0].kind, AliasKind::Alias),
360 "output should be an alias"
361 );
362 assert!(
363 matches!(results.alias_infos[0].arg_idx, 0),
364 "alias should point to the first input"
365 );
366 let t2 = results.outputs.pop().unwrap();
367
368 #[allow(clippy::undocumented_unsafe_blocks)]
369 let x = unsafe { &args[0].clone_unsafe().to_tensor().unwrap() };
370 assert!(
371 is_alias(x, &t2.to_tensor().unwrap()),
372 "c++ tensors should alias"
373 );
374 }
375
376 #[test]
377 fn call_op_raw_with_chunk_aliasing() {
378 let size = IValue::from(vec![2, 3].as_slice());
379 #[allow(clippy::undocumented_unsafe_blocks)]
380 let mut results =
381 unsafe { call_op_raw("aten::ones", "", &mut [size], &mut [], false) }.unwrap();
382 assert_eq!(results.outputs.len(), 1, "Expected 1 output");
383 assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
384
385 let t1 = results.outputs.pop().unwrap();
386 assert!(
387 matches!(results.alias_infos[0].kind, AliasKind::NewValue),
388 "output should be a new value"
389 );
390
391 #[allow(clippy::undocumented_unsafe_blocks)]
392 let mut results = unsafe {
393 call_op_raw(
394 "aten::chunk",
395 "",
396 &mut [t1.clone_unsafe(), IValue::from(2)],
397 &mut [],
398 false,
399 )
400 }
401 .unwrap();
402
403 assert_eq!(results.outputs.len(), 1, "Expected 1 output");
404 assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
405
406 let chunked_list = results.outputs.pop().unwrap();
407 assert!(
408 matches!(results.alias_infos[0].kind, AliasKind::Alias),
409 "chunk output should be an alias"
410 );
411 assert_eq!(
412 results.alias_infos[0].arg_idx, 0,
413 "chunk output should alias the first input"
414 );
415
416 let chunked_list = chunked_list
417 .to_tensor_list()
418 .expect("return of chunk should be a tensor list");
419
420 let tensor = t1.to_tensor().unwrap();
421 for chunk in &chunked_list {
422 assert!(is_alias(&tensor, chunk,), "c++ tensors should alias");
423 }
424 }
425
426 fn call_op_one(
431 op_name: impl AsRef<str>,
432 overload: impl AsRef<str>,
433 args: &[RValue],
434 kwargs: &HashMap<String, RValue>,
435 ) -> RValue {
436 let mut results = call_op(op_name, overload, args, kwargs, false).unwrap();
437 assert_eq!(results.len(), 1);
438 results.pop().unwrap()
439 }
440
441 #[test]
442 fn call_op_basic() {
443 let rv = RValue::from(vec![2, 3]);
444 let t1 = call_op_one("aten::ones", "", &[rv.clone()], &HashMap::new());
445 let t2 = call_op_one("aten::ones", "", &[rv.clone()], &HashMap::new());
446
447 match (&t1, &t2) {
448 (RValue::Tensor(t1), RValue::Tensor(t2)) => {
449 assert!(!t1.aliases(t2));
450 }
451 _ => panic!("expected tensor"),
452 }
453
454 let result = call_op("aten::allclose", "", &[t1, t2], &HashMap::new(), false)
455 .unwrap()
456 .pop()
457 .unwrap();
458
459 assert!(
460 matches!(result, RValue::Bool(true)),
461 "Expected true for allclose output"
462 );
463 }
464
465 #[test]
466 fn call_op_multi_alias() {
467 let rv = RValue::from(vec![2, 3]);
468 let t1 = call_op_one("aten::ones", "", &[rv.clone()], &HashMap::new());
469 let t1_view = call_op_one("aten::view", "", &[t1.clone(), rv.clone()], &HashMap::new());
470
471 let t1_cell: TensorCell = t1.clone().try_into().unwrap();
472 let t1_view_cell: TensorCell = t1_view.try_into().unwrap();
473 assert!(t1_cell.aliases(&t1_view_cell));
474
475 let handle1 = std::thread::spawn(move || {
477 for _ in 0..1000 {
478 call_op(
479 "aten::add",
480 "Tensor",
481 &[t1.clone(), t1.clone()],
482 &HashMap::new(),
483 false,
484 )
485 .unwrap();
486 }
487 });
488
489 let handle2 = std::thread::spawn(move || {
490 let t1_view: RValue = t1_view_cell.clone().into();
491 for _ in 0..1000 {
492 call_op(
493 "aten::add",
494 "Tensor",
495 &[t1_view.clone(), t1_view.clone()],
496 &HashMap::new(),
497 false,
498 )
499 .unwrap();
500 }
501 });
502 handle1.join().unwrap();
503 handle2.join().unwrap();
504 }
505
506 #[test]
508 fn call_op_multi_alias_mutable() {
509 let rv = RValue::from(vec![2, 3]);
510 let t1 = call_op_one("aten::ones", "", &[rv.clone()], &HashMap::new());
511 let t1_view = call_op_one("aten::view", "", &[t1.clone(), rv.clone()], &HashMap::new());
512
513 call_op_one(
514 "aten::add_",
515 "Tensor",
516 &[t1_view.clone(), t1_view.clone()],
517 &HashMap::new(),
518 );
519 }
520
521 #[test]
524 fn call_op_implicit_scalar_to_tensor() {
525 let tensor = call_op_one(
526 "aten::ones",
527 "",
528 &[RValue::from(vec![2, 3])],
529 &HashMap::new(),
530 );
531 call_op_one(
532 "aten::add_",
533 "Tensor",
534 &[tensor, RValue::Int(1)],
535 &HashMap::new(),
536 );
537 }
538
539 #[should_panic]
540 #[test]
541 fn call_op_mutating_while_borrowed() {
542 let rv = RValue::from(vec![2, 3]);
543 let t1 = call_op_one("aten::ones", "", &[rv.clone()], &HashMap::new());
544 let t1_view = call_op_one("aten::view", "", &[t1.clone(), rv.clone()], &HashMap::new());
545
546 let t1_cell: TensorCell = t1.clone().try_into().unwrap();
547 let t1_view_cell: TensorCell = t1_view.try_into().unwrap();
548 assert!(t1_cell.aliases(&t1_view_cell));
549
550 let handle1 = std::thread::spawn(move || {
552 for _ in 0..1000 {
553 call_op_one(
554 "aten::add",
555 "Tensor",
556 &[t1.clone(), t1.clone()],
557 &HashMap::new(),
558 );
559 }
560 });
561
562 let handle2 = std::thread::spawn(move || {
563 let t1_view: RValue = t1_view_cell.clone().into();
564 for _ in 0..1000 {
567 call_op_one(
568 "aten::add_",
569 "Tensor",
570 &[t1_view.clone(), t1_view.clone()],
571 &HashMap::new(),
572 );
573 }
574 });
575 handle1.join().unwrap();
576 handle2.join().unwrap();
577 }
578
579 #[test]
580 fn kwargs() {
581 let rv = RValue::from(vec![2, 3]);
582 let kwargs = HashMap::from([("size".into(), rv)]);
583 let t1 = call_op_one("aten::ones", "", &[], &kwargs.clone());
584 let t2 = call_op_one("aten::ones", "", &[], &kwargs);
585
586 match (&t1, &t2) {
587 (RValue::Tensor(t1), RValue::Tensor(t2)) => {
588 assert!(!t1.aliases(t2));
589 }
590 _ => panic!("expected tensor"),
591 }
592
593 let result = call_op("aten::allclose", "", &[t1, t2], &HashMap::new(), true)
594 .unwrap()
595 .pop()
596 .unwrap();
597
598 assert!(
599 matches!(result, RValue::Bool(true)),
600 "Expected true for allclose output"
601 );
602 }
603
604 #[test]
605 fn kwargs_alias() {
606 let rv = RValue::from(vec![2, 3]);
607 let kwargs = HashMap::from([("size".into(), rv.clone())]);
608 let t1 = call_op_one("aten::ones", "", &[], &kwargs);
609
610 let kwargs = HashMap::from([("size".into(), rv.clone()), ("self".into(), t1.clone())]);
611 let t1_view = call_op_one("aten::view", "", &[], &kwargs);
612
613 let t1_cell: TensorCell = t1.clone().try_into().unwrap();
614 let t1_view_cell: TensorCell = t1_view.try_into().unwrap();
615 assert!(t1_cell.aliases(&t1_view_cell));
616 }
617
618 #[test]
619 fn kwargs_chunk_alias() {
620 let size = RValue::from(vec![2, 3]);
621 let kwargs = HashMap::from([("size".into(), size)]);
622 let t1 = call_op_one("aten::ones", "", &[], &kwargs);
623
624 let kwargs = HashMap::from([("self".into(), t1.clone()), ("chunks".into(), 2.into())]);
625 let chunked = call_op_one("aten::chunk", "", &[], &kwargs);
626
627 let cells: Vec<TensorCell> = chunked
628 .try_into()
629 .expect("return of chunk should be a tensor list");
630
631 let original_cell: TensorCell = t1.try_into().expect("return of ones should be a tensor");
632 for cell in cells {
633 assert!(original_cell.aliases(&cell));
634 }
635 }
636
637 #[should_panic]
638 #[test]
639 fn kwargs_mutate_double_borrow() {
640 let size = RValue::from(vec![2, 3]);
641 let kwargs = HashMap::from([("size".into(), size)]);
642 let t1 = call_op_one("aten::ones", "", &[], &kwargs.clone());
643 let t1_view = call_op_one("aten::view", "", &[t1.clone()], &kwargs);
644 let t1_view_cell: TensorCell = t1_view.try_into().unwrap();
645
646 let handle1 = std::thread::spawn(move || {
647 for _ in 0..1000 {
648 let kwargs = HashMap::from([("size".into(), t1.clone())]);
649 call_op_one("aten::add", "Tensor", &[t1.clone()], &kwargs);
650 }
651 });
652
653 let handle2 = std::thread::spawn(move || {
654 let t1_view: RValue = t1_view_cell.into();
655 for _ in 0..1000 {
658 let kwargs = HashMap::from([("size".into(), t1_view.clone())]);
659 call_op_one("aten::add_", "Tensor", &[t1_view.clone()], &kwargs);
660 }
661 });
662 handle1.join().unwrap();
663 handle2.join().unwrap();
664 }
665
666 #[test]
667 fn test_flatten_results() {
668 let size = RValue::from(vec![5, 2]);
669 let kwargs = HashMap::from([("size".into(), size)]);
670 let t = call_op_one("aten::ones", "", &[], &kwargs.clone());
671 let res = call_op(
672 "aten::split_with_sizes",
673 "",
674 &[t, vec![1, 4].into()],
675 &HashMap::new(),
676 true,
677 )
678 .unwrap();
679 assert_eq!(res.len(), 2);
680 match (&res[0], &res[1]) {
681 (RValue::Tensor(t1), RValue::Tensor(t2)) => {
682 assert_eq!(t1.borrow().numel(), 2);
683 assert_eq!(t2.borrow().numel(), 8);
684 }
685 _ => panic!("unexpected results: {:?}", res),
686 }
687 }
688
689 #[test]
690 fn test_call_op_mutating_self_with_no_return() {
691 let t: TensorCell = call_op_one("aten::ones", "", &[vec![5, 1].into()], &HashMap::new())
692 .try_into()
693 .unwrap();
694 let res = call_op(
695 "aten::_foreach_add_",
696 "Scalar",
697 &[vec![t.clone()].into(), 1.into()],
698 &HashMap::new(),
699 true,
700 )
701 .unwrap();
702 assert_eq!(res.len(), 0);
703 let expected: TensorCell = call_op_one(
704 "aten::full",
705 "",
706 &[vec![5, 1].into(), 2.into()],
707 &HashMap::new(),
708 )
709 .try_into()
710 .unwrap();
711 assert!(t.borrow().equal(&expected.borrow()));
712 }
713
714 #[test]
715 fn test_call_op_arg_types() {
716 let args_info =
717 crate::bridge::ffi::get_schema_args_info("aten::_foreach_add_", "Scalar").unwrap();
718 assert_eq!(args_info.len(), 2);
719 assert_eq!(args_info[0].name, "self");
720 assert!(args_info[0].is_mutable);
721 assert!(args_info[0].type_.is_tensor_list());
722 assert!(!args_info[0].type_.is_tensor());
723 assert_eq!(args_info[1].name, "scalar");
724 assert!(!args_info[1].is_mutable);
725 assert!(!args_info[1].type_.is_tensor());
726 assert!(!args_info[1].type_.is_tensor_list());
727
728 let args_info =
729 crate::bridge::ffi::get_schema_args_info("aten::_foreach_add_", "Tensor").unwrap();
730 assert_eq!(args_info[1].name, "other");
731 assert!(!args_info[1].is_mutable);
732 assert!(args_info[1].type_.is_tensor());
733 assert!(!args_info[1].type_.is_tensor_list());
734 }
735}