1use crate::IValue;
10use crate::backend::BoxedBackend;
11use crate::backend::BoxedWork;
12
13#[allow(dead_code)]
14#[cxx::bridge(namespace = "monarch")]
15pub(crate) mod ffi {
16 impl Vec<IValue> {}
20 impl Vec<Tensor> {}
21 impl Box<BoxedBackend> {}
22 impl Box<BoxedWork> {}
23
24 #[derive(Debug)]
25 enum AliasKind {
26 NewValue,
27 Alias,
28 }
29
30 #[derive(Debug)]
31 struct AliasInfo {
32 kind: AliasKind,
33 arg_idx: usize,
34 arg_name: String,
35 }
36
37 #[derive(Debug)]
38 struct OpCallResult {
39 outputs: Vec<IValue>,
40 alias_infos: Vec<AliasInfo>,
41 }
42
43 #[derive(Debug)]
44 struct SchemaArgInfo<'a> {
45 name: String,
46 is_mutable: bool,
47 type_: &'a TypePtr,
48 num_elements: i32,
49 allows_number_as_tensor: bool,
53 }
54
55 #[derive(Debug)]
56 struct Kwarg {
57 name: String,
58 arg: IValue,
59 }
60
61 #[derive(Debug)]
62 enum ReduceOp {
63 Sum = 0,
64 Avg = 1,
65 Max = 2,
66 Min = 3,
67 }
68
69 #[derive(Debug)]
70 struct AllreduceOptions {
71 reduce_op: ReduceOp,
72 }
73
74 #[derive(Debug)]
75 struct BarrierOptions {
76 timeout: usize, }
78
79 #[derive(Debug)]
80 struct ReduceOptions {
81 reduce_op: ReduceOp,
82 root_rank: i32,
83 }
84
85 #[derive(Debug)]
86 struct ReduceScatterOptions {
87 reduce_op: ReduceOp,
88 }
89
90 #[derive(Debug)]
91 struct GatherOptions {
92 root_rank: i32,
93 timeout: usize, }
95
96 #[derive(Debug)]
97 struct ScatterOptions {
98 root_rank: i32,
99 timeout: usize, }
101
102 #[derive(Debug)]
103 struct BroadcastOptions {
104 root_rank: i32,
105 }
106
107 #[derive(Debug)]
108 struct AllToAllOptions {
109 timeout: usize, }
111
112 extern "Rust" {
113 type BoxedWork;
114 fn wait(&self) -> Result<()>;
115 fn is_completed(&self) -> Result<bool>;
116 }
117
118 extern "Rust" {
119 type BoxedBackend;
122 fn allreduce(
123 &self,
124 tensors: &CxxVector<Tensor>,
125 opts: AllreduceOptions,
126 ) -> Result<Box<BoxedWork>>;
127 fn allgather(&self, output: &CxxVector<Tensor>, input: &Tensor) -> Result<Box<BoxedWork>>;
128 fn _allgather_base(&self, output: &Tensor, input: &Tensor) -> Result<Box<BoxedWork>>;
129 fn barrier(&self, opts: BarrierOptions) -> Result<Box<BoxedWork>>;
130 fn reduce(&self, input: &Tensor, opts: ReduceOptions) -> Result<Box<BoxedWork>>;
131 fn _reduce_scatter_base(
132 &self,
133 output: &Tensor,
134 input: &Tensor,
135 opts: ReduceScatterOptions,
136 ) -> Result<Box<BoxedWork>>;
137 fn send(
138 &self,
139 tensors: &CxxVector<Tensor>,
140 dst_rank: i32,
141 tag: i32,
142 ) -> Result<Box<BoxedWork>>;
143 fn recv(
144 &self,
145 tensors: &CxxVector<Tensor>,
146 src_rank: i32,
147 tag: i32,
148 ) -> Result<Box<BoxedWork>>;
149 fn gather(
150 &self,
151 outputs: &CxxVector<Tensor>,
152 input: &Tensor,
153 opts: GatherOptions,
154 ) -> Result<Box<BoxedWork>>;
155 fn scatter(
156 &self,
157 output: &Tensor,
158 inputs: &CxxVector<Tensor>,
159 opts: ScatterOptions,
160 ) -> Result<Box<BoxedWork>>;
161 fn broadcast(
162 &self,
163 tensors: &CxxVector<Tensor>,
164 opts: BroadcastOptions,
165 ) -> Result<Box<BoxedWork>>;
166 fn alltoall_base(
167 &self,
168 output_buffer: &Tensor,
169 input_buffer: &Tensor,
170 opts: AllToAllOptions,
171 ) -> Result<Box<BoxedWork>>;
172 fn alltoall(
173 &self,
174 output_tensors: &CxxVector<Tensor>,
175 input_tensors: &CxxVector<Tensor>,
176 opts: AllToAllOptions,
177 ) -> Result<Box<BoxedWork>>;
178 }
179
180 unsafe extern "C++" {
181 include!("torch/csrc/distributed/c10d/Types.hpp");
189 include!("monarch/torch-sys/src/bridge.h");
190 #[namespace = "c10"]
191 type IValue = crate::IValue;
192 #[namespace = "torch"]
193 type Tensor = crate::Tensor;
194 #[namespace = "c10"]
195 type Device = crate::Device;
196 #[namespace = "c10"]
197 type MemoryFormat = crate::MemoryFormat;
198 #[namespace = "c10"]
199 type Layout = crate::Layout;
200 #[namespace = "monarch"]
201 type FFIPyObject = crate::pyobject::FFIPyObject;
202 #[namespace = "c10"]
203 type TypePtr = crate::call_op::TypePtr;
204
205 fn create_monarch_backend() -> FFIPyObject;
207 fn create_null_backend() -> FFIPyObject;
208
209 fn device_from_py_object(obj: FFIPyObject) -> Result<Device>;
211 fn device_to_py_object(device: Device) -> FFIPyObject;
212
213 fn layout_from_py_object(obj: FFIPyObject) -> Result<Layout>;
215 fn layout_to_py_object(layout: Layout) -> FFIPyObject;
216 fn py_object_is_layout(obj: FFIPyObject) -> bool;
217
218 fn memory_format_from_py_object(obj: FFIPyObject) -> Result<MemoryFormat>;
220 fn memory_format_to_py_object(memory_format: MemoryFormat) -> FFIPyObject;
221 fn py_object_is_memory_format(obj: FFIPyObject) -> bool;
222
223 fn tensor_from_py_object(obj: FFIPyObject) -> Result<Tensor>;
225 fn tensor_to_py_object(tensor: Tensor) -> FFIPyObject;
226
227 fn device(self: &Tensor) -> Device;
229 fn scalar_type(self: &Tensor) -> ScalarType;
230 fn is_cuda(self: &Tensor) -> bool;
231 fn cpu(self: &Tensor) -> Tensor;
232 fn is_sparse(self: &Tensor) -> bool;
233 fn is_contiguous(self: &Tensor, memory_format: MemoryFormat) -> bool;
234 fn numel(self: &Tensor) -> i64;
235 fn nbytes(self: &Tensor) -> usize;
236 fn suggest_memory_format(t: &Tensor) -> MemoryFormat;
237 fn equal(self: &Tensor, other: &Tensor) -> bool;
238 fn defined(self: &Tensor) -> bool;
239
240 fn factory_zeros(
242 sizes: &[i64],
243 dtype: ScalarType,
244 layout: Layout,
245 device: Device,
246 ) -> Tensor;
247 fn factory_empty(
249 sizes: &[i64],
250 dtype: ScalarType,
251 layout: Layout,
252 device: Device,
253 ) -> Tensor;
254 fn factory_float_tensor(data: &[f32], device: Device) -> Tensor;
258 fn deep_clone(t: &Tensor) -> Tensor;
264
265 fn load_tensor(buf: &[u8]) -> Result<Tensor>;
267 fn save_tensor(tensor: &Tensor) -> Result<Vec<u8>>;
268
269 fn copy_(tensor: &mut Tensor, src: &Tensor);
270 fn sizes(tensor: &Tensor) -> Vec<i32>;
271
272 #[namespace = "c10"]
274 type ScalarType = crate::ScalarType;
275 #[namespace = "at"]
276 #[rust_name = "is_float8_type"]
277 fn isFloat8Type(t: ScalarType) -> bool;
278
279 fn scalar_type_from_py_object(obj: FFIPyObject) -> Result<ScalarType>;
281 fn scalar_type_to_py_object(scalar_type: ScalarType) -> FFIPyObject;
282 fn py_object_is_scalar_type(obj: FFIPyObject) -> bool;
283
284 unsafe fn call_op_raw(
299 op_name: &str,
300 overload: &str,
301 args: &mut [IValue],
302 kwargs: &mut [Kwarg],
303 flatten_results: bool,
304 ) -> Result<OpCallResult>;
305
306 fn get_schema_args_info<'a>(
312 op_name: &'a str,
313 overload: &'a str,
314 ) -> Result<Vec<SchemaArgInfo<'a>>>;
315
316 fn ivalue_from_int(val: i64) -> IValue;
318 fn ivalue_from_int_list(val: &[i64]) -> IValue;
319 fn ivalue_from_double(val: f64) -> IValue;
320 fn ivalue_from_bool(val: bool) -> IValue;
321 fn ivalue_from_string(val: &String) -> IValue;
322 fn ivalue_from_tensor(val: Tensor) -> IValue;
323 fn ivalue_from_tensor_list(val: Vec<Tensor>) -> IValue;
324 fn ivalue_from_device(val: Device) -> IValue;
325 fn ivalue_from_layout(val: Layout) -> IValue;
326 fn ivalue_from_scalar_type(val: ScalarType) -> IValue;
327 fn ivalue_from_none() -> IValue;
328
329 fn arbitrary_ivalue_to_py_object(val: IValue) -> Result<FFIPyObject>;
331 fn ivalue_from_arbitrary_py_object(obj: FFIPyObject) -> Result<IValue>;
332 fn py_object_is_ivalue(obj: FFIPyObject) -> bool;
333 fn ivalue_from_py_object_with_type(
337 obj: FFIPyObject,
338 type_: &TypePtr,
339 num_elements: i32,
340 allow_nums_as_tensors: bool,
341 ) -> Result<IValue>;
342
343 fn ivalues_equal_operator(a: &IValue, b: &IValue) -> bool;
346
347 fn serialize_ivalue(val: &IValue) -> Result<Vec<u8>>;
349 fn deserialize_ivalue(buf: &[u8]) -> Result<IValue>;
350
351 fn ivalue_deepcopy(iv: &IValue) -> Result<IValue>;
354
355 fn dump(self: &IValue) -> ();
358
359 #[doc(hidden)]
360 fn isBool(self: &IValue) -> bool;
361 #[doc(hidden)]
362 fn toBool(self: &IValue) -> Result<bool>;
363
364 #[doc(hidden)]
365 fn isInt(self: &IValue) -> bool;
366 #[doc(hidden)]
367 fn toInt(self: &IValue) -> Result<i64>;
368
369 #[doc(hidden)]
370 fn isDouble(self: &IValue) -> bool;
371 #[doc(hidden)]
372 fn toDouble(self: &IValue) -> Result<f64>;
373
374 #[doc(hidden)]
375 fn isIntList(self: &IValue) -> bool;
376 #[doc(hidden)]
377 fn toIntList(iv: &IValue) -> Result<Vec<i64>>;
378
379 #[doc(hidden)]
380 fn isString(self: &IValue) -> bool;
381 #[doc(hidden)]
382 fn toString(iv: &IValue) -> Result<String>;
383
384 #[doc(hidden)]
385 fn isTensor(self: &IValue) -> bool;
386 #[doc(hidden)]
387 fn toTensor(iv: IValue) -> Result<Tensor>;
388
389 #[doc(hidden)]
390 fn isTensorList(self: &IValue) -> bool;
391 #[doc(hidden)]
392 fn toTensorList(iv: IValue) -> Result<Vec<Tensor>>;
393
394 #[doc(hidden)]
395 fn isDevice(self: &IValue) -> bool;
396 #[doc(hidden)]
397 fn toDevice(self: &IValue) -> Result<Device>;
398
399 #[doc(hidden)]
402 fn toLayout(self: &IValue) -> Result<Layout>;
403
404 #[doc(hidden)]
407 fn toScalarType(self: &IValue) -> Result<ScalarType>;
408
409 #[doc(hidden)]
410 fn isNone(self: &IValue) -> bool;
411 fn type_ptr_is_tensor(t: &TypePtr) -> bool;
418 fn type_ptr_is_tensor_list(t: &TypePtr) -> bool;
419 fn type_ptr_is_optional_tensor(t: &TypePtr) -> bool;
420 fn type_ptr_is_optional_tensor_list(t: &TypePtr) -> bool;
421
422 fn debug_type_str(val: &IValue) -> Result<String>;
424 fn debug_print(val: &IValue) -> Result<String>;
425
426 #[namespace = "monarch::test"]
427 fn test_make_undefined_tensor_ivalue() -> IValue;
428
429 #[namespace = "monarch::test"]
430 fn test_make_opaque_ivalue() -> IValue;
431
432 #[namespace = "monarch::test"]
433 fn test_make_tensor() -> Tensor;
434
435 #[namespace = "monarch::test"]
436 fn cuda_full(size: &[i64], value: f32) -> Tensor;
437
438 #[namespace = "monarch::test"]
439 unsafe fn test_make_alias(t: &Tensor) -> Tensor;
440
441 #[namespace = "monarch::test"]
442 fn allclose(a: &Tensor, b: &Tensor) -> Result<bool>;
443
444 #[namespace = "monarch::test"]
445 fn repr(t: &Tensor) -> String;
446
447 #[namespace = "monarch::test"]
448 fn stack(tensor: &[Tensor]) -> Tensor;
449
450 fn is_alias(lhs: &Tensor, rhs: &Tensor) -> bool;
451 }
452
453 impl CxxVector<Tensor> {}
455}
456
457unsafe extern "C" {
458 pub(crate) fn cpp_decref(ptr: *mut std::ffi::c_void);
459 pub(crate) fn drop(this: *mut IValue);
460 pub(crate) fn clone_iv(this: *const IValue, new: *mut IValue);
461 pub(crate) fn cpp_incref(ptr: *mut std::ffi::c_void);
462
463 pub(crate) fn const_data_ptr(tensor: *mut std::ffi::c_void) -> *const std::ffi::c_void;
464 pub(crate) fn mut_data_ptr(tensor: *mut std::ffi::c_void) -> *mut std::ffi::c_void;
465}