Expand description
Rust bindings for libtorch C++ APIs.
These bindings were designed with the following goals:
- Fundamental types should look and perform close to the C++ version. In
particular, we bind
Tensor
andIValue
by hand, so that they can be passed around by value rather than requiring a heap allocation. - We want to minimize the amount of application logic that needs to be written in C++, and avoid complex invariants that need to be maintained across languages.
- Types exposed by the bindings should behave like regular Rust types. In particular, they should be safe; from safe Rust code we should never be able to violate Rust’s invariants.
At the moment, these bindings implement the minimal functionality needed to work with the PyTorch object model and perform dispatch on PyTorch ops.
§Example
let sizes = RValue::from(vec![2, 3]);
let mut outputs =
torch_sys::call_op::call_op("aten::ones", "", &[sizes.clone()], &HashMap::new(), true)?;
let t1 = outputs.pop().unwrap();
// Can do kwargs as well
let kwargs = HashMap::from([("size".into(), sizes)]);
let mut outputs = torch_sys::call_op::call_op("aten::ones", "", &[], &kwargs, true)?;
let t2 = outputs.pop().unwrap();
let mut outputs =
torch_sys::call_op::call_op("aten::allclose", "", &[t1, t2], &HashMap::new(), true)?;
let result = outputs.pop().unwrap();
assert!(result.try_into()?);
§Safety
These are considerations that apply to bindings that deal with Tensor
(and
by extension, IValue
s since they can contain Tensor
s). If a binding
violates these rules, they must be marked unsafe
and the additional
constraints should be documented.
§Mutability
Rule: If a binding can potentially mutate a Tensor
, the safe Rust
function signature must take take it by either value or &mut Tensor
.
You must manually audit the C++ implementation to determine whether it not it can mutate its arguments.
Notably, this is true even if the C++ signature receives a const Tensor&
.
You can still mutate a Tensor
obtained that way! LibTorch’s C++ API
doesn’t have a concept of an immutable Tensor
object, so we must rely on
manual auditing to ensure that a Rust &Tensor
is immutable.
§Aliasing
Rule: A safe binding must not produce a new alias of an existing
Tensor
.
You must manually audit the C++ implementation to determine whether or not
it can produce a new alias. This may involve inserting dynamic aliasing
checks if aliasing relationships are not known statically (e.g.
aten::contiguous
).
We want the Rust compiler to be correctly tracking ownership and borrowing
of Tensor
and enforcing the invariant that only one mutable reference to a
Tensor
can exist at a time.
If a C++ object returned a new alias of an existing Tensor
, the Rust
compiler would treat them as independent Tensor
objects, and would not be
able to prevent a data race if we tried to mutate them on two different
threads.
In Rust, shared ownership + mutability is handled by having a smart pointer
own a value that is synchronized (e.g. the Arc<Mutex<T>>
pattern). We
cannot synchronize access to the C++ underlying TensorImpl
without
changing the implementation of at::Tensor
, so we must disallow shared
ownership in Rust code.
Modules§
Structs§
- Cuda
Device - A device that is statically guaranteed to be a CUDA device.
- Device
- Binding for
c10::Device
. - Device
Index - Binding for
c10::DeviceIndex
. - Device
Type - Binding for
c10::DeviceType
. - IValue
- Rust binding for the C++ type
c10::IValue
. - Layout
- Binding for
c10::Layout
. - Layout
Def - Remote serde implementation.
- Memory
Format - Binding for
c10::MemoryFormat
. - Memory
Format Def - Remote serde implementation.
- Multi
Borrow - A helper that batches multiple borrows for a single borrower, deduping them so we don’t accidentally borrow the same alias twice.
- OpaqueI
Value - An opaque container for an
IValue
. This is used to restrict safe direct access to the underlyingIValue
. - Scalar
Type - Binding for
c10::ScalarType
. - Scalar
Type Def - Remote serde implementation.
- Tensor
- Rust binding for the C++ type
at::Tensor
.
Enums§
- Borrow
- Abstracts over the different types of borrows we can have.
- Borrow
Error - Errors that can occur while calling an operator.
- Borrow
Type - Device
Parse Error - Errors that can be returned from constructing a device from a string.
- IValue
Kind - Enum representing the different internal types an
IValue
can hold. - RValue
- A pure Rust equivalent for
IValue
. This is safe to treat like a normal Rust value.
Traits§
- Clone
Unsafe CloneUnsafe
is a trait that allows us to have theAliasTrackingRefCell
implementClone
for that type. Theclone_unsafe
method is unsafe because it does not create an independent copy of the underlying type but instead the returned value will be tracked like any other alias, and borrow-checking rules will be enforced across both cells.
Functions§
- deep_
clone - Return a clone of this tensor. The semantics of clone are like
torch.clone
: it will copy the the underlying tensor storage. - factory_
empty - binding for
torch.empty
- factory_
float_ tensor - Creates a new one-dimensional f32 Tensor with the provided data.
Mostly used for testing; basically equivalent to a limited version
of the raw
torch.tensor
constructor. - factory_
zeros - binding for
torch.zeros
- is_
float8_ type - rvalue_
to_ ⚠ivalue - suggest_
memory_ format