torch_sys/lib.rs
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Rust bindings for libtorch C++ APIs.
10//!
11//! These bindings were designed with the following goals:
12//! - Fundamental types should look and perform close to the C++ version. In
13//! particular, we bind [`Tensor`] and [`IValue`] by hand, so that they can be
14//! passed around by value rather than requiring a heap allocation.
15//! - We want to minimize the amount of application logic that needs to be
16//! written in C++, and avoid complex invariants that need to be maintained
17//! across languages.
18//! - Types exposed by the bindings should behave like regular Rust types. In
19//! particular, they should be safe; from safe Rust code we should never
20//! be able to violate Rust's invariants.
21//!
22//! At the moment, these bindings implement the minimal functionality needed to
23//! work with the PyTorch object model and perform dispatch on PyTorch ops.
24//!
25//! # Example
26//! ```
27//! # use std::collections::HashMap;
28//! # use std::error::Error;
29//! # use torch_sys::RValue;
30//! # fn main() -> Result<(), Box<dyn Error>> {
31//! let sizes = RValue::from(vec![2, 3]);
32//!
33//! let mut outputs =
34//! torch_sys::call_op::call_op("aten::ones", "", &[sizes.clone()], &HashMap::new(), true)?;
35//! let t1 = outputs.pop().unwrap();
36//!
37//! // Can do kwargs as well
38//! let kwargs = HashMap::from([("size".into(), sizes)]);
39//! let mut outputs = torch_sys::call_op::call_op("aten::ones", "", &[], &kwargs, true)?;
40//! let t2 = outputs.pop().unwrap();
41//!
42//! let mut outputs =
43//! torch_sys::call_op::call_op("aten::allclose", "", &[t1, t2], &HashMap::new(), true)?;
44//! let result = outputs.pop().unwrap();
45//!
46//! assert!(result.try_into()?);
47//! # Ok(())
48//! # }
49//! ```
50//!
51//! # Safety
52//! These are considerations that apply to bindings that deal with `Tensor` (and
53//! by extension, `IValue`s since they can contain `Tensor`s). If a binding
54//! violates these rules, they must be marked `unsafe` and the additional
55//! constraints should be documented.
56//!
57//! ## Mutability
58//!
59//! **Rule**: If a binding can potentially mutate a `Tensor`, the safe Rust
60//! function signature *must* take take it by either value or `&mut Tensor`.
61//!
62//! You must manually audit the C++ implementation to determine whether it not
63//! it can mutate its arguments.
64//!
65//! Notably, this is true even if the C++ signature receives a `const Tensor&`.
66//! You can still mutate a `Tensor` obtained that way! LibTorch's C++ API
67//! doesn't have a concept of an immutable `Tensor` object, so we must rely on
68//! manual auditing to ensure that a Rust `&Tensor` is immutable.
69//!
70//! ## Aliasing
71//!
72//! **Rule**: A safe binding *must not* produce a new alias of an existing
73//! `Tensor`.
74//!
75//! You must manually audit the C++ implementation to determine whether or not
76//! it can produce a new alias. This may involve inserting dynamic aliasing
77//! checks if aliasing relationships are not known statically (e.g.
78//! `aten::contiguous`).
79//!
80//! We want the Rust compiler to be correctly tracking ownership and borrowing
81//! of `Tensor` and enforcing the invariant that only one mutable reference to a
82//! `Tensor` can exist at a time.
83//!
84//! If a C++ object returned a new alias of an existing `Tensor`, the Rust
85//! compiler would treat them as independent `Tensor` objects, and would not be
86//! able to prevent a data race if we tried to mutate them on two different
87//! threads.
88//!
89//! In Rust, shared ownership + mutability is handled by having a smart pointer
90//! own a value that is synchronized (e.g. the `Arc<Mutex<T>>` pattern). We
91//! cannot synchronize access to the C++ underlying `TensorImpl` without
92//! changing the implementation of `at::Tensor`, so we must disallow shared
93//! ownership in Rust code.
94
95#![feature(assert_matches)]
96#![feature(once_cell_try)]
97
98mod bindings;
99mod borrow;
100mod bridge;
101pub mod call_op;
102mod cell;
103mod device;
104mod ivalue;
105mod layout;
106mod memory_format;
107mod pyobject;
108mod rvalue;
109mod scalar_type;
110mod tensor;
111
112pub mod backend;
113
114/// Binding for `c10::Layout`.
115pub use bindings::root::c10::Layout;
116/// Binding for `c10::MemoryFormat`.
117pub use bindings::root::c10::MemoryFormat;
118/// Binding for `c10::ScalarType`.
119pub use bindings::root::c10::ScalarType;
120pub use borrow::Borrow;
121pub use borrow::BorrowError;
122pub use borrow::BorrowType;
123pub use borrow::MultiBorrow;
124pub use cell::CloneUnsafe;
125pub use device::CudaDevice;
126pub use device::Device;
127pub use device::DeviceIndex;
128pub use device::DeviceParseError;
129pub use device::DeviceType;
130pub use ivalue::IValue;
131pub use ivalue::IValueKind;
132pub use ivalue::OpaqueIValue;
133pub use rvalue::RValue;
134pub use rvalue::rvalue_to_ivalue;
135pub use tensor::Tensor;
136pub use tensor::TensorCell;
137
138pub use crate::bridge::ffi::deep_clone;
139pub use crate::bridge::ffi::factory_float_tensor;
140/// Remote serde implementation.
141pub use crate::layout::LayoutDef;
142/// Remote serde implementation.
143pub use crate::memory_format::MemoryFormatDef;
144/// Remote serde implementation.
145pub use crate::scalar_type::ScalarTypeDef;
146pub mod testing {
147 /// Compares two tensors with `torch.allclose`.
148 pub use crate::bridge::ffi::allclose;
149 pub use crate::bridge::ffi::cuda_full;
150 pub use crate::bridge::ffi::stack;
151}
152pub use crate::bridge::ffi::factory_empty;
153pub use crate::bridge::ffi::factory_zeros;
154pub use crate::bridge::ffi::is_float8_type;
155pub use crate::bridge::ffi::suggest_memory_format;
156// Only here to make them available to doctests!
157#[doc(hidden)]
158pub use crate::bridge::ffi::test_make_alias;
159#[doc(hidden)]
160pub use crate::bridge::ffi::test_make_tensor;