torch_sys/
device.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Bindings for `c10::Device` and friends.
10
11use std::sync::LazyLock;
12
13use cxx::ExternType;
14use cxx::type_id;
15use derive_more::From;
16use derive_more::Into;
17use pyo3::IntoPyObjectExt;
18use pyo3::exceptions::PyValueError;
19use pyo3::prelude::*;
20use regex::Regex;
21use serde::Deserialize;
22use serde::Serialize;
23use thiserror::Error;
24
25use crate::bridge::ffi;
26
27/// Errors that can be returned from constructing a device from a string.
28#[derive(Error, Debug)]
29#[non_exhaustive]
30pub enum DeviceParseError {
31    #[error("invalid device type specified: {0}")]
32    InvalidDeviceType(String),
33
34    #[error("invalid device index specified: {0}")]
35    InvalidDeviceIndex(#[from] std::num::ParseIntError),
36
37    #[error("invalid device string: {0}")]
38    ParserFailure(String),
39}
40
41/// Binding for `c10::DeviceType`.
42///
43/// This is an `int8_t` enum class in C++. The reason it looks ridiculous here
44/// is because C++ allows the value of an enum to any `int8_t` value, even if
45/// there is no discriminant specified. This is UB in Rust, so in order to
46/// control for that case, we follow the `cxx` strategy of defining a struct
47/// that looks more or less like an enum.
48///
49/// This is a little pedantic but better safe than sorry :)
50#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
51#[repr(transparent)]
52pub struct DeviceType {
53    pub repr: i8,
54}
55
56#[allow(dead_code)]
57#[allow(non_upper_case_globals)]
58impl DeviceType {
59    pub const CPU: Self = DeviceType { repr: 0 };
60    pub const CUDA: Self = DeviceType { repr: 1 };
61    pub const MKLDNN: Self = DeviceType { repr: 2 };
62    pub const OPENGL: Self = DeviceType { repr: 3 };
63    pub const OPENCL: Self = DeviceType { repr: 4 };
64    pub const IDEEP: Self = DeviceType { repr: 5 };
65    pub const HIP: Self = DeviceType { repr: 6 };
66    pub const FPGA: Self = DeviceType { repr: 7 };
67    pub const MAIA: Self = DeviceType { repr: 8 };
68    pub const XLA: Self = DeviceType { repr: 9 };
69    pub const Vulkan: Self = DeviceType { repr: 10 };
70    pub const Metal: Self = DeviceType { repr: 11 };
71    pub const XPU: Self = DeviceType { repr: 12 };
72    pub const MPS: Self = DeviceType { repr: 13 };
73    pub const Meta: Self = DeviceType { repr: 14 };
74    pub const HPU: Self = DeviceType { repr: 15 };
75    pub const VE: Self = DeviceType { repr: 16 };
76    pub const Lazy: Self = DeviceType { repr: 17 };
77    pub const IPU: Self = DeviceType { repr: 18 };
78    pub const MTIA: Self = DeviceType { repr: 19 };
79    pub const PrivateUse1: Self = DeviceType { repr: 20 };
80    pub const CompileTimeMaxDeviceTypes: Self = DeviceType { repr: 21 };
81}
82
83impl TryFrom<&str> for DeviceType {
84    type Error = DeviceParseError;
85    fn try_from(val: &str) -> Result<DeviceType, Self::Error> {
86        Ok(match val {
87            "cpu" => DeviceType::CPU,
88            "cuda" => DeviceType::CUDA,
89            "ipu" => DeviceType::IPU,
90            "xpu" => DeviceType::XPU,
91            "mkldnn" => DeviceType::MKLDNN,
92            "opengl" => DeviceType::OPENGL,
93            "opencl" => DeviceType::OPENCL,
94            "ideep" => DeviceType::IDEEP,
95            "hip" => DeviceType::HIP,
96            "ve" => DeviceType::VE,
97            "fpga" => DeviceType::FPGA,
98            "maia" => DeviceType::MAIA,
99            "xla" => DeviceType::XLA,
100            "lazy" => DeviceType::Lazy,
101            "vulkan" => DeviceType::Vulkan,
102            "mps" => DeviceType::MPS,
103            "meta" => DeviceType::Meta,
104            "hpu" => DeviceType::HPU,
105            "mtia" => DeviceType::MTIA,
106            "privateuseone" => DeviceType::PrivateUse1,
107            _ => return Err(DeviceParseError::InvalidDeviceType(val.to_string())),
108        })
109    }
110}
111
112impl std::fmt::Display for DeviceType {
113    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114        match *self {
115            DeviceType::CPU => write!(f, "cpu"),
116            DeviceType::CUDA => write!(f, "cuda"),
117            DeviceType::IPU => write!(f, "ipu"),
118            DeviceType::XPU => write!(f, "xpu"),
119            DeviceType::MKLDNN => write!(f, "mkldnn"),
120            DeviceType::OPENGL => write!(f, "opengl"),
121            DeviceType::OPENCL => write!(f, "opencl"),
122            DeviceType::IDEEP => write!(f, "ideep"),
123            DeviceType::HIP => write!(f, "hip"),
124            DeviceType::VE => write!(f, "ve"),
125            DeviceType::FPGA => write!(f, "fpga"),
126            DeviceType::MAIA => write!(f, "maia"),
127            DeviceType::XLA => write!(f, "xla"),
128            DeviceType::Lazy => write!(f, "lazy"),
129            DeviceType::Vulkan => write!(f, "vulkan"),
130            DeviceType::MPS => write!(f, "mps"),
131            DeviceType::Meta => write!(f, "meta"),
132            DeviceType::HPU => write!(f, "hpu"),
133            DeviceType::MTIA => write!(f, "mtia"),
134            DeviceType::PrivateUse1 => write!(f, "privateuseone"),
135            _ => write!(f, "unknown"),
136        }
137    }
138}
139
140// SAFETY: Register our custom type implementation with cxx.
141unsafe impl ExternType for DeviceType {
142    type Id = type_id!("c10::DeviceType");
143    // Yes, it's trivial, it's just an i8.
144    type Kind = cxx::kind::Trivial;
145}
146
147impl FromPyObject<'_> for DeviceType {
148    fn extract_bound(obj: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
149        obj.extract::<String>()?
150            .as_str()
151            .try_into()
152            .map_err(|e| PyErr::new::<PyValueError, _>(format!("Failed extracting from py: {}", e)))
153    }
154}
155
156impl<'py> IntoPyObject<'py> for DeviceType {
157    type Target = PyAny;
158    type Output = Bound<'py, Self::Target>;
159    type Error = PyErr;
160
161    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
162        format!("{}", self).into_bound_py_any(py)
163    }
164}
165
166/// Binding for `c10::DeviceIndex`.
167///
168/// An index representing a specific device; e.g., the 1 in GPU 1.  A
169/// DeviceIndex is not independently meaningful without knowing the DeviceType
170/// it is associated; try to use Device rather than DeviceIndex directly.
171///
172/// Marked `repr(transparent)` because `c10::DeviceType` is really just a type
173/// alias for `int8_t`, so we want to guarantee that the representation is
174/// identical to that.
175#[derive(
176    Debug,
177    Copy,
178    Clone,
179    Serialize,
180    Deserialize,
181    PartialEq,
182    Eq,
183    Hash,
184    Into,
185    From
186)]
187#[repr(transparent)]
188pub struct DeviceIndex(pub i8);
189
190// SAFETY: Register our custom type implementation with cxx.
191unsafe impl ExternType for DeviceIndex {
192    type Id = type_id!("c10::DeviceIndex");
193    // Yes, it's trivial, it's just an i8.
194    type Kind = cxx::kind::Trivial;
195}
196
197impl FromPyObject<'_> for DeviceIndex {
198    fn extract_bound(obj: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
199        let v = obj.extract::<Option<i8>>()?;
200        Ok(DeviceIndex(v.unwrap_or(-1)))
201    }
202}
203
204impl<'py> IntoPyObject<'py> for DeviceIndex {
205    type Target = PyAny;
206    type Output = Bound<'py, Self::Target>;
207    type Error = PyErr;
208
209    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
210        self.0.into_bound_py_any(py)
211    }
212}
213
214/// Binding for `c10::Device`.
215///
216/// Represents a compute device on which a tensor is located. A device is
217/// uniquely identified by a type, which specifies the type of machine it is
218/// (e.g. CPU or CUDA GPU), and a device index or ordinal, which identifies the
219/// specific compute device when there is more than one of a certain type. The
220/// device index is optional, and in its defaulted state represents (abstractly)
221/// "the current device". Further, there are two constraints on the value of the
222/// device index, if one is explicitly stored:
223/// 1. A -1 represents the current device, a non-negative index
224///    represents a specific, concrete device,
225/// 2. When the device type is CPU, the device index must be -1 or zero.
226#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
227#[repr(C)]
228pub struct Device {
229    device_type: DeviceType,
230    index: DeviceIndex,
231}
232
233impl Device {
234    /// Create a new device of the specified type on the "current" device index.
235    pub fn new(device_type: DeviceType) -> Self {
236        Self {
237            device_type,
238            index: DeviceIndex(-1),
239        }
240    }
241
242    pub fn new_with_index(device_type: DeviceType, index: DeviceIndex) -> Self {
243        debug_assert!(
244            index.0 >= -1,
245            "Device index must be -1 or non-negative, got: {}",
246            index.0
247        );
248        debug_assert!(
249            !matches!(device_type, DeviceType::CPU) || index.0 <= 0,
250            "Device index for CPU must be -1 or 0, got: {}",
251            index.0
252        );
253        Self { device_type, index }
254    }
255
256    pub fn device_type(&self) -> DeviceType {
257        self.device_type
258    }
259
260    pub fn index(&self) -> DeviceIndex {
261        self.index
262    }
263}
264
265impl TryFrom<Device> for CudaDevice {
266    type Error = &'static str;
267    fn try_from(value: Device) -> Result<Self, Self::Error> {
268        if value.device_type() == DeviceType::CUDA {
269            Ok(CudaDevice { index: value.index })
270        } else {
271            Err("Device is not a CUDA device")
272        }
273    }
274}
275
276/// A device that is statically guaranteed to be a CUDA device.
277#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
278pub struct CudaDevice {
279    index: DeviceIndex,
280}
281
282impl CudaDevice {
283    pub fn new(index: DeviceIndex) -> Self {
284        Self { index }
285    }
286
287    pub fn index(&self) -> DeviceIndex {
288        self.index
289    }
290}
291
292impl From<CudaDevice> for Device {
293    fn from(device: CudaDevice) -> Self {
294        Device::new_with_index(DeviceType::CUDA, device.index)
295    }
296}
297
298static DEVICE_REGEX: LazyLock<Regex> =
299    LazyLock::new(|| Regex::new("([a-zA-Z_]+)(?::([1-9]\\d*|0))?").unwrap());
300
301impl TryFrom<&str> for Device {
302    type Error = DeviceParseError;
303    fn try_from(val: &str) -> Result<Device, Self::Error> {
304        let captures = DEVICE_REGEX
305            .captures(val)
306            .ok_or_else(|| DeviceParseError::ParserFailure(val.to_string()))?;
307
308        if captures.get(0).unwrap().len() != val.len() {
309            return Err(DeviceParseError::ParserFailure(val.to_string()));
310        }
311
312        let device_type: DeviceType = captures
313            .get(1)
314            .ok_or_else(|| DeviceParseError::ParserFailure(val.to_string()))?
315            .as_str()
316            .try_into()?;
317
318        let index = captures.get(2);
319        match index {
320            Some(match_) => Ok(Device::new_with_index(
321                device_type,
322                DeviceIndex(
323                    match_
324                        .as_str()
325                        .parse::<i8>()
326                        .map_err(DeviceParseError::from)?,
327                ),
328            )),
329            None => Ok(Device::new(device_type)),
330        }
331    }
332}
333
334impl TryFrom<String> for Device {
335    type Error = DeviceParseError;
336    fn try_from(val: String) -> Result<Device, Self::Error> {
337        Device::try_from(val.as_ref())
338    }
339}
340
341impl std::fmt::Display for Device {
342    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343        if self.index.0 == -1 {
344            write!(f, "{}", self.device_type)
345        } else {
346            write!(f, "{}:{}", self.device_type, self.index.0)
347        }
348    }
349}
350
351// SAFETY: Register our custom type implementation with cxx.
352unsafe impl ExternType for Device {
353    type Id = type_id!("c10::Device");
354    // Yes, it's trivial, it's just two i8s.
355    type Kind = cxx::kind::Trivial;
356}
357
358impl FromPyObject<'_> for Device {
359    fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
360        ffi::device_from_py_object(obj.into()).map_err(|e| {
361            PyValueError::new_err(format!(
362                "Failed extracting {} from py as Device: {}",
363                obj, e
364            ))
365        })
366    }
367}
368
369impl<'py> IntoPyObject<'py> for Device {
370    type Target = PyAny;
371    type Output = Bound<'py, Self::Target>;
372    type Error = PyErr;
373
374    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
375        ffi::device_to_py_object(self).into_pyobject(py)
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use anyhow::Result;
382
383    use super::*;
384
385    #[test]
386    fn from_str_bad() {
387        let result: Result<Device, DeviceParseError> = "a@#fij".try_into();
388        assert!(matches!(result, Err(DeviceParseError::ParserFailure(_))));
389
390        let result: Result<Device, DeviceParseError> = "asdf:4".try_into();
391        assert!(matches!(
392            result,
393            Err(DeviceParseError::InvalidDeviceType(_))
394        ));
395    }
396
397    #[test]
398    fn from_str_good() {
399        let device: Device = "cuda".try_into().unwrap();
400        assert!(matches!(device.device_type(), DeviceType::CUDA));
401        assert_eq!(device.index(), DeviceIndex(-1));
402    }
403
404    #[test]
405    fn from_str_index() {
406        let device: Device = "cuda:5".try_into().unwrap();
407        assert!(matches!(device.device_type(), DeviceType::CUDA));
408        assert_eq!(device.index(), DeviceIndex(5));
409    }
410
411    #[test]
412    fn device_type_convert_to_py_and_back() -> Result<()> {
413        pyo3::prepare_freethreaded_python();
414        let expected: DeviceType = DeviceType::CUDA;
415        let actual = Python::with_gil(|py| {
416            // import torch to ensure torch.dtype types are registered
417            py.import("torch")?;
418            let obj = expected.clone().into_pyobject(py)?;
419            obj.extract::<DeviceType>()
420        })?;
421        assert_eq!(actual, expected);
422        Ok(())
423    }
424
425    #[test]
426    fn device_index_convert_to_py_and_back() -> Result<()> {
427        pyo3::prepare_freethreaded_python();
428        let expected: DeviceIndex = 3.into();
429        let actual = Python::with_gil(|py| {
430            // import torch to ensure torch.dtype types are registered
431            py.import("torch")?;
432            let obj = expected.clone().into_pyobject(py)?;
433            obj.extract::<DeviceIndex>()
434        })?;
435        assert_eq!(actual, expected);
436        Ok(())
437    }
438
439    #[test]
440    fn device_convert_to_py_and_back() -> Result<()> {
441        pyo3::prepare_freethreaded_python();
442        let expected: Device = "cuda:2".try_into()?;
443        let actual = Python::with_gil(|py| {
444            // import torch to ensure torch.dtype types are registered
445            py.import("torch")?;
446            let obj = expected.clone().into_pyobject(py)?;
447            obj.extract::<Device>()
448        })?;
449        assert_eq!(actual, expected);
450        Ok(())
451    }
452
453    #[test]
454    fn device_from_py() -> Result<()> {
455        pyo3::prepare_freethreaded_python();
456        let expected: Device = "cuda:2".try_into()?;
457        let actual = Python::with_gil(|py| {
458            let obj = py.import("torch")?.getattr("device")?.call1(("cuda:2",))?;
459            obj.extract::<Device>()
460        })?;
461        assert_eq!(actual, expected);
462        Ok(())
463    }
464}