torch_sys/
call_op.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use std::collections::HashMap;
10
11use cxx::ExternType;
12use cxx::type_id;
13use thiserror::Error;
14
15use crate::IValue;
16use crate::IValueKind;
17use crate::RValue;
18use crate::TensorCell;
19use crate::borrow::BorrowError;
20use crate::borrow::BorrowType;
21use crate::borrow::MultiBorrow;
22use crate::bridge::ffi::AliasInfo;
23use crate::bridge::ffi::AliasKind;
24use crate::bridge::ffi::Kwarg;
25use crate::bridge::ffi::Tensor;
26pub use crate::bridge::ffi::get_schema_args_info;
27use crate::ivalue::OpaqueIValue;
28use crate::ivalue::OpaqueIValueCell;
29use crate::rvalue::rvalue_to_ivalue;
30
31/// Errors that can occur while calling an operator.
32#[derive(Error, Debug)]
33#[non_exhaustive]
34pub enum CallOpError {
35    #[error("torch operator error {0}")]
36    TorchOperatorError(#[from] cxx::Exception),
37
38    #[error("error borrowing: {0}")]
39    BorrowError(#[from] BorrowError),
40
41    #[error("invalid kwarg '{kwarg}' for op: '{operator}.{overload}'")]
42    InvalidKwargs {
43        kwarg: String,
44        operator: String,
45        overload: String,
46    },
47}
48
49/// An opaque type that represents the type of an argument to a torch operator.
50/// This is essentially used to interface with C++ code and should not be instantiated
51/// or owned by Rust code.
52#[repr(C)]
53pub struct TypePtr {
54    _private: [u8; 0],
55}
56
57// SAFETY: Register our custom bindings with cxx. This is just treating
58// at::TypePtr as an opaque type and we would only have refs to it.
59unsafe impl ExternType for TypePtr {
60    type Id = type_id!("c10::TypePtr");
61    type Kind = cxx::kind::Opaque;
62}
63
64impl TypePtr {
65    #[allow(dead_code)]
66    #[inline]
67    pub fn is_tensor(&self) -> bool {
68        crate::bridge::ffi::type_ptr_is_tensor(self)
69    }
70
71    #[allow(dead_code)]
72    #[inline]
73    pub fn is_tensor_list(&self) -> bool {
74        crate::bridge::ffi::type_ptr_is_tensor_list(self)
75    }
76
77    #[allow(dead_code)]
78    #[inline]
79    pub fn is_optional_tensor(&self) -> bool {
80        crate::bridge::ffi::type_ptr_is_optional_tensor(self)
81    }
82
83    #[allow(dead_code)]
84    #[inline]
85    pub fn is_optional_tensor_list(&self) -> bool {
86        crate::bridge::ffi::type_ptr_is_optional_tensor_list(self)
87    }
88}
89
90impl std::fmt::Debug for TypePtr {
91    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
92        f.debug_struct("TypePtr").field("type", &"<...>").finish()
93    }
94}
95
96fn get_aliased_rvalue<'a>(
97    alias_info: &'a AliasInfo,
98    args: &'a [RValue],
99    kwargs: &'a HashMap<String, RValue>,
100) -> &'a RValue {
101    match args.get(alias_info.arg_idx) {
102        // The alias references an arg.
103        Some(rvalue) => rvalue,
104        None => {
105            // This alias references a kwarg.
106            let (_name, rvalue) = kwargs
107                .iter()
108                .find(|(key, _)| *key == &alias_info.arg_name)
109                // The aliased value must have been passed in through
110                // either args or kwargs, panic if not.
111                .unwrap();
112            rvalue
113        }
114    }
115}
116
117/// Create a TensorCell out of a tensor, with the correct aliasing information.
118fn make_tensor_cell(
119    tensor: Tensor,
120    alias_info: &AliasInfo,
121    args: &[RValue],
122    kwargs: &HashMap<String, RValue>,
123) -> TensorCell {
124    match alias_info.kind {
125        AliasKind::NewValue => TensorCell::new(tensor),
126        AliasKind::Alias => match get_aliased_rvalue(alias_info, args, kwargs) {
127            RValue::Tensor(cell) => TensorCell::new_with_alias(tensor, cell),
128            // TODO: call_op should broken down into helpers and directly used in stream.rs
129            // and there if wirevalue was IValue we just create a new TensorCell even if it is being
130            // aliased as it will not be tracked on the rust worker yet.
131            RValue::Opaque(_) => TensorCell::new(tensor),
132            _ => panic!("must be a tensor to create an aliases tensorcell"),
133        },
134        _ => panic!("unsupported alias kind"),
135    }
136}
137
138fn make_opaque_ivalue_cell(
139    ivalue: OpaqueIValue,
140    alias_info: &AliasInfo,
141    args: &[RValue],
142    kwargs: &HashMap<String, RValue>,
143) -> OpaqueIValueCell {
144    match alias_info.kind {
145        AliasKind::NewValue => OpaqueIValueCell::new(ivalue),
146        AliasKind::Alias => match get_aliased_rvalue(alias_info, args, kwargs) {
147            RValue::Opaque(cell) => OpaqueIValueCell::new_with_alias(ivalue, cell),
148            _ => panic!("must be an opaque ivalue to create an aliases opaque ivalue cell"),
149        },
150        _ => panic!("unsupported alias kind"),
151    }
152}
153
154/// Call a PyTorch-dispatched operator by name.
155///
156/// `op_name` is the fully qualified name of the operator, like `"aten::add"`.
157///
158/// `overload` is the name of the overload, like `"Scalar"`. Due to a
159/// quirk of libtorch, the `default` overload must be called by passing
160/// an empty string.
161///
162/// `flatten_results` is a flag that indicates whether the results of the
163/// operator should be flattened into a single list. Extracting out values
164/// from lists, tuples and dicts recursively.
165///
166/// # Errors
167/// If the called operator throws an exception, a [`cxx::Exception`]
168/// will be returned which contains the C++ exception.
169pub fn call_op(
170    op_name: impl AsRef<str>,
171    overload: impl AsRef<str>,
172    args: &[RValue],
173    kwargs: &HashMap<String, RValue>,
174    flatten_results: bool,
175) -> Result<Vec<RValue>, CallOpError> {
176    // SAFETY: We will be making an unchecked clone of each tensor to pass to to
177    // C++, so we need to hold a borrow of each input tensor for the duration of
178    // this function.
179    let mut multiborrow = MultiBorrow::new();
180
181    let mutates = get_schema_args_info(op_name.as_ref(), overload.as_ref())?;
182
183    // Queue up borrows for the args.
184    for (arg, arg_mutability) in args.iter().zip(&mutates) {
185        let borrow_type = if arg_mutability.is_mutable {
186            BorrowType::Mutable
187        } else {
188            BorrowType::Shared
189        };
190        multiborrow.add(arg, borrow_type);
191    }
192
193    // Queue up borrows for the kwargs.
194    for (key, arg) in kwargs.iter() {
195        let arg_mutability = mutates.iter().find(|arg| &arg.name == key).ok_or_else(|| {
196            CallOpError::InvalidKwargs {
197                kwarg: key.to_string(),
198                operator: op_name.as_ref().to_string(),
199                overload: overload.as_ref().to_string(),
200            }
201        })?;
202        let borrow_type = if arg_mutability.is_mutable {
203            BorrowType::Mutable
204        } else {
205            BorrowType::Shared
206        };
207        multiborrow.add(arg, borrow_type);
208    }
209
210    // Actually execute the borrows.
211    let _borrows = multiborrow.borrow()?;
212
213    let mut ivalue_args: Vec<IValue> = args
214        .iter()
215        // SAFETY: The borrows above guard the unchecked clones done by
216        // `rvalue_to_ivalue`. This may result in multiple mutable references to
217        // tensor data, but the C++ side is responsible for making sure that is safe
218        // within the context of a single operator invocation.
219        .map(|rvalue| unsafe { rvalue_to_ivalue(rvalue) })
220        .collect();
221    let mut ivalue_kwargs: Vec<Kwarg> = kwargs
222        .iter()
223        .map(|(key, value)| Kwarg {
224            name: key.clone(),
225            // SAFETY: see above
226            arg: unsafe { rvalue_to_ivalue(value) },
227        })
228        .collect();
229
230    // SAFETY: we will be unifying the ownership of potential aliases in the
231    // returned TensorCells so this is okay to call.
232    let call_op_result = unsafe {
233        crate::bridge::ffi::call_op_raw(
234            op_name.as_ref(),
235            overload.as_ref(),
236            &mut ivalue_args,
237            &mut ivalue_kwargs,
238            flatten_results,
239        )?
240    };
241    Ok(call_op_result
242        .outputs
243        .into_iter()
244        .zip(call_op_result.alias_infos)
245        .map(|(ivalue, alias_info)| match ivalue.kind() {
246            IValueKind::Tensor => RValue::Tensor(make_tensor_cell(
247                ivalue.to_tensor().unwrap(),
248                &alias_info,
249                args,
250                kwargs,
251            )),
252            IValueKind::Bool => RValue::Bool(ivalue.to_bool().unwrap()),
253            IValueKind::Int => RValue::Int(ivalue.to_int().unwrap()),
254            IValueKind::IntList => RValue::IntList(ivalue.to_int_list().unwrap()),
255            IValueKind::Double => RValue::Double(ivalue.to_double().unwrap()),
256            IValueKind::String => RValue::String(ivalue.to_string().unwrap()),
257            IValueKind::TensorList => {
258                let mut tensors = Vec::new();
259                let tensor_list = ivalue.to_tensor_list().unwrap();
260                for tensor in tensor_list {
261                    tensors.push(make_tensor_cell(tensor, &alias_info, args, kwargs));
262                }
263                RValue::TensorList(tensors)
264            }
265            IValueKind::Device => RValue::Device(ivalue.to_device().unwrap()),
266            IValueKind::None => RValue::None,
267            IValueKind::Other => RValue::Opaque(make_opaque_ivalue_cell(
268                ivalue.to_opaque().unwrap(),
269                &alias_info,
270                args,
271                kwargs,
272            )),
273        })
274        .collect())
275}
276
277#[cfg(test)]
278mod tests {
279    use core::panic;
280
281    use super::*;
282    use crate::CloneUnsafe;
283    use crate::bridge::ffi::AliasKind;
284    use crate::bridge::ffi::call_op_raw;
285    use crate::bridge::ffi::is_alias;
286
287    #[test]
288    fn call_op_raw_basic() {
289        let iv = IValue::from(vec![2, 3].as_slice());
290        #[allow(clippy::undocumented_unsafe_blocks)]
291        let mut results =
292            unsafe { call_op_raw("aten::ones", "", &mut [iv], &mut [], false) }.unwrap();
293
294        assert_eq!(results.outputs.len(), 1, "Expected 1 output");
295        assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
296
297        let t1 = results.outputs.pop().unwrap();
298        assert!(
299            matches!(results.alias_infos[0].kind, AliasKind::NewValue),
300            "output should be a new value"
301        );
302
303        let iv = IValue::from(vec![2, 3].as_slice());
304        #[allow(clippy::undocumented_unsafe_blocks)]
305        let mut results =
306            unsafe { call_op_raw("aten::ones", "", &mut [iv], &mut [], false) }.unwrap();
307
308        assert_eq!(results.outputs.len(), 1, "Expected 1 output");
309        assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
310
311        let t2 = results.outputs.pop().unwrap();
312        assert!(
313            matches!(results.alias_infos[0].kind, AliasKind::NewValue),
314            "output should be a new value"
315        );
316
317        #[allow(clippy::undocumented_unsafe_blocks)]
318        let results =
319            unsafe { call_op_raw("aten::allclose", "", &mut [t1, t2], &mut [], false) }.unwrap();
320        assert_eq!(results.outputs.len(), 1, "Expected 1 output");
321        assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
322        assert!(
323            matches!(results.alias_infos[0].kind, AliasKind::NewValue),
324            "output should be a new value"
325        );
326
327        assert!(
328            results.outputs[0]
329                .to_bool()
330                .expect("expected boolean return"),
331            "expected allclose to be true",
332        );
333    }
334
335    #[test]
336    fn call_op_raw_with_aliasing() {
337        let size = IValue::from(vec![2, 3].as_slice());
338        #[allow(clippy::undocumented_unsafe_blocks)]
339        let mut results =
340            unsafe { call_op_raw("aten::ones", "", &mut [size], &mut [], false) }.unwrap();
341        assert_eq!(results.outputs.len(), 1, "Expected 1 output");
342        assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
343
344        let t1 = results.outputs.pop().unwrap();
345        assert!(
346            matches!(results.alias_infos[0].kind, AliasKind::NewValue),
347            "output should be a new value"
348        );
349
350        let size = IValue::from(vec![2, 3].as_slice());
351        let mut args = vec![t1, size];
352        #[allow(clippy::undocumented_unsafe_blocks)]
353        let mut results =
354            unsafe { call_op_raw("aten::view", "", args.as_mut_slice(), &mut [], false) }.unwrap();
355        assert_eq!(results.outputs.len(), 1, "Expected 1 output");
356        assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
357
358        assert!(
359            matches!(results.alias_infos[0].kind, AliasKind::Alias),
360            "output should be an alias"
361        );
362        assert!(
363            matches!(results.alias_infos[0].arg_idx, 0),
364            "alias should point to the first input"
365        );
366        let t2 = results.outputs.pop().unwrap();
367
368        #[allow(clippy::undocumented_unsafe_blocks)]
369        let x = unsafe { &args[0].clone_unsafe().to_tensor().unwrap() };
370        assert!(
371            is_alias(x, &t2.to_tensor().unwrap()),
372            "c++ tensors should alias"
373        );
374    }
375
376    #[test]
377    fn call_op_raw_with_chunk_aliasing() {
378        let size = IValue::from(vec![2, 3].as_slice());
379        #[allow(clippy::undocumented_unsafe_blocks)]
380        let mut results =
381            unsafe { call_op_raw("aten::ones", "", &mut [size], &mut [], false) }.unwrap();
382        assert_eq!(results.outputs.len(), 1, "Expected 1 output");
383        assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
384
385        let t1 = results.outputs.pop().unwrap();
386        assert!(
387            matches!(results.alias_infos[0].kind, AliasKind::NewValue),
388            "output should be a new value"
389        );
390
391        #[allow(clippy::undocumented_unsafe_blocks)]
392        let mut results = unsafe {
393            call_op_raw(
394                "aten::chunk",
395                "",
396                &mut [t1.clone_unsafe(), IValue::from(2)],
397                &mut [],
398                false,
399            )
400        }
401        .unwrap();
402
403        assert_eq!(results.outputs.len(), 1, "Expected 1 output");
404        assert_eq!(results.alias_infos.len(), 1, "Expected 1 output");
405
406        let chunked_list = results.outputs.pop().unwrap();
407        assert!(
408            matches!(results.alias_infos[0].kind, AliasKind::Alias),
409            "chunk output should be an alias"
410        );
411        assert_eq!(
412            results.alias_infos[0].arg_idx, 0,
413            "chunk output should alias the first input"
414        );
415
416        let chunked_list = chunked_list
417            .to_tensor_list()
418            .expect("return of chunk should be a tensor list");
419
420        let tensor = t1.to_tensor().unwrap();
421        for chunk in &chunked_list {
422            assert!(is_alias(&tensor, chunk,), "c++ tensors should alias");
423        }
424    }
425
426    /// Convenience function to avoid lots of unwrapping in test code.
427    ///
428    /// # Panics
429    /// Panics if the arg has more than one result, or if an error occurred.
430    fn call_op_one(
431        op_name: impl AsRef<str>,
432        overload: impl AsRef<str>,
433        args: &[RValue],
434        kwargs: &HashMap<String, RValue>,
435    ) -> RValue {
436        let mut results = call_op(op_name, overload, args, kwargs, false).unwrap();
437        assert_eq!(results.len(), 1);
438        results.pop().unwrap()
439    }
440
441    #[test]
442    fn call_op_basic() {
443        let rv = RValue::from(vec![2, 3]);
444        let t1 = call_op_one("aten::ones", "", &[rv.clone()], &HashMap::new());
445        let t2 = call_op_one("aten::ones", "", &[rv.clone()], &HashMap::new());
446
447        match (&t1, &t2) {
448            (RValue::Tensor(t1), RValue::Tensor(t2)) => {
449                assert!(!t1.aliases(t2));
450            }
451            _ => panic!("expected tensor"),
452        }
453
454        let result = call_op("aten::allclose", "", &[t1, t2], &HashMap::new(), false)
455            .unwrap()
456            .pop()
457            .unwrap();
458
459        assert!(
460            matches!(result, RValue::Bool(true)),
461            "Expected true for allclose output"
462        );
463    }
464
465    #[test]
466    fn call_op_multi_alias() {
467        let rv = RValue::from(vec![2, 3]);
468        let t1 = call_op_one("aten::ones", "", &[rv.clone()], &HashMap::new());
469        let t1_view = call_op_one("aten::view", "", &[t1.clone(), rv.clone()], &HashMap::new());
470
471        let t1_cell: TensorCell = t1.clone().try_into().unwrap();
472        let t1_view_cell: TensorCell = t1_view.try_into().unwrap();
473        assert!(t1_cell.aliases(&t1_view_cell));
474
475        // Two threads can call non-mutating ops on the same alias with no problem.
476        let handle1 = std::thread::spawn(move || {
477            for _ in 0..1000 {
478                call_op(
479                    "aten::add",
480                    "Tensor",
481                    &[t1.clone(), t1.clone()],
482                    &HashMap::new(),
483                    false,
484                )
485                .unwrap();
486            }
487        });
488
489        let handle2 = std::thread::spawn(move || {
490            let t1_view: RValue = t1_view_cell.clone().into();
491            for _ in 0..1000 {
492                call_op(
493                    "aten::add",
494                    "Tensor",
495                    &[t1_view.clone(), t1_view.clone()],
496                    &HashMap::new(),
497                    false,
498                )
499                .unwrap();
500            }
501        });
502        handle1.join().unwrap();
503        handle2.join().unwrap();
504    }
505
506    /// Trying to call an op with a mutable and immutable borrow of the same alias should work.
507    #[test]
508    fn call_op_multi_alias_mutable() {
509        let rv = RValue::from(vec![2, 3]);
510        let t1 = call_op_one("aten::ones", "", &[rv.clone()], &HashMap::new());
511        let t1_view = call_op_one("aten::view", "", &[t1.clone(), rv.clone()], &HashMap::new());
512
513        call_op_one(
514            "aten::add_",
515            "Tensor",
516            &[t1_view.clone(), t1_view.clone()],
517            &HashMap::new(),
518        );
519    }
520
521    /// Test that we implicitly convert scalar args to tensors for the appropriate
522    /// operations.
523    #[test]
524    fn call_op_implicit_scalar_to_tensor() {
525        let tensor = call_op_one(
526            "aten::ones",
527            "",
528            &[RValue::from(vec![2, 3])],
529            &HashMap::new(),
530        );
531        call_op_one(
532            "aten::add_",
533            "Tensor",
534            &[tensor, RValue::Int(1)],
535            &HashMap::new(),
536        );
537    }
538
539    #[should_panic]
540    #[test]
541    fn call_op_mutating_while_borrowed() {
542        let rv = RValue::from(vec![2, 3]);
543        let t1 = call_op_one("aten::ones", "", &[rv.clone()], &HashMap::new());
544        let t1_view = call_op_one("aten::view", "", &[t1.clone(), rv.clone()], &HashMap::new());
545
546        let t1_cell: TensorCell = t1.clone().try_into().unwrap();
547        let t1_view_cell: TensorCell = t1_view.try_into().unwrap();
548        assert!(t1_cell.aliases(&t1_view_cell));
549
550        // Two threads can call non-mutating ops on the same alias with no problem.
551        let handle1 = std::thread::spawn(move || {
552            for _ in 0..1000 {
553                call_op_one(
554                    "aten::add",
555                    "Tensor",
556                    &[t1.clone(), t1.clone()],
557                    &HashMap::new(),
558                );
559            }
560        });
561
562        let handle2 = std::thread::spawn(move || {
563            let t1_view: RValue = t1_view_cell.clone().into();
564            // Trying to mutate this tensor while it is borrowed by the first
565            // thread should panic!
566            for _ in 0..1000 {
567                call_op_one(
568                    "aten::add_",
569                    "Tensor",
570                    &[t1_view.clone(), t1_view.clone()],
571                    &HashMap::new(),
572                );
573            }
574        });
575        handle1.join().unwrap();
576        handle2.join().unwrap();
577    }
578
579    #[test]
580    fn kwargs() {
581        let rv = RValue::from(vec![2, 3]);
582        let kwargs = HashMap::from([("size".into(), rv)]);
583        let t1 = call_op_one("aten::ones", "", &[], &kwargs.clone());
584        let t2 = call_op_one("aten::ones", "", &[], &kwargs);
585
586        match (&t1, &t2) {
587            (RValue::Tensor(t1), RValue::Tensor(t2)) => {
588                assert!(!t1.aliases(t2));
589            }
590            _ => panic!("expected tensor"),
591        }
592
593        let result = call_op("aten::allclose", "", &[t1, t2], &HashMap::new(), true)
594            .unwrap()
595            .pop()
596            .unwrap();
597
598        assert!(
599            matches!(result, RValue::Bool(true)),
600            "Expected true for allclose output"
601        );
602    }
603
604    #[test]
605    fn kwargs_alias() {
606        let rv = RValue::from(vec![2, 3]);
607        let kwargs = HashMap::from([("size".into(), rv.clone())]);
608        let t1 = call_op_one("aten::ones", "", &[], &kwargs);
609
610        let kwargs = HashMap::from([("size".into(), rv.clone()), ("self".into(), t1.clone())]);
611        let t1_view = call_op_one("aten::view", "", &[], &kwargs);
612
613        let t1_cell: TensorCell = t1.clone().try_into().unwrap();
614        let t1_view_cell: TensorCell = t1_view.try_into().unwrap();
615        assert!(t1_cell.aliases(&t1_view_cell));
616    }
617
618    #[test]
619    fn kwargs_chunk_alias() {
620        let size = RValue::from(vec![2, 3]);
621        let kwargs = HashMap::from([("size".into(), size)]);
622        let t1 = call_op_one("aten::ones", "", &[], &kwargs);
623
624        let kwargs = HashMap::from([("self".into(), t1.clone()), ("chunks".into(), 2.into())]);
625        let chunked = call_op_one("aten::chunk", "", &[], &kwargs);
626
627        let cells: Vec<TensorCell> = chunked
628            .try_into()
629            .expect("return of chunk should be a tensor list");
630
631        let original_cell: TensorCell = t1.try_into().expect("return of ones should be a tensor");
632        for cell in cells {
633            assert!(original_cell.aliases(&cell));
634        }
635    }
636
637    #[should_panic]
638    #[test]
639    fn kwargs_mutate_double_borrow() {
640        let size = RValue::from(vec![2, 3]);
641        let kwargs = HashMap::from([("size".into(), size)]);
642        let t1 = call_op_one("aten::ones", "", &[], &kwargs.clone());
643        let t1_view = call_op_one("aten::view", "", &[t1.clone()], &kwargs);
644        let t1_view_cell: TensorCell = t1_view.try_into().unwrap();
645
646        let handle1 = std::thread::spawn(move || {
647            for _ in 0..1000 {
648                let kwargs = HashMap::from([("size".into(), t1.clone())]);
649                call_op_one("aten::add", "Tensor", &[t1.clone()], &kwargs);
650            }
651        });
652
653        let handle2 = std::thread::spawn(move || {
654            let t1_view: RValue = t1_view_cell.into();
655            // Trying to mutate this tensor while it is borrowed by the first
656            // thread should panic!
657            for _ in 0..1000 {
658                let kwargs = HashMap::from([("size".into(), t1_view.clone())]);
659                call_op_one("aten::add_", "Tensor", &[t1_view.clone()], &kwargs);
660            }
661        });
662        handle1.join().unwrap();
663        handle2.join().unwrap();
664    }
665
666    #[test]
667    fn test_flatten_results() {
668        let size = RValue::from(vec![5, 2]);
669        let kwargs = HashMap::from([("size".into(), size)]);
670        let t = call_op_one("aten::ones", "", &[], &kwargs.clone());
671        let res = call_op(
672            "aten::split_with_sizes",
673            "",
674            &[t, vec![1, 4].into()],
675            &HashMap::new(),
676            true,
677        )
678        .unwrap();
679        assert_eq!(res.len(), 2);
680        match (&res[0], &res[1]) {
681            (RValue::Tensor(t1), RValue::Tensor(t2)) => {
682                assert_eq!(t1.borrow().numel(), 2);
683                assert_eq!(t2.borrow().numel(), 8);
684            }
685            _ => panic!("unexpected results: {:?}", res),
686        }
687    }
688
689    #[test]
690    fn test_call_op_mutating_self_with_no_return() {
691        let t: TensorCell = call_op_one("aten::ones", "", &[vec![5, 1].into()], &HashMap::new())
692            .try_into()
693            .unwrap();
694        let res = call_op(
695            "aten::_foreach_add_",
696            "Scalar",
697            &[vec![t.clone()].into(), 1.into()],
698            &HashMap::new(),
699            true,
700        )
701        .unwrap();
702        assert_eq!(res.len(), 0);
703        let expected: TensorCell = call_op_one(
704            "aten::full",
705            "",
706            &[vec![5, 1].into(), 2.into()],
707            &HashMap::new(),
708        )
709        .try_into()
710        .unwrap();
711        assert!(t.borrow().equal(&expected.borrow()));
712    }
713
714    #[test]
715    fn test_call_op_arg_types() {
716        let args_info =
717            crate::bridge::ffi::get_schema_args_info("aten::_foreach_add_", "Scalar").unwrap();
718        assert_eq!(args_info.len(), 2);
719        assert_eq!(args_info[0].name, "self");
720        assert!(args_info[0].is_mutable);
721        assert!(args_info[0].type_.is_tensor_list());
722        assert!(!args_info[0].type_.is_tensor());
723        assert_eq!(args_info[1].name, "scalar");
724        assert!(!args_info[1].is_mutable);
725        assert!(!args_info[1].type_.is_tensor());
726        assert!(!args_info[1].type_.is_tensor_list());
727
728        let args_info =
729            crate::bridge::ffi::get_schema_args_info("aten::_foreach_add_", "Tensor").unwrap();
730        assert_eq!(args_info[1].name, "other");
731        assert!(!args_info[1].is_mutable);
732        assert!(args_info[1].type_.is_tensor());
733        assert!(!args_info[1].type_.is_tensor_list());
734    }
735}