torch_sys2/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#![allow(non_camel_case_types)]
10
11//! Simplified Rust bindings for libtorch C++ APIs.
12//!
13//! This is a streamlined version that only includes the functionality
14//! actually used by the monarch codebase.
15
16pub mod testing;
17
18use monarch_types::py_global;
19use pyo3::prelude::*;
20use pyo3::types::PyAny;
21use serde::Deserialize;
22use serde::Serialize;
23use thiserror::Error;
24
25// Cached imports for torch APIs
26py_global!(torch_device, "torch", "device");
27py_global!(torch_strided, "torch", "strided");
28py_global!(torch_sparse_coo, "torch", "sparse_coo");
29py_global!(torch_contiguous_format, "torch", "contiguous_format");
30py_global!(torch_preserve_format, "torch", "preserve_format");
31py_global!(torch_channels_last, "torch", "channels_last");
32py_global!(torch_channels_last_3d, "torch", "channels_last_3d");
33py_global!(torch_uint8, "torch", "uint8");
34py_global!(torch_int8, "torch", "int8");
35py_global!(torch_int16, "torch", "int16");
36py_global!(torch_int32, "torch", "int32");
37py_global!(torch_int64, "torch", "int64");
38py_global!(torch_float16, "torch", "float16");
39py_global!(torch_float32, "torch", "float32");
40py_global!(torch_float64, "torch", "float64");
41py_global!(torch_complex32, "torch", "complex32");
42py_global!(torch_complex64, "torch", "complex64");
43py_global!(torch_complex128, "torch", "complex128");
44py_global!(torch_bool, "torch", "bool");
45py_global!(torch_bfloat16, "torch", "bfloat16");
46py_global!(torch_float8_e5m2, "torch", "float8_e5m2");
47py_global!(torch_float8_e4m3fn, "torch", "float8_e4m3fn");
48py_global!(torch_zeros, "torch", "zeros");
49py_global!(torch_empty, "torch", "empty");
50py_global!(torch_tensor, "torch", "tensor");
51py_global!(torch_allclose, "torch", "allclose");
52py_global!(torch_full, "torch", "full");
53py_global!(torch_stack, "torch", "stack");
54
55// ============================================================================
56// Device Types
57// ============================================================================
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub enum DeviceType {
61    CPU,
62    CUDA,
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
66pub struct DeviceIndex(pub i8);
67
68impl From<DeviceIndex> for i8 {
69    fn from(idx: DeviceIndex) -> i8 {
70        idx.0
71    }
72}
73
74impl From<i8> for DeviceIndex {
75    fn from(idx: i8) -> DeviceIndex {
76        DeviceIndex(idx)
77    }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
81pub struct Device {
82    device_type: DeviceType,
83    index: Option<DeviceIndex>,
84}
85
86impl Device {
87    pub fn device_type(&self) -> DeviceType {
88        self.device_type
89    }
90}
91
92impl FromPyObject<'_> for Device {
93    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
94        let device_str: String = obj.str()?.extract()?;
95        // Parse the device string
96        device_str.parse().map_err(|e: DeviceParseError| {
97            PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string())
98        })
99    }
100}
101
102impl<'py> IntoPyObject<'py> for Device {
103    type Target = PyAny;
104    type Output = Bound<'py, Self::Target>;
105    type Error = PyErr;
106
107    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
108        let device_str = self.to_string();
109        let device = torch_device(py).call1((device_str,))?;
110        Ok(device)
111    }
112}
113
114impl std::fmt::Display for Device {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        match self.device_type {
117            DeviceType::CPU => write!(f, "cpu"),
118            DeviceType::CUDA => {
119                if let Some(index) = self.index {
120                    write!(f, "cuda:{}", index.0)
121                } else {
122                    write!(f, "cuda")
123                }
124            }
125        }
126    }
127}
128
129impl std::str::FromStr for Device {
130    type Err = DeviceParseError;
131
132    fn from_str(s: &str) -> Result<Self, Self::Err> {
133        if s == "cpu" {
134            Ok(Device {
135                device_type: DeviceType::CPU,
136                index: None,
137            })
138        } else if s == "cuda" {
139            Ok(Device {
140                device_type: DeviceType::CUDA,
141                index: None,
142            })
143        } else if let Some(cuda_idx) = s.strip_prefix("cuda:") {
144            let index = cuda_idx
145                .parse::<i8>()
146                .map_err(|_| DeviceParseError::InvalidDevice)?;
147            Ok(Device {
148                device_type: DeviceType::CUDA,
149                index: Some(DeviceIndex(index)),
150            })
151        } else {
152            Err(DeviceParseError::InvalidDevice)
153        }
154    }
155}
156
157impl From<CudaDevice> for Device {
158    fn from(cuda_device: CudaDevice) -> Self {
159        Device {
160            device_type: DeviceType::CUDA,
161            index: Some(cuda_device.index),
162        }
163    }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
167pub struct CudaDevice {
168    index: DeviceIndex,
169}
170
171impl CudaDevice {
172    pub fn new(index: DeviceIndex) -> Self {
173        CudaDevice { index }
174    }
175
176    pub fn index(&self) -> DeviceIndex {
177        self.index
178    }
179}
180
181#[derive(Debug, Error)]
182pub enum DeviceParseError {
183    #[error("invalid device string")]
184    InvalidDevice,
185}
186
187// ============================================================================
188// Layout and Memory Format
189// ============================================================================
190
191#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
192#[repr(i32)]
193pub enum Layout {
194    Strided = 0,
195    Sparse = 1,
196    Mkldnn = 2,
197}
198
199/// Remote serde implementation for Layout
200#[derive(Serialize, Deserialize)]
201#[serde(remote = "Layout")]
202pub enum LayoutDef {
203    Strided,
204    Sparse,
205    Mkldnn,
206}
207
208impl FromPyObject<'_> for Layout {
209    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
210        Python::with_gil(|py| {
211            let strided = torch_strided(py);
212            let sparse_coo = torch_sparse_coo(py);
213
214            if obj.eq(strided)? {
215                Ok(Layout::Strided)
216            } else if obj.eq(sparse_coo)? {
217                Ok(Layout::Sparse)
218            } else {
219                // Try to match by string representation
220                let obj_str: String = obj.str()?.extract()?;
221                match obj_str.as_str() {
222                    "torch.strided" => Ok(Layout::Strided),
223                    "torch.sparse_coo" => Ok(Layout::Sparse),
224                    _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
225                        "Unknown layout type",
226                    )),
227                }
228            }
229        })
230    }
231}
232
233impl<'py> IntoPyObject<'py> for Layout {
234    type Target = PyAny;
235    type Output = Bound<'py, Self::Target>;
236    type Error = PyErr;
237
238    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
239        match self {
240            Layout::Strided => Ok(torch_strided(py)),
241            Layout::Sparse => Ok(torch_sparse_coo(py)),
242            Layout::Mkldnn => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
243                "MKLDNN layout not supported in PyTorch",
244            )),
245        }
246    }
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
250#[repr(i32)]
251pub enum MemoryFormat {
252    Contiguous = 0,
253    Preserve = 1,
254    ChannelsLast = 2,
255    ChannelsLast3d = 3,
256}
257
258/// Remote serde implementation for MemoryFormat
259#[derive(Serialize, Deserialize)]
260#[serde(remote = "MemoryFormat")]
261pub enum MemoryFormatDef {
262    Contiguous,
263    Preserve,
264    ChannelsLast,
265    ChannelsLast3d,
266}
267
268impl<'py> IntoPyObject<'py> for MemoryFormat {
269    type Target = PyAny;
270    type Output = Bound<'py, Self::Target>;
271    type Error = PyErr;
272
273    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
274        match self {
275            MemoryFormat::Contiguous => Ok(torch_contiguous_format(py)),
276            MemoryFormat::Preserve => Ok(torch_preserve_format(py)),
277            MemoryFormat::ChannelsLast => Ok(torch_channels_last(py)),
278            MemoryFormat::ChannelsLast3d => Ok(torch_channels_last_3d(py)),
279        }
280    }
281}
282
283// ============================================================================
284// ScalarType
285// ============================================================================
286
287#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
288#[repr(i32)]
289pub enum ScalarType {
290    Byte = 0,
291    Char = 1,
292    Short = 2,
293    Int = 3,
294    Long = 4,
295    Half = 5,
296    Float = 6,
297    Double = 7,
298    ComplexHalf = 8,
299    ComplexFloat = 9,
300    ComplexDouble = 10,
301    Bool = 11,
302    QInt8 = 12,
303    QUInt8 = 13,
304    QInt32 = 14,
305    BFloat16 = 15,
306    QUInt4x2 = 16,
307    QUInt2x4 = 17,
308    Bits1x8 = 18,
309    Bits2x4 = 19,
310    Bits4x2 = 20,
311    Bits8 = 21,
312    Bits16 = 22,
313    Float8_e5m2 = 23,
314    Float8_e4m3fn = 24,
315    Float8_e5m2fnuz = 25,
316    Float8_e4m3fnuz = 26,
317}
318
319/// Remote serde implementation for ScalarType
320#[derive(Serialize, Deserialize)]
321#[serde(remote = "ScalarType")]
322pub enum ScalarTypeDef {
323    Byte,
324    Char,
325    Short,
326    Int,
327    Long,
328    Half,
329    Float,
330    Double,
331    ComplexHalf,
332    ComplexFloat,
333    ComplexDouble,
334    Bool,
335    QInt8,
336    QUInt8,
337    QInt32,
338    BFloat16,
339    QUInt4x2,
340    QUInt2x4,
341    Bits1x8,
342    Bits2x4,
343    Bits4x2,
344    Bits8,
345    Bits16,
346    Float8_e5m2,
347    Float8_e4m3fn,
348    Float8_e5m2fnuz,
349    Float8_e4m3fnuz,
350}
351
352impl FromPyObject<'_> for ScalarType {
353    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
354        Python::with_gil(|py| {
355            // Map of PyTorch dtype getters to ScalarType
356            let dtype_map = [
357                (torch_uint8(py), ScalarType::Byte),
358                (torch_int8(py), ScalarType::Char),
359                (torch_int16(py), ScalarType::Short),
360                (torch_int32(py), ScalarType::Int),
361                (torch_int64(py), ScalarType::Long),
362                (torch_float16(py), ScalarType::Half),
363                (torch_float32(py), ScalarType::Float),
364                (torch_float64(py), ScalarType::Double),
365                (torch_complex32(py), ScalarType::ComplexHalf),
366                (torch_complex64(py), ScalarType::ComplexFloat),
367                (torch_complex128(py), ScalarType::ComplexDouble),
368                (torch_bool(py), ScalarType::Bool),
369                (torch_bfloat16(py), ScalarType::BFloat16),
370                (torch_float8_e5m2(py), ScalarType::Float8_e5m2),
371                (torch_float8_e4m3fn(py), ScalarType::Float8_e4m3fn),
372            ];
373
374            // Try matching by equality with torch dtypes
375            for (dtype, scalar_type) in &dtype_map {
376                if obj.eq(dtype)? {
377                    return Ok(*scalar_type);
378                }
379            }
380
381            // Try matching by string representation
382            let obj_str: String = obj.str()?.extract()?;
383            let str_map = [
384                ("uint8", ScalarType::Byte),
385                ("int8", ScalarType::Char),
386                ("int16", ScalarType::Short),
387                ("int32", ScalarType::Int),
388                ("int64", ScalarType::Long),
389                ("float16", ScalarType::Half),
390                ("float32", ScalarType::Float),
391                ("float64", ScalarType::Double),
392                ("complex32", ScalarType::ComplexHalf),
393                ("complex64", ScalarType::ComplexFloat),
394                ("complex128", ScalarType::ComplexDouble),
395                ("bool", ScalarType::Bool),
396                ("bfloat16", ScalarType::BFloat16),
397                ("float8_e5m2", ScalarType::Float8_e5m2),
398                ("float8_e4m3fn", ScalarType::Float8_e4m3fn),
399            ];
400
401            for (name, scalar_type) in &str_map {
402                if obj_str.contains(name) {
403                    return Ok(*scalar_type);
404                }
405            }
406
407            Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
408                "Unknown scalar type: {}",
409                obj_str
410            )))
411        })
412    }
413}
414
415impl<'py> IntoPyObject<'py> for ScalarType {
416    type Target = PyAny;
417    type Output = Bound<'py, Self::Target>;
418    type Error = PyErr;
419
420    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
421        match self {
422            ScalarType::Byte => Ok(torch_uint8(py)),
423            ScalarType::Char => Ok(torch_int8(py)),
424            ScalarType::Short => Ok(torch_int16(py)),
425            ScalarType::Int => Ok(torch_int32(py)),
426            ScalarType::Long => Ok(torch_int64(py)),
427            ScalarType::Half => Ok(torch_float16(py)),
428            ScalarType::Float => Ok(torch_float32(py)),
429            ScalarType::Double => Ok(torch_float64(py)),
430            ScalarType::ComplexHalf => Ok(torch_complex32(py)),
431            ScalarType::ComplexFloat => Ok(torch_complex64(py)),
432            ScalarType::ComplexDouble => Ok(torch_complex128(py)),
433            ScalarType::Bool => Ok(torch_bool(py)),
434            ScalarType::BFloat16 => Ok(torch_bfloat16(py)),
435            ScalarType::Float8_e5m2 => Ok(torch_float8_e5m2(py)),
436            ScalarType::Float8_e4m3fn => Ok(torch_float8_e4m3fn(py)),
437            _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
438                "Unsupported scalar type: {:?}",
439                self
440            ))),
441        }
442    }
443}
444
445// ============================================================================
446// Tensor and TensorCell
447// ============================================================================
448
449#[derive(Debug)]
450pub struct Tensor {
451    inner: PyObject,
452}
453
454impl Clone for Tensor {
455    fn clone(&self) -> Self {
456        Python::with_gil(|py| Tensor {
457            inner: self.inner.clone_ref(py),
458        })
459    }
460}
461
462impl Tensor {
463    pub fn scalar_type(&self) -> ScalarType {
464        Python::with_gil(|py| {
465            let tensor = self.inner.bind(py);
466            let dtype = tensor.getattr("dtype").unwrap();
467            ScalarType::extract_bound(&dtype).unwrap()
468        })
469    }
470
471    pub fn device(&self) -> Device {
472        Python::with_gil(|py| {
473            let tensor = self.inner.bind(py);
474            let device = tensor.getattr("device").unwrap();
475            Device::extract_bound(&device).unwrap()
476        })
477    }
478
479    pub fn numel(&self) -> i64 {
480        Python::with_gil(|py| {
481            let tensor = self.inner.bind(py);
482            tensor.call_method0("numel").unwrap().extract().unwrap()
483        })
484    }
485
486    pub fn data_ptr(&self) -> *const std::ffi::c_void {
487        Python::with_gil(|py| {
488            let tensor = self.inner.bind(py);
489            let ptr: usize = tensor.call_method0("data_ptr").unwrap().extract().unwrap();
490            ptr as *const std::ffi::c_void
491        })
492    }
493
494    pub fn mut_data_ptr(&self) -> *mut std::ffi::c_void {
495        self.data_ptr() as *mut std::ffi::c_void
496    }
497
498    pub fn defined(&self) -> bool {
499        Python::with_gil(|py| {
500            let tensor = self.inner.bind(py);
501            // A tensor is defined if it's not None and has storage
502            !tensor.is_none()
503        })
504    }
505
506    pub fn is_cuda(&self) -> bool {
507        Python::with_gil(|py| {
508            let tensor = self.inner.bind(py);
509            tensor.getattr("is_cuda").unwrap().extract().unwrap()
510        })
511    }
512
513    pub fn is_sparse(&self) -> bool {
514        Python::with_gil(|py| {
515            let tensor = self.inner.bind(py);
516            tensor.getattr("is_sparse").unwrap().extract().unwrap()
517        })
518    }
519
520    pub fn is_contiguous(&self) -> bool {
521        Python::with_gil(|py| {
522            let tensor = self.inner.bind(py);
523            tensor
524                .call_method0("is_contiguous")
525                .unwrap()
526                .extract()
527                .unwrap()
528        })
529    }
530
531    pub fn nbytes(&self) -> i64 {
532        Python::with_gil(|py| {
533            let tensor = self.inner.bind(py);
534            tensor.getattr("nbytes").unwrap().extract().unwrap()
535        })
536    }
537
538    pub fn sizes(&self) -> Vec<i64> {
539        Python::with_gil(|py| {
540            let tensor = self.inner.bind(py);
541            let size = tensor.call_method0("size").unwrap();
542            size.try_iter()
543                .unwrap()
544                .map(|x| x.unwrap().extract().unwrap())
545                .collect()
546        })
547    }
548}
549
550impl pyo3::FromPyObject<'_> for Tensor {
551    fn extract_bound(ob: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
552        Ok(Tensor {
553            inner: ob.clone().unbind(),
554        })
555    }
556}
557
558impl<'py> IntoPyObject<'py> for Tensor {
559    type Target = PyAny;
560    type Output = Bound<'py, Self::Target>;
561    type Error = PyErr;
562
563    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
564        Ok(self.inner.bind(py).clone())
565    }
566}
567
568#[derive(Debug, Clone)]
569pub struct TensorCell {
570    tensor: Tensor,
571}
572
573impl TensorCell {
574    pub fn new(tensor: Tensor) -> Self {
575        TensorCell { tensor }
576    }
577
578    pub fn borrow(&self) -> BorrowGuard {
579        BorrowGuard {
580            // SAFETY: TensorCell owns the tensor and the returned BorrowGuard
581            // maintains proper ownership semantics by holding a clone.
582            tensor: unsafe { self.tensor.clone_unsafe() },
583        }
584    }
585
586    pub fn borrow_mut(&self) -> BorrowGuardMut {
587        BorrowGuardMut {
588            // SAFETY: TensorCell owns the tensor and the returned BorrowGuardMut
589            // maintains proper ownership semantics by holding a clone.
590            tensor: unsafe { self.tensor.clone_unsafe() },
591        }
592    }
593
594    pub fn aliases(&self, other: &TensorCell) -> bool {
595        // Check if two tensors share the same underlying storage
596        Python::with_gil(|_py| {
597            let self_ptr = self.tensor.data_ptr();
598            let other_ptr = other.tensor.data_ptr();
599            self_ptr == other_ptr && !self_ptr.is_null()
600        })
601    }
602
603    /// # Safety
604    /// Caller must ensure that the TensorCell is borrowed appropriately
605    pub unsafe fn get_unchecked(&self) -> &Tensor {
606        &self.tensor
607    }
608
609    pub fn try_borrow(&self) -> Result<BorrowGuard, BorrowError> {
610        Ok(self.borrow())
611    }
612
613    pub fn try_borrow_mut(&self) -> Result<BorrowGuardMut, BorrowError> {
614        Ok(self.borrow_mut())
615    }
616
617    pub fn try_cpu(&self) -> Result<TensorCell, BorrowError> {
618        Python::with_gil(|py| {
619            let tensor = self.tensor.inner.bind(py);
620            let cpu_tensor = tensor
621                .call_method0("cpu")
622                .map_err(|_| BorrowError::BorrowError)?;
623            Ok(TensorCell::new(Tensor {
624                inner: cpu_tensor.clone().unbind(),
625            }))
626        })
627    }
628}
629
630#[derive(Debug, Clone)]
631pub struct BorrowGuard {
632    tensor: Tensor,
633}
634
635impl std::ops::Deref for BorrowGuard {
636    type Target = Tensor;
637
638    fn deref(&self) -> &Self::Target {
639        &self.tensor
640    }
641}
642
643#[derive(Debug, Clone)]
644pub struct BorrowGuardMut {
645    tensor: Tensor,
646}
647
648impl std::ops::Deref for BorrowGuardMut {
649    type Target = Tensor;
650
651    fn deref(&self) -> &Self::Target {
652        &self.tensor
653    }
654}
655
656impl std::ops::DerefMut for BorrowGuardMut {
657    fn deref_mut(&mut self) -> &mut Self::Target {
658        &mut self.tensor
659    }
660}
661
662impl BorrowGuardMut {
663    pub fn copy_(&mut self, src: &Tensor) {
664        Python::with_gil(|py| {
665            let dst_tensor = self.tensor.inner.bind(py);
666            let src_tensor = src.inner.bind(py);
667            dst_tensor.call_method1("copy_", (src_tensor,)).unwrap();
668        })
669    }
670}
671
672// ============================================================================
673// CloneUnsafe trait
674// ============================================================================
675
676pub trait CloneUnsafe {
677    /// # Safety
678    /// Caller must ensure proper ownership semantics
679    unsafe fn clone_unsafe(&self) -> Self;
680}
681
682impl CloneUnsafe for Tensor {
683    unsafe fn clone_unsafe(&self) -> Self {
684        self.clone()
685    }
686}
687
688// ============================================================================
689// Borrow errors
690// ============================================================================
691
692#[derive(Debug, Error)]
693pub enum BorrowError {
694    #[error("borrow error")]
695    BorrowError,
696}
697
698#[derive(Debug, Clone, Copy)]
699pub enum BorrowType {
700    Shared,
701    Exclusive,
702}
703
704#[derive(Debug)]
705pub struct Borrow {
706    _private: (),
707}
708
709#[derive(Debug)]
710pub struct MultiBorrow {
711    _private: (),
712}
713
714// ============================================================================
715// Factory functions
716// ============================================================================
717
718pub fn factory_zeros(size: &[i64], dtype: ScalarType, layout: Layout, device: Device) -> Tensor {
719    Python::with_gil(|py| {
720        let size_tuple = pyo3::types::PyTuple::new(py, size).unwrap();
721        let dtype_obj = dtype.into_pyobject(py).unwrap();
722        let device_obj = device.into_pyobject(py).unwrap();
723        let layout_obj = layout.into_pyobject(py).unwrap();
724
725        let kwargs = pyo3::types::PyDict::new(py);
726        kwargs.set_item("dtype", dtype_obj).unwrap();
727        kwargs.set_item("device", device_obj).unwrap();
728        kwargs.set_item("layout", layout_obj).unwrap();
729
730        let result = torch_zeros(py).call((size_tuple,), Some(&kwargs)).unwrap();
731
732        Tensor {
733            inner: result.clone().unbind(),
734        }
735    })
736}
737
738pub fn factory_empty(size: &[i64], dtype: ScalarType, layout: Layout, device: Device) -> Tensor {
739    Python::with_gil(|py| {
740        let size_tuple = pyo3::types::PyTuple::new(py, size).unwrap();
741        let dtype_obj = dtype.into_pyobject(py).unwrap();
742        let device_obj = device.into_pyobject(py).unwrap();
743        let layout_obj = layout.into_pyobject(py).unwrap();
744
745        let kwargs = pyo3::types::PyDict::new(py);
746        kwargs.set_item("dtype", dtype_obj).unwrap();
747        kwargs.set_item("device", device_obj).unwrap();
748        kwargs.set_item("layout", layout_obj).unwrap();
749
750        let result = torch_empty(py).call((size_tuple,), Some(&kwargs)).unwrap();
751
752        Tensor {
753            inner: result.clone().unbind(),
754        }
755    })
756}
757
758pub fn factory_float_tensor(data: &[f32], device: Device) -> Tensor {
759    Python::with_gil(|py| {
760        let data_list = pyo3::types::PyList::new(py, data).unwrap();
761        let device_obj = device.into_pyobject(py).unwrap();
762
763        let kwargs = pyo3::types::PyDict::new(py);
764        kwargs.set_item("device", device_obj).unwrap();
765        kwargs.set_item("dtype", torch_float32(py)).unwrap();
766
767        let result = torch_tensor(py).call((data_list,), Some(&kwargs)).unwrap();
768
769        Tensor {
770            inner: result.clone().unbind(),
771        }
772    })
773}
774
775pub fn deep_clone(tensor: &Tensor) -> Tensor {
776    Python::with_gil(|py| {
777        let tensor_obj = tensor.inner.bind(py);
778        let cloned = tensor_obj.call_method0("clone").unwrap();
779        Tensor {
780            inner: cloned.clone().unbind(),
781        }
782    })
783}
784
785pub fn is_float8_type(scalar_type: ScalarType) -> bool {
786    matches!(
787        scalar_type,
788        ScalarType::Float8_e5m2
789            | ScalarType::Float8_e4m3fn
790            | ScalarType::Float8_e5m2fnuz
791            | ScalarType::Float8_e4m3fnuz
792    )
793}
794
795pub fn suggest_memory_format(tensor: &Tensor) -> MemoryFormat {
796    Python::with_gil(|py| {
797        let tensor_obj = tensor.inner.bind(py);
798
799        // Call suggest_memory_format method on the tensor
800        let result = tensor_obj.call_method0("suggest_memory_format").unwrap();
801
802        // Convert the result back to our enum
803        let result_str: String = result.str().unwrap().extract().unwrap();
804
805        if result_str.contains("channels_last_3d") {
806            MemoryFormat::ChannelsLast3d
807        } else if result_str.contains("channels_last") {
808            MemoryFormat::ChannelsLast
809        } else if result_str.contains("preserve") {
810            MemoryFormat::Preserve
811        } else {
812            MemoryFormat::Contiguous
813        }
814    })
815}