1#![allow(non_camel_case_types)]
10
11pub mod testing;
17
18use monarch_types::py_global;
19use pyo3::prelude::*;
20use pyo3::types::PyAny;
21use serde::Deserialize;
22use serde::Serialize;
23use thiserror::Error;
24
25py_global!(torch_device, "torch", "device");
27py_global!(torch_strided, "torch", "strided");
28py_global!(torch_sparse_coo, "torch", "sparse_coo");
29py_global!(torch_contiguous_format, "torch", "contiguous_format");
30py_global!(torch_preserve_format, "torch", "preserve_format");
31py_global!(torch_channels_last, "torch", "channels_last");
32py_global!(torch_channels_last_3d, "torch", "channels_last_3d");
33py_global!(torch_uint8, "torch", "uint8");
34py_global!(torch_int8, "torch", "int8");
35py_global!(torch_int16, "torch", "int16");
36py_global!(torch_int32, "torch", "int32");
37py_global!(torch_int64, "torch", "int64");
38py_global!(torch_float16, "torch", "float16");
39py_global!(torch_float32, "torch", "float32");
40py_global!(torch_float64, "torch", "float64");
41py_global!(torch_complex32, "torch", "complex32");
42py_global!(torch_complex64, "torch", "complex64");
43py_global!(torch_complex128, "torch", "complex128");
44py_global!(torch_bool, "torch", "bool");
45py_global!(torch_bfloat16, "torch", "bfloat16");
46py_global!(torch_float8_e5m2, "torch", "float8_e5m2");
47py_global!(torch_float8_e4m3fn, "torch", "float8_e4m3fn");
48py_global!(torch_zeros, "torch", "zeros");
49py_global!(torch_empty, "torch", "empty");
50py_global!(torch_tensor, "torch", "tensor");
51py_global!(torch_allclose, "torch", "allclose");
52py_global!(torch_full, "torch", "full");
53py_global!(torch_stack, "torch", "stack");
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
60pub enum DeviceType {
61 CPU,
62 CUDA,
63}
64
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
66pub struct DeviceIndex(pub i8);
67
68impl From<DeviceIndex> for i8 {
69 fn from(idx: DeviceIndex) -> i8 {
70 idx.0
71 }
72}
73
74impl From<i8> for DeviceIndex {
75 fn from(idx: i8) -> DeviceIndex {
76 DeviceIndex(idx)
77 }
78}
79
80#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
81pub struct Device {
82 device_type: DeviceType,
83 index: Option<DeviceIndex>,
84}
85
86impl Device {
87 pub fn device_type(&self) -> DeviceType {
88 self.device_type
89 }
90}
91
92impl FromPyObject<'_> for Device {
93 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
94 let device_str: String = obj.str()?.extract()?;
95 device_str.parse().map_err(|e: DeviceParseError| {
97 PyErr::new::<pyo3::exceptions::PyValueError, _>(e.to_string())
98 })
99 }
100}
101
102impl<'py> IntoPyObject<'py> for Device {
103 type Target = PyAny;
104 type Output = Bound<'py, Self::Target>;
105 type Error = PyErr;
106
107 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
108 let device_str = self.to_string();
109 let device = torch_device(py).call1((device_str,))?;
110 Ok(device)
111 }
112}
113
114impl std::fmt::Display for Device {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 match self.device_type {
117 DeviceType::CPU => write!(f, "cpu"),
118 DeviceType::CUDA => {
119 if let Some(index) = self.index {
120 write!(f, "cuda:{}", index.0)
121 } else {
122 write!(f, "cuda")
123 }
124 }
125 }
126 }
127}
128
129impl std::str::FromStr for Device {
130 type Err = DeviceParseError;
131
132 fn from_str(s: &str) -> Result<Self, Self::Err> {
133 if s == "cpu" {
134 Ok(Device {
135 device_type: DeviceType::CPU,
136 index: None,
137 })
138 } else if s == "cuda" {
139 Ok(Device {
140 device_type: DeviceType::CUDA,
141 index: None,
142 })
143 } else if let Some(cuda_idx) = s.strip_prefix("cuda:") {
144 let index = cuda_idx
145 .parse::<i8>()
146 .map_err(|_| DeviceParseError::InvalidDevice)?;
147 Ok(Device {
148 device_type: DeviceType::CUDA,
149 index: Some(DeviceIndex(index)),
150 })
151 } else {
152 Err(DeviceParseError::InvalidDevice)
153 }
154 }
155}
156
157impl From<CudaDevice> for Device {
158 fn from(cuda_device: CudaDevice) -> Self {
159 Device {
160 device_type: DeviceType::CUDA,
161 index: Some(cuda_device.index),
162 }
163 }
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
167pub struct CudaDevice {
168 index: DeviceIndex,
169}
170
171impl CudaDevice {
172 pub fn new(index: DeviceIndex) -> Self {
173 CudaDevice { index }
174 }
175
176 pub fn index(&self) -> DeviceIndex {
177 self.index
178 }
179}
180
181#[derive(Debug, Error)]
182pub enum DeviceParseError {
183 #[error("invalid device string")]
184 InvalidDevice,
185}
186
187#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
192#[repr(i32)]
193pub enum Layout {
194 Strided = 0,
195 Sparse = 1,
196 Mkldnn = 2,
197}
198
199#[derive(Serialize, Deserialize)]
201#[serde(remote = "Layout")]
202pub enum LayoutDef {
203 Strided,
204 Sparse,
205 Mkldnn,
206}
207
208impl FromPyObject<'_> for Layout {
209 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
210 Python::with_gil(|py| {
211 let strided = torch_strided(py);
212 let sparse_coo = torch_sparse_coo(py);
213
214 if obj.eq(strided)? {
215 Ok(Layout::Strided)
216 } else if obj.eq(sparse_coo)? {
217 Ok(Layout::Sparse)
218 } else {
219 let obj_str: String = obj.str()?.extract()?;
221 match obj_str.as_str() {
222 "torch.strided" => Ok(Layout::Strided),
223 "torch.sparse_coo" => Ok(Layout::Sparse),
224 _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
225 "Unknown layout type",
226 )),
227 }
228 }
229 })
230 }
231}
232
233impl<'py> IntoPyObject<'py> for Layout {
234 type Target = PyAny;
235 type Output = Bound<'py, Self::Target>;
236 type Error = PyErr;
237
238 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
239 match self {
240 Layout::Strided => Ok(torch_strided(py)),
241 Layout::Sparse => Ok(torch_sparse_coo(py)),
242 Layout::Mkldnn => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
243 "MKLDNN layout not supported in PyTorch",
244 )),
245 }
246 }
247}
248
249#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
250#[repr(i32)]
251pub enum MemoryFormat {
252 Contiguous = 0,
253 Preserve = 1,
254 ChannelsLast = 2,
255 ChannelsLast3d = 3,
256}
257
258#[derive(Serialize, Deserialize)]
260#[serde(remote = "MemoryFormat")]
261pub enum MemoryFormatDef {
262 Contiguous,
263 Preserve,
264 ChannelsLast,
265 ChannelsLast3d,
266}
267
268impl<'py> IntoPyObject<'py> for MemoryFormat {
269 type Target = PyAny;
270 type Output = Bound<'py, Self::Target>;
271 type Error = PyErr;
272
273 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
274 match self {
275 MemoryFormat::Contiguous => Ok(torch_contiguous_format(py)),
276 MemoryFormat::Preserve => Ok(torch_preserve_format(py)),
277 MemoryFormat::ChannelsLast => Ok(torch_channels_last(py)),
278 MemoryFormat::ChannelsLast3d => Ok(torch_channels_last_3d(py)),
279 }
280 }
281}
282
283#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
288#[repr(i32)]
289pub enum ScalarType {
290 Byte = 0,
291 Char = 1,
292 Short = 2,
293 Int = 3,
294 Long = 4,
295 Half = 5,
296 Float = 6,
297 Double = 7,
298 ComplexHalf = 8,
299 ComplexFloat = 9,
300 ComplexDouble = 10,
301 Bool = 11,
302 QInt8 = 12,
303 QUInt8 = 13,
304 QInt32 = 14,
305 BFloat16 = 15,
306 QUInt4x2 = 16,
307 QUInt2x4 = 17,
308 Bits1x8 = 18,
309 Bits2x4 = 19,
310 Bits4x2 = 20,
311 Bits8 = 21,
312 Bits16 = 22,
313 Float8_e5m2 = 23,
314 Float8_e4m3fn = 24,
315 Float8_e5m2fnuz = 25,
316 Float8_e4m3fnuz = 26,
317}
318
319#[derive(Serialize, Deserialize)]
321#[serde(remote = "ScalarType")]
322pub enum ScalarTypeDef {
323 Byte,
324 Char,
325 Short,
326 Int,
327 Long,
328 Half,
329 Float,
330 Double,
331 ComplexHalf,
332 ComplexFloat,
333 ComplexDouble,
334 Bool,
335 QInt8,
336 QUInt8,
337 QInt32,
338 BFloat16,
339 QUInt4x2,
340 QUInt2x4,
341 Bits1x8,
342 Bits2x4,
343 Bits4x2,
344 Bits8,
345 Bits16,
346 Float8_e5m2,
347 Float8_e4m3fn,
348 Float8_e5m2fnuz,
349 Float8_e4m3fnuz,
350}
351
352impl FromPyObject<'_> for ScalarType {
353 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
354 Python::with_gil(|py| {
355 let dtype_map = [
357 (torch_uint8(py), ScalarType::Byte),
358 (torch_int8(py), ScalarType::Char),
359 (torch_int16(py), ScalarType::Short),
360 (torch_int32(py), ScalarType::Int),
361 (torch_int64(py), ScalarType::Long),
362 (torch_float16(py), ScalarType::Half),
363 (torch_float32(py), ScalarType::Float),
364 (torch_float64(py), ScalarType::Double),
365 (torch_complex32(py), ScalarType::ComplexHalf),
366 (torch_complex64(py), ScalarType::ComplexFloat),
367 (torch_complex128(py), ScalarType::ComplexDouble),
368 (torch_bool(py), ScalarType::Bool),
369 (torch_bfloat16(py), ScalarType::BFloat16),
370 (torch_float8_e5m2(py), ScalarType::Float8_e5m2),
371 (torch_float8_e4m3fn(py), ScalarType::Float8_e4m3fn),
372 ];
373
374 for (dtype, scalar_type) in &dtype_map {
376 if obj.eq(dtype)? {
377 return Ok(*scalar_type);
378 }
379 }
380
381 let obj_str: String = obj.str()?.extract()?;
383 let str_map = [
384 ("uint8", ScalarType::Byte),
385 ("int8", ScalarType::Char),
386 ("int16", ScalarType::Short),
387 ("int32", ScalarType::Int),
388 ("int64", ScalarType::Long),
389 ("float16", ScalarType::Half),
390 ("float32", ScalarType::Float),
391 ("float64", ScalarType::Double),
392 ("complex32", ScalarType::ComplexHalf),
393 ("complex64", ScalarType::ComplexFloat),
394 ("complex128", ScalarType::ComplexDouble),
395 ("bool", ScalarType::Bool),
396 ("bfloat16", ScalarType::BFloat16),
397 ("float8_e5m2", ScalarType::Float8_e5m2),
398 ("float8_e4m3fn", ScalarType::Float8_e4m3fn),
399 ];
400
401 for (name, scalar_type) in &str_map {
402 if obj_str.contains(name) {
403 return Ok(*scalar_type);
404 }
405 }
406
407 Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
408 "Unknown scalar type: {}",
409 obj_str
410 )))
411 })
412 }
413}
414
415impl<'py> IntoPyObject<'py> for ScalarType {
416 type Target = PyAny;
417 type Output = Bound<'py, Self::Target>;
418 type Error = PyErr;
419
420 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
421 match self {
422 ScalarType::Byte => Ok(torch_uint8(py)),
423 ScalarType::Char => Ok(torch_int8(py)),
424 ScalarType::Short => Ok(torch_int16(py)),
425 ScalarType::Int => Ok(torch_int32(py)),
426 ScalarType::Long => Ok(torch_int64(py)),
427 ScalarType::Half => Ok(torch_float16(py)),
428 ScalarType::Float => Ok(torch_float32(py)),
429 ScalarType::Double => Ok(torch_float64(py)),
430 ScalarType::ComplexHalf => Ok(torch_complex32(py)),
431 ScalarType::ComplexFloat => Ok(torch_complex64(py)),
432 ScalarType::ComplexDouble => Ok(torch_complex128(py)),
433 ScalarType::Bool => Ok(torch_bool(py)),
434 ScalarType::BFloat16 => Ok(torch_bfloat16(py)),
435 ScalarType::Float8_e5m2 => Ok(torch_float8_e5m2(py)),
436 ScalarType::Float8_e4m3fn => Ok(torch_float8_e4m3fn(py)),
437 _ => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
438 "Unsupported scalar type: {:?}",
439 self
440 ))),
441 }
442 }
443}
444
445#[derive(Debug)]
450pub struct Tensor {
451 inner: PyObject,
452}
453
454impl Clone for Tensor {
455 fn clone(&self) -> Self {
456 Python::with_gil(|py| Tensor {
457 inner: self.inner.clone_ref(py),
458 })
459 }
460}
461
462impl Tensor {
463 pub fn scalar_type(&self) -> ScalarType {
464 Python::with_gil(|py| {
465 let tensor = self.inner.bind(py);
466 let dtype = tensor.getattr("dtype").unwrap();
467 ScalarType::extract_bound(&dtype).unwrap()
468 })
469 }
470
471 pub fn device(&self) -> Device {
472 Python::with_gil(|py| {
473 let tensor = self.inner.bind(py);
474 let device = tensor.getattr("device").unwrap();
475 Device::extract_bound(&device).unwrap()
476 })
477 }
478
479 pub fn numel(&self) -> i64 {
480 Python::with_gil(|py| {
481 let tensor = self.inner.bind(py);
482 tensor.call_method0("numel").unwrap().extract().unwrap()
483 })
484 }
485
486 pub fn data_ptr(&self) -> *const std::ffi::c_void {
487 Python::with_gil(|py| {
488 let tensor = self.inner.bind(py);
489 let ptr: usize = tensor.call_method0("data_ptr").unwrap().extract().unwrap();
490 ptr as *const std::ffi::c_void
491 })
492 }
493
494 pub fn mut_data_ptr(&self) -> *mut std::ffi::c_void {
495 self.data_ptr() as *mut std::ffi::c_void
496 }
497
498 pub fn defined(&self) -> bool {
499 Python::with_gil(|py| {
500 let tensor = self.inner.bind(py);
501 !tensor.is_none()
503 })
504 }
505
506 pub fn is_cuda(&self) -> bool {
507 Python::with_gil(|py| {
508 let tensor = self.inner.bind(py);
509 tensor.getattr("is_cuda").unwrap().extract().unwrap()
510 })
511 }
512
513 pub fn is_sparse(&self) -> bool {
514 Python::with_gil(|py| {
515 let tensor = self.inner.bind(py);
516 tensor.getattr("is_sparse").unwrap().extract().unwrap()
517 })
518 }
519
520 pub fn is_contiguous(&self) -> bool {
521 Python::with_gil(|py| {
522 let tensor = self.inner.bind(py);
523 tensor
524 .call_method0("is_contiguous")
525 .unwrap()
526 .extract()
527 .unwrap()
528 })
529 }
530
531 pub fn nbytes(&self) -> i64 {
532 Python::with_gil(|py| {
533 let tensor = self.inner.bind(py);
534 tensor.getattr("nbytes").unwrap().extract().unwrap()
535 })
536 }
537
538 pub fn sizes(&self) -> Vec<i64> {
539 Python::with_gil(|py| {
540 let tensor = self.inner.bind(py);
541 let size = tensor.call_method0("size").unwrap();
542 size.try_iter()
543 .unwrap()
544 .map(|x| x.unwrap().extract().unwrap())
545 .collect()
546 })
547 }
548}
549
550impl pyo3::FromPyObject<'_> for Tensor {
551 fn extract_bound(ob: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
552 Ok(Tensor {
553 inner: ob.clone().unbind(),
554 })
555 }
556}
557
558impl<'py> IntoPyObject<'py> for Tensor {
559 type Target = PyAny;
560 type Output = Bound<'py, Self::Target>;
561 type Error = PyErr;
562
563 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
564 Ok(self.inner.bind(py).clone())
565 }
566}
567
568#[derive(Debug, Clone)]
569pub struct TensorCell {
570 tensor: Tensor,
571}
572
573impl TensorCell {
574 pub fn new(tensor: Tensor) -> Self {
575 TensorCell { tensor }
576 }
577
578 pub fn borrow(&self) -> BorrowGuard {
579 BorrowGuard {
580 tensor: unsafe { self.tensor.clone_unsafe() },
583 }
584 }
585
586 pub fn borrow_mut(&self) -> BorrowGuardMut {
587 BorrowGuardMut {
588 tensor: unsafe { self.tensor.clone_unsafe() },
591 }
592 }
593
594 pub fn aliases(&self, other: &TensorCell) -> bool {
595 Python::with_gil(|_py| {
597 let self_ptr = self.tensor.data_ptr();
598 let other_ptr = other.tensor.data_ptr();
599 self_ptr == other_ptr && !self_ptr.is_null()
600 })
601 }
602
603 pub unsafe fn get_unchecked(&self) -> &Tensor {
606 &self.tensor
607 }
608
609 pub fn try_borrow(&self) -> Result<BorrowGuard, BorrowError> {
610 Ok(self.borrow())
611 }
612
613 pub fn try_borrow_mut(&self) -> Result<BorrowGuardMut, BorrowError> {
614 Ok(self.borrow_mut())
615 }
616
617 pub fn try_cpu(&self) -> Result<TensorCell, BorrowError> {
618 Python::with_gil(|py| {
619 let tensor = self.tensor.inner.bind(py);
620 let cpu_tensor = tensor
621 .call_method0("cpu")
622 .map_err(|_| BorrowError::BorrowError)?;
623 Ok(TensorCell::new(Tensor {
624 inner: cpu_tensor.clone().unbind(),
625 }))
626 })
627 }
628}
629
630#[derive(Debug, Clone)]
631pub struct BorrowGuard {
632 tensor: Tensor,
633}
634
635impl std::ops::Deref for BorrowGuard {
636 type Target = Tensor;
637
638 fn deref(&self) -> &Self::Target {
639 &self.tensor
640 }
641}
642
643#[derive(Debug, Clone)]
644pub struct BorrowGuardMut {
645 tensor: Tensor,
646}
647
648impl std::ops::Deref for BorrowGuardMut {
649 type Target = Tensor;
650
651 fn deref(&self) -> &Self::Target {
652 &self.tensor
653 }
654}
655
656impl std::ops::DerefMut for BorrowGuardMut {
657 fn deref_mut(&mut self) -> &mut Self::Target {
658 &mut self.tensor
659 }
660}
661
662impl BorrowGuardMut {
663 pub fn copy_(&mut self, src: &Tensor) {
664 Python::with_gil(|py| {
665 let dst_tensor = self.tensor.inner.bind(py);
666 let src_tensor = src.inner.bind(py);
667 dst_tensor.call_method1("copy_", (src_tensor,)).unwrap();
668 })
669 }
670}
671
672pub trait CloneUnsafe {
677 unsafe fn clone_unsafe(&self) -> Self;
680}
681
682impl CloneUnsafe for Tensor {
683 unsafe fn clone_unsafe(&self) -> Self {
684 self.clone()
685 }
686}
687
688#[derive(Debug, Error)]
693pub enum BorrowError {
694 #[error("borrow error")]
695 BorrowError,
696}
697
698#[derive(Debug, Clone, Copy)]
699pub enum BorrowType {
700 Shared,
701 Exclusive,
702}
703
704#[derive(Debug)]
705pub struct Borrow {
706 _private: (),
707}
708
709#[derive(Debug)]
710pub struct MultiBorrow {
711 _private: (),
712}
713
714pub fn factory_zeros(size: &[i64], dtype: ScalarType, layout: Layout, device: Device) -> Tensor {
719 Python::with_gil(|py| {
720 let size_tuple = pyo3::types::PyTuple::new(py, size).unwrap();
721 let dtype_obj = dtype.into_pyobject(py).unwrap();
722 let device_obj = device.into_pyobject(py).unwrap();
723 let layout_obj = layout.into_pyobject(py).unwrap();
724
725 let kwargs = pyo3::types::PyDict::new(py);
726 kwargs.set_item("dtype", dtype_obj).unwrap();
727 kwargs.set_item("device", device_obj).unwrap();
728 kwargs.set_item("layout", layout_obj).unwrap();
729
730 let result = torch_zeros(py).call((size_tuple,), Some(&kwargs)).unwrap();
731
732 Tensor {
733 inner: result.clone().unbind(),
734 }
735 })
736}
737
738pub fn factory_empty(size: &[i64], dtype: ScalarType, layout: Layout, device: Device) -> Tensor {
739 Python::with_gil(|py| {
740 let size_tuple = pyo3::types::PyTuple::new(py, size).unwrap();
741 let dtype_obj = dtype.into_pyobject(py).unwrap();
742 let device_obj = device.into_pyobject(py).unwrap();
743 let layout_obj = layout.into_pyobject(py).unwrap();
744
745 let kwargs = pyo3::types::PyDict::new(py);
746 kwargs.set_item("dtype", dtype_obj).unwrap();
747 kwargs.set_item("device", device_obj).unwrap();
748 kwargs.set_item("layout", layout_obj).unwrap();
749
750 let result = torch_empty(py).call((size_tuple,), Some(&kwargs)).unwrap();
751
752 Tensor {
753 inner: result.clone().unbind(),
754 }
755 })
756}
757
758pub fn factory_float_tensor(data: &[f32], device: Device) -> Tensor {
759 Python::with_gil(|py| {
760 let data_list = pyo3::types::PyList::new(py, data).unwrap();
761 let device_obj = device.into_pyobject(py).unwrap();
762
763 let kwargs = pyo3::types::PyDict::new(py);
764 kwargs.set_item("device", device_obj).unwrap();
765 kwargs.set_item("dtype", torch_float32(py)).unwrap();
766
767 let result = torch_tensor(py).call((data_list,), Some(&kwargs)).unwrap();
768
769 Tensor {
770 inner: result.clone().unbind(),
771 }
772 })
773}
774
775pub fn deep_clone(tensor: &Tensor) -> Tensor {
776 Python::with_gil(|py| {
777 let tensor_obj = tensor.inner.bind(py);
778 let cloned = tensor_obj.call_method0("clone").unwrap();
779 Tensor {
780 inner: cloned.clone().unbind(),
781 }
782 })
783}
784
785pub fn is_float8_type(scalar_type: ScalarType) -> bool {
786 matches!(
787 scalar_type,
788 ScalarType::Float8_e5m2
789 | ScalarType::Float8_e4m3fn
790 | ScalarType::Float8_e5m2fnuz
791 | ScalarType::Float8_e4m3fnuz
792 )
793}
794
795pub fn suggest_memory_format(tensor: &Tensor) -> MemoryFormat {
796 Python::with_gil(|py| {
797 let tensor_obj = tensor.inner.bind(py);
798
799 let result = tensor_obj.call_method0("suggest_memory_format").unwrap();
801
802 let result_str: String = result.str().unwrap().extract().unwrap();
804
805 if result_str.contains("channels_last_3d") {
806 MemoryFormat::ChannelsLast3d
807 } else if result_str.contains("channels_last") {
808 MemoryFormat::ChannelsLast
809 } else if result_str.contains("preserve") {
810 MemoryFormat::Preserve
811 } else {
812 MemoryFormat::Contiguous
813 }
814 })
815}