torch_sys2/
testing.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Testing utilities
10
11use pyo3::prelude::*;
12
13use crate::Tensor;
14use crate::torch_allclose;
15use crate::torch_full;
16use crate::torch_stack;
17
18pub fn allclose(a: &Tensor, b: &Tensor) -> Result<bool, String> {
19    Python::with_gil(|py| {
20        let a_obj = a.inner.bind(py);
21        let b_obj = b.inner.bind(py);
22
23        torch_allclose(py)
24            .call1((a_obj, b_obj))
25            .map_err(|e| format!("Failed to call torch.allclose: {}", e))?
26            .extract()
27            .map_err(|e| format!("Failed to extract result: {}", e))
28    })
29}
30
31pub fn cuda_full(size: &[i64], value: f32) -> Tensor {
32    Python::with_gil(|py| {
33        let size_tuple = pyo3::types::PyTuple::new(py, size).unwrap();
34
35        let kwargs = pyo3::types::PyDict::new(py);
36        kwargs.set_item("device", "cuda").unwrap();
37
38        let result = torch_full(py)
39            .call((size_tuple, value), Some(&kwargs))
40            .unwrap();
41
42        Tensor {
43            inner: result.clone().unbind(),
44        }
45    })
46}
47
48pub fn stack(tensors: &[Tensor]) -> Tensor {
49    Python::with_gil(|py| {
50        // Convert Rust tensor slice to Python list
51        let tensor_list = pyo3::types::PyList::empty(py);
52        for tensor in tensors {
53            tensor_list.append(tensor.inner.bind(py)).unwrap();
54        }
55
56        let result = torch_stack(py).call1((tensor_list,)).unwrap();
57
58        Tensor {
59            inner: result.clone().unbind(),
60        }
61    })
62}