Crate torch_sys

Source
Expand description

Rust bindings for libtorch C++ APIs.

These bindings were designed with the following goals:

  • Fundamental types should look and perform close to the C++ version. In particular, we bind Tensor and IValue by hand, so that they can be passed around by value rather than requiring a heap allocation.
  • We want to minimize the amount of application logic that needs to be written in C++, and avoid complex invariants that need to be maintained across languages.
  • Types exposed by the bindings should behave like regular Rust types. In particular, they should be safe; from safe Rust code we should never be able to violate Rust’s invariants.

At the moment, these bindings implement the minimal functionality needed to work with the PyTorch object model and perform dispatch on PyTorch ops.

§Example

let sizes = RValue::from(vec![2, 3]);

let mut outputs =
    torch_sys::call_op::call_op("aten::ones", "", &[sizes.clone()], &HashMap::new(), true)?;
let t1 = outputs.pop().unwrap();

// Can do kwargs as well
let kwargs = HashMap::from([("size".into(), sizes)]);
let mut outputs = torch_sys::call_op::call_op("aten::ones", "", &[], &kwargs, true)?;
let t2 = outputs.pop().unwrap();

let mut outputs =
    torch_sys::call_op::call_op("aten::allclose", "", &[t1, t2], &HashMap::new(), true)?;
let result = outputs.pop().unwrap();

assert!(result.try_into()?);

§Safety

These are considerations that apply to bindings that deal with Tensor (and by extension, IValues since they can contain Tensors). If a binding violates these rules, they must be marked unsafe and the additional constraints should be documented.

§Mutability

Rule: If a binding can potentially mutate a Tensor, the safe Rust function signature must take take it by either value or &mut Tensor.

You must manually audit the C++ implementation to determine whether it not it can mutate its arguments.

Notably, this is true even if the C++ signature receives a const Tensor&. You can still mutate a Tensor obtained that way! LibTorch’s C++ API doesn’t have a concept of an immutable Tensor object, so we must rely on manual auditing to ensure that a Rust &Tensor is immutable.

§Aliasing

Rule: A safe binding must not produce a new alias of an existing Tensor.

You must manually audit the C++ implementation to determine whether or not it can produce a new alias. This may involve inserting dynamic aliasing checks if aliasing relationships are not known statically (e.g. aten::contiguous).

We want the Rust compiler to be correctly tracking ownership and borrowing of Tensor and enforcing the invariant that only one mutable reference to a Tensor can exist at a time.

If a C++ object returned a new alias of an existing Tensor, the Rust compiler would treat them as independent Tensor objects, and would not be able to prevent a data race if we tried to mutate them on two different threads.

In Rust, shared ownership + mutability is handled by having a smart pointer own a value that is synchronized (e.g. the Arc<Mutex<T>> pattern). We cannot synchronize access to the C++ underlying TensorImpl without changing the implementation of at::Tensor, so we must disallow shared ownership in Rust code.

Modules§

backend
call_op
testing

Structs§

CudaDevice
A device that is statically guaranteed to be a CUDA device.
Device
Binding for c10::Device.
DeviceIndex
Binding for c10::DeviceIndex.
DeviceType
Binding for c10::DeviceType.
IValue
Rust binding for the C++ type c10::IValue.
Layout
Binding for c10::Layout.
LayoutDef
Remote serde implementation.
MemoryFormat
Binding for c10::MemoryFormat.
MemoryFormatDef
Remote serde implementation.
MultiBorrow
A helper that batches multiple borrows for a single borrower, deduping them so we don’t accidentally borrow the same alias twice.
OpaqueIValue
An opaque container for an IValue. This is used to restrict safe direct access to the underlying IValue.
ScalarType
Binding for c10::ScalarType.
ScalarTypeDef
Remote serde implementation.
Tensor
Rust binding for the C++ type at::Tensor.

Enums§

Borrow
Abstracts over the different types of borrows we can have.
BorrowError
Errors that can occur while calling an operator.
BorrowType
DeviceParseError
Errors that can be returned from constructing a device from a string.
IValueKind
Enum representing the different internal types an IValue can hold.
RValue
A pure Rust equivalent for IValue. This is safe to treat like a normal Rust value.

Traits§

CloneUnsafe
CloneUnsafe is a trait that allows us to have the AliasTrackingRefCell implement Clone for that type. The clone_unsafe method is unsafe because it does not create an independent copy of the underlying type but instead the returned value will be tracked like any other alias, and borrow-checking rules will be enforced across both cells.

Functions§

deep_clone
Return a clone of this tensor. The semantics of clone are like torch.clone: it will copy the the underlying tensor storage.
factory_empty
binding for torch.empty
factory_float_tensor
Creates a new one-dimensional f32 Tensor with the provided data. Mostly used for testing; basically equivalent to a limited version of the raw torch.tensor constructor.
factory_zeros
binding for torch.zeros
is_float8_type
rvalue_to_ivalue
suggest_memory_format

Type Aliases§

TensorCell