1use pyo3::prelude::*;
12
13use crate::Tensor;
14use crate::torch_allclose;
15use crate::torch_full;
16use crate::torch_stack;
17
18pub fn allclose(a: &Tensor, b: &Tensor) -> Result<bool, String> {
19 Python::with_gil(|py| {
20 let a_obj = a.inner.bind(py);
21 let b_obj = b.inner.bind(py);
22
23 torch_allclose(py)
24 .call1((a_obj, b_obj))
25 .map_err(|e| format!("Failed to call torch.allclose: {}", e))?
26 .extract()
27 .map_err(|e| format!("Failed to extract result: {}", e))
28 })
29}
30
31pub fn cuda_full(size: &[i64], value: f32) -> Tensor {
32 Python::with_gil(|py| {
33 let size_tuple = pyo3::types::PyTuple::new(py, size).unwrap();
34
35 let kwargs = pyo3::types::PyDict::new(py);
36 kwargs.set_item("device", "cuda").unwrap();
37
38 let result = torch_full(py)
39 .call((size_tuple, value), Some(&kwargs))
40 .unwrap();
41
42 Tensor {
43 inner: result.clone().unbind(),
44 }
45 })
46}
47
48pub fn stack(tensors: &[Tensor]) -> Tensor {
49 Python::with_gil(|py| {
50 let tensor_list = pyo3::types::PyList::empty(py);
52 for tensor in tensors {
53 tensor_list.append(tensor.inner.bind(py)).unwrap();
54 }
55
56 let result = torch_stack(py).call1((tensor_list,)).unwrap();
57
58 Tensor {
59 inner: result.clone().unbind(),
60 }
61 })
62}