torch_sys/
bridge.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9use crate::IValue;
10use crate::backend::BoxedBackend;
11use crate::backend::BoxedWork;
12
13#[allow(dead_code)]
14#[cxx::bridge(namespace = "monarch")]
15pub(crate) mod ffi {
16    // These are here to instruct the CXX codegen to generate concrete
17    // specializations of `rust::Vec` for these types in C++.
18    // Otherwise, you'll get a linker error about undefined symbols.
19    impl Vec<IValue> {}
20    impl Vec<Tensor> {}
21    impl Box<BoxedBackend> {}
22    impl Box<BoxedWork> {}
23
24    #[derive(Debug)]
25    enum AliasKind {
26        NewValue,
27        Alias,
28    }
29
30    #[derive(Debug)]
31    struct AliasInfo {
32        kind: AliasKind,
33        arg_idx: usize,
34        arg_name: String,
35    }
36
37    #[derive(Debug)]
38    struct OpCallResult {
39        outputs: Vec<IValue>,
40        alias_infos: Vec<AliasInfo>,
41    }
42
43    #[derive(Debug)]
44    struct SchemaArgInfo<'a> {
45        name: String,
46        is_mutable: bool,
47        type_: &'a TypePtr,
48        num_elements: i32,
49        // This is an operator level flag but we have it in schema arg
50        // at the moment. This should be moved to an operator level
51        // struct in future especially when we have more fields.
52        allows_number_as_tensor: bool,
53    }
54
55    #[derive(Debug)]
56    struct Kwarg {
57        name: String,
58        arg: IValue,
59    }
60
61    #[derive(Debug)]
62    enum ReduceOp {
63        Sum = 0,
64        Avg = 1,
65        Max = 2,
66        Min = 3,
67    }
68
69    #[derive(Debug)]
70    struct AllreduceOptions {
71        reduce_op: ReduceOp,
72    }
73
74    #[derive(Debug)]
75    struct BarrierOptions {
76        timeout: usize, // milliseconds
77    }
78
79    #[derive(Debug)]
80    struct ReduceOptions {
81        reduce_op: ReduceOp,
82        root_rank: i32,
83    }
84
85    #[derive(Debug)]
86    struct ReduceScatterOptions {
87        reduce_op: ReduceOp,
88    }
89
90    #[derive(Debug)]
91    struct GatherOptions {
92        root_rank: i32,
93        timeout: usize, // milliseconds
94    }
95
96    #[derive(Debug)]
97    struct ScatterOptions {
98        root_rank: i32,
99        timeout: usize, // milliseconds
100    }
101
102    #[derive(Debug)]
103    struct BroadcastOptions {
104        root_rank: i32,
105    }
106
107    #[derive(Debug)]
108    struct AllToAllOptions {
109        timeout: usize, // milliseconds
110    }
111
112    extern "Rust" {
113        type BoxedWork;
114        fn wait(&self) -> Result<()>;
115        fn is_completed(&self) -> Result<bool>;
116    }
117
118    extern "Rust" {
119        // We use `BoxedBackend` as the bridge to the our C++ backend impl.
120        // TODO(agallagher): Fill this out.
121        type BoxedBackend;
122        fn allreduce(
123            &self,
124            tensors: &CxxVector<Tensor>,
125            opts: AllreduceOptions,
126        ) -> Result<Box<BoxedWork>>;
127        fn allgather(&self, output: &CxxVector<Tensor>, input: &Tensor) -> Result<Box<BoxedWork>>;
128        fn _allgather_base(&self, output: &Tensor, input: &Tensor) -> Result<Box<BoxedWork>>;
129        fn barrier(&self, opts: BarrierOptions) -> Result<Box<BoxedWork>>;
130        fn reduce(&self, input: &Tensor, opts: ReduceOptions) -> Result<Box<BoxedWork>>;
131        fn _reduce_scatter_base(
132            &self,
133            output: &Tensor,
134            input: &Tensor,
135            opts: ReduceScatterOptions,
136        ) -> Result<Box<BoxedWork>>;
137        fn send(
138            &self,
139            tensors: &CxxVector<Tensor>,
140            dst_rank: i32,
141            tag: i32,
142        ) -> Result<Box<BoxedWork>>;
143        fn recv(
144            &self,
145            tensors: &CxxVector<Tensor>,
146            src_rank: i32,
147            tag: i32,
148        ) -> Result<Box<BoxedWork>>;
149        fn gather(
150            &self,
151            outputs: &CxxVector<Tensor>,
152            input: &Tensor,
153            opts: GatherOptions,
154        ) -> Result<Box<BoxedWork>>;
155        fn scatter(
156            &self,
157            output: &Tensor,
158            inputs: &CxxVector<Tensor>,
159            opts: ScatterOptions,
160        ) -> Result<Box<BoxedWork>>;
161        fn broadcast(
162            &self,
163            tensors: &CxxVector<Tensor>,
164            opts: BroadcastOptions,
165        ) -> Result<Box<BoxedWork>>;
166        fn alltoall_base(
167            &self,
168            output_buffer: &Tensor,
169            input_buffer: &Tensor,
170            opts: AllToAllOptions,
171        ) -> Result<Box<BoxedWork>>;
172        fn alltoall(
173            &self,
174            output_tensors: &CxxVector<Tensor>,
175            input_tensors: &CxxVector<Tensor>,
176            opts: AllToAllOptions,
177        ) -> Result<Box<BoxedWork>>;
178    }
179
180    unsafe extern "C++" {
181        ///////////////////////////////////////////////////////////////////////
182        /// NOTE: If you are implementing a new binding, please review the
183        /// safety discussion in `lib.rs`. Then, please include a "Safety"
184        /// section in your docblock, discussing how mutability/aliasing
185        /// restrictions apply to your binding.
186        ///////////////////////////////////////////////////////////////////////
187        // include!("ATen/cuda/CUDAEvent.h");
188        include!("torch/csrc/distributed/c10d/Types.hpp");
189        include!("monarch/torch-sys/src/bridge.h");
190        #[namespace = "c10"]
191        type IValue = crate::IValue;
192        #[namespace = "torch"]
193        type Tensor = crate::Tensor;
194        #[namespace = "c10"]
195        type Device = crate::Device;
196        #[namespace = "c10"]
197        type MemoryFormat = crate::MemoryFormat;
198        #[namespace = "c10"]
199        type Layout = crate::Layout;
200        #[namespace = "monarch"]
201        type FFIPyObject = crate::pyobject::FFIPyObject;
202        #[namespace = "c10"]
203        type TypePtr = crate::call_op::TypePtr;
204
205        // Creates a Python callback to be passed to `Backend.register_backend`.
206        fn create_monarch_backend() -> FFIPyObject;
207        fn create_null_backend() -> FFIPyObject;
208
209        // Device
210        fn device_from_py_object(obj: FFIPyObject) -> Result<Device>;
211        fn device_to_py_object(device: Device) -> FFIPyObject;
212
213        // Layout
214        fn layout_from_py_object(obj: FFIPyObject) -> Result<Layout>;
215        fn layout_to_py_object(layout: Layout) -> FFIPyObject;
216        fn py_object_is_layout(obj: FFIPyObject) -> bool;
217
218        // MemoryFormat
219        fn memory_format_from_py_object(obj: FFIPyObject) -> Result<MemoryFormat>;
220        fn memory_format_to_py_object(memory_format: MemoryFormat) -> FFIPyObject;
221        fn py_object_is_memory_format(obj: FFIPyObject) -> bool;
222
223        // Tensor
224        fn tensor_from_py_object(obj: FFIPyObject) -> Result<Tensor>;
225        fn tensor_to_py_object(tensor: Tensor) -> FFIPyObject;
226
227        // Methods on Tensor
228        fn device(self: &Tensor) -> Device;
229        fn scalar_type(self: &Tensor) -> ScalarType;
230        fn is_cuda(self: &Tensor) -> bool;
231        fn cpu(self: &Tensor) -> Tensor;
232        fn is_sparse(self: &Tensor) -> bool;
233        fn is_contiguous(self: &Tensor, memory_format: MemoryFormat) -> bool;
234        fn numel(self: &Tensor) -> i64;
235        fn nbytes(self: &Tensor) -> usize;
236        fn suggest_memory_format(t: &Tensor) -> MemoryFormat;
237        fn equal(self: &Tensor, other: &Tensor) -> bool;
238        fn defined(self: &Tensor) -> bool;
239
240        /// binding for `torch.zeros`
241        fn factory_zeros(
242            sizes: &[i64],
243            dtype: ScalarType,
244            layout: Layout,
245            device: Device,
246        ) -> Tensor;
247        /// binding for `torch.empty`
248        fn factory_empty(
249            sizes: &[i64],
250            dtype: ScalarType,
251            layout: Layout,
252            device: Device,
253        ) -> Tensor;
254        /// Creates a new one-dimensional f32 Tensor with the provided data.
255        /// Mostly used for testing; basically equivalent to a limited version
256        /// of the raw `torch.tensor` constructor.
257        fn factory_float_tensor(data: &[f32], device: Device) -> Tensor;
258        /// Return a clone of this tensor. The semantics of clone are like
259        /// `torch.clone`: it will copy the the underlying tensor storage.
260        ///
261        /// # Safety
262        /// This function is guaranteed to produce a fresh (non-aliasing) tensor.
263        fn deep_clone(t: &Tensor) -> Tensor;
264
265        /// Bindings for `load`/`save` for `Tensor`.
266        fn load_tensor(buf: &[u8]) -> Result<Tensor>;
267        fn save_tensor(tensor: &Tensor) -> Result<Vec<u8>>;
268
269        fn copy_(tensor: &mut Tensor, src: &Tensor);
270        fn sizes(tensor: &Tensor) -> Vec<i32>;
271
272        // ScalarType
273        #[namespace = "c10"]
274        type ScalarType = crate::ScalarType;
275        #[namespace = "at"]
276        #[rust_name = "is_float8_type"]
277        fn isFloat8Type(t: ScalarType) -> bool;
278
279        // Convert to Python object.
280        fn scalar_type_from_py_object(obj: FFIPyObject) -> Result<ScalarType>;
281        fn scalar_type_to_py_object(scalar_type: ScalarType) -> FFIPyObject;
282        fn py_object_is_scalar_type(obj: FFIPyObject) -> bool;
283
284        /// # Safety
285        /// - **Mutability**:
286        /// `call_op` may mutate the provided arguments (for example, if you
287        /// called `aten::add_`), so `args` and `kwargs` require a mutable slice.
288        ///
289        /// - **Aliasing**:
290        /// `call_op` may return aliases of the provided arguments, so it is
291        /// marked as `unsafe`. The caller is responsible for using the aliasing
292        /// info returned by `call_op` to ensure that Rust's aliasing rules are
293        /// respected.
294        //
295        // TODO this fn ends up making a bunch of small copies to marshall
296        // arguments across the FFI boundary. This could probably be improved,
297        // at the cost of a less straightforward calling convention.
298        unsafe fn call_op_raw(
299            op_name: &str,
300            overload: &str,
301            args: &mut [IValue],
302            kwargs: &mut [Kwarg],
303            flatten_results: bool,
304        ) -> Result<OpCallResult>;
305
306        /// Give information about which arguments can be mutated by the
307        /// provided operator.
308        /// TODO:
309        ///   - This returns results for all arguments, even ones not provided
310        ///     by the caller.
311        fn get_schema_args_info<'a>(
312            op_name: &'a str,
313            overload: &'a str,
314        ) -> Result<Vec<SchemaArgInfo<'a>>>;
315
316        // Constructors for IValue
317        fn ivalue_from_int(val: i64) -> IValue;
318        fn ivalue_from_int_list(val: &[i64]) -> IValue;
319        fn ivalue_from_double(val: f64) -> IValue;
320        fn ivalue_from_bool(val: bool) -> IValue;
321        fn ivalue_from_string(val: &String) -> IValue;
322        fn ivalue_from_tensor(val: Tensor) -> IValue;
323        fn ivalue_from_tensor_list(val: Vec<Tensor>) -> IValue;
324        fn ivalue_from_device(val: Device) -> IValue;
325        fn ivalue_from_layout(val: Layout) -> IValue;
326        fn ivalue_from_scalar_type(val: ScalarType) -> IValue;
327        fn ivalue_from_none() -> IValue;
328
329        // Interop with Python object.
330        fn arbitrary_ivalue_to_py_object(val: IValue) -> Result<FFIPyObject>;
331        fn ivalue_from_arbitrary_py_object(obj: FFIPyObject) -> Result<IValue>;
332        fn py_object_is_ivalue(obj: FFIPyObject) -> bool;
333        /// Converts the provided Python object to an `IValue` with the provided
334        /// type. If the object is not convertible to the provided type, an
335        /// exception will be thrown.
336        fn ivalue_from_py_object_with_type(
337            obj: FFIPyObject,
338            type_: &TypePtr,
339            num_elements: i32,
340            allow_nums_as_tensors: bool,
341        ) -> Result<IValue>;
342
343        // Equality
344        /// Allows comparing ivalues for equality using `operator==` on `IValue`.
345        fn ivalues_equal_operator(a: &IValue, b: &IValue) -> bool;
346
347        // Serde for IValue
348        fn serialize_ivalue(val: &IValue) -> Result<Vec<u8>>;
349        fn deserialize_ivalue(buf: &[u8]) -> Result<IValue>;
350
351        /// Clones the `IValue` with copying data over. Can throw an exception
352        /// if the `IValue` is not cloneable.
353        fn ivalue_deepcopy(iv: &IValue) -> Result<IValue>;
354
355        // These are methods on the C++ IValue type.
356        /// Prints a human-readable representation of the `IValue` to stdout.
357        fn dump(self: &IValue) -> ();
358
359        #[doc(hidden)]
360        fn isBool(self: &IValue) -> bool;
361        #[doc(hidden)]
362        fn toBool(self: &IValue) -> Result<bool>;
363
364        #[doc(hidden)]
365        fn isInt(self: &IValue) -> bool;
366        #[doc(hidden)]
367        fn toInt(self: &IValue) -> Result<i64>;
368
369        #[doc(hidden)]
370        fn isDouble(self: &IValue) -> bool;
371        #[doc(hidden)]
372        fn toDouble(self: &IValue) -> Result<f64>;
373
374        #[doc(hidden)]
375        fn isIntList(self: &IValue) -> bool;
376        #[doc(hidden)]
377        fn toIntList(iv: &IValue) -> Result<Vec<i64>>;
378
379        #[doc(hidden)]
380        fn isString(self: &IValue) -> bool;
381        #[doc(hidden)]
382        fn toString(iv: &IValue) -> Result<String>;
383
384        #[doc(hidden)]
385        fn isTensor(self: &IValue) -> bool;
386        #[doc(hidden)]
387        fn toTensor(iv: IValue) -> Result<Tensor>;
388
389        #[doc(hidden)]
390        fn isTensorList(self: &IValue) -> bool;
391        #[doc(hidden)]
392        fn toTensorList(iv: IValue) -> Result<Vec<Tensor>>;
393
394        #[doc(hidden)]
395        fn isDevice(self: &IValue) -> bool;
396        #[doc(hidden)]
397        fn toDevice(self: &IValue) -> Result<Device>;
398
399        //#[doc(hidden)]
400        //fn isLayout(self: &IValue) -> bool;
401        #[doc(hidden)]
402        fn toLayout(self: &IValue) -> Result<Layout>;
403
404        //#[doc(hidden)]
405        //fn isScalarType(self: &IValue) -> bool;
406        #[doc(hidden)]
407        fn toScalarType(self: &IValue) -> Result<ScalarType>;
408
409        #[doc(hidden)]
410        fn isNone(self: &IValue) -> bool;
411        //#[doc(hidden)]
412        //fn toNone(self: &IValue) -> Result<String>;
413
414        // TODO: support the rest of IValues stuff
415
416        // Utility functions on TypePtr
417        fn type_ptr_is_tensor(t: &TypePtr) -> bool;
418        fn type_ptr_is_tensor_list(t: &TypePtr) -> bool;
419        fn type_ptr_is_optional_tensor(t: &TypePtr) -> bool;
420        fn type_ptr_is_optional_tensor_list(t: &TypePtr) -> bool;
421
422        // Helpers for debugging
423        fn debug_type_str(val: &IValue) -> Result<String>;
424        fn debug_print(val: &IValue) -> Result<String>;
425
426        #[namespace = "monarch::test"]
427        fn test_make_undefined_tensor_ivalue() -> IValue;
428
429        #[namespace = "monarch::test"]
430        fn test_make_opaque_ivalue() -> IValue;
431
432        #[namespace = "monarch::test"]
433        fn test_make_tensor() -> Tensor;
434
435        #[namespace = "monarch::test"]
436        fn cuda_full(size: &[i64], value: f32) -> Tensor;
437
438        #[namespace = "monarch::test"]
439        unsafe fn test_make_alias(t: &Tensor) -> Tensor;
440
441        #[namespace = "monarch::test"]
442        fn allclose(a: &Tensor, b: &Tensor) -> Result<bool>;
443
444        #[namespace = "monarch::test"]
445        fn repr(t: &Tensor) -> String;
446
447        #[namespace = "monarch::test"]
448        fn stack(tensor: &[Tensor]) -> Tensor;
449
450        fn is_alias(lhs: &Tensor, rhs: &Tensor) -> bool;
451    }
452
453    // Allow accessing `Tensor` from `CxxVector` in the `BoxedBackend` impl.
454    impl CxxVector<Tensor> {}
455}
456
457unsafe extern "C" {
458    pub(crate) fn cpp_decref(ptr: *mut std::ffi::c_void);
459    pub(crate) fn drop(this: *mut IValue);
460    pub(crate) fn clone_iv(this: *const IValue, new: *mut IValue);
461    pub(crate) fn cpp_incref(ptr: *mut std::ffi::c_void);
462
463    pub(crate) fn const_data_ptr(tensor: *mut std::ffi::c_void) -> *const std::ffi::c_void;
464    pub(crate) fn mut_data_ptr(tensor: *mut std::ffi::c_void) -> *mut std::ffi::c_void;
465}