1use std::sync::LazyLock;
12
13use cxx::ExternType;
14use cxx::type_id;
15use derive_more::From;
16use derive_more::Into;
17use pyo3::IntoPyObjectExt;
18use pyo3::exceptions::PyValueError;
19use pyo3::prelude::*;
20use regex::Regex;
21use serde::Deserialize;
22use serde::Serialize;
23use thiserror::Error;
24
25use crate::bridge::ffi;
26
27#[derive(Error, Debug)]
29#[non_exhaustive]
30pub enum DeviceParseError {
31 #[error("invalid device type specified: {0}")]
32 InvalidDeviceType(String),
33
34 #[error("invalid device index specified: {0}")]
35 InvalidDeviceIndex(#[from] std::num::ParseIntError),
36
37 #[error("invalid device string: {0}")]
38 ParserFailure(String),
39}
40
41#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
51#[repr(transparent)]
52pub struct DeviceType {
53 pub repr: i8,
54}
55
56#[allow(dead_code)]
57#[allow(non_upper_case_globals)]
58impl DeviceType {
59 pub const CPU: Self = DeviceType { repr: 0 };
60 pub const CUDA: Self = DeviceType { repr: 1 };
61 pub const MKLDNN: Self = DeviceType { repr: 2 };
62 pub const OPENGL: Self = DeviceType { repr: 3 };
63 pub const OPENCL: Self = DeviceType { repr: 4 };
64 pub const IDEEP: Self = DeviceType { repr: 5 };
65 pub const HIP: Self = DeviceType { repr: 6 };
66 pub const FPGA: Self = DeviceType { repr: 7 };
67 pub const MAIA: Self = DeviceType { repr: 8 };
68 pub const XLA: Self = DeviceType { repr: 9 };
69 pub const Vulkan: Self = DeviceType { repr: 10 };
70 pub const Metal: Self = DeviceType { repr: 11 };
71 pub const XPU: Self = DeviceType { repr: 12 };
72 pub const MPS: Self = DeviceType { repr: 13 };
73 pub const Meta: Self = DeviceType { repr: 14 };
74 pub const HPU: Self = DeviceType { repr: 15 };
75 pub const VE: Self = DeviceType { repr: 16 };
76 pub const Lazy: Self = DeviceType { repr: 17 };
77 pub const IPU: Self = DeviceType { repr: 18 };
78 pub const MTIA: Self = DeviceType { repr: 19 };
79 pub const PrivateUse1: Self = DeviceType { repr: 20 };
80 pub const CompileTimeMaxDeviceTypes: Self = DeviceType { repr: 21 };
81}
82
83impl TryFrom<&str> for DeviceType {
84 type Error = DeviceParseError;
85 fn try_from(val: &str) -> Result<DeviceType, Self::Error> {
86 Ok(match val {
87 "cpu" => DeviceType::CPU,
88 "cuda" => DeviceType::CUDA,
89 "ipu" => DeviceType::IPU,
90 "xpu" => DeviceType::XPU,
91 "mkldnn" => DeviceType::MKLDNN,
92 "opengl" => DeviceType::OPENGL,
93 "opencl" => DeviceType::OPENCL,
94 "ideep" => DeviceType::IDEEP,
95 "hip" => DeviceType::HIP,
96 "ve" => DeviceType::VE,
97 "fpga" => DeviceType::FPGA,
98 "maia" => DeviceType::MAIA,
99 "xla" => DeviceType::XLA,
100 "lazy" => DeviceType::Lazy,
101 "vulkan" => DeviceType::Vulkan,
102 "mps" => DeviceType::MPS,
103 "meta" => DeviceType::Meta,
104 "hpu" => DeviceType::HPU,
105 "mtia" => DeviceType::MTIA,
106 "privateuseone" => DeviceType::PrivateUse1,
107 _ => return Err(DeviceParseError::InvalidDeviceType(val.to_string())),
108 })
109 }
110}
111
112impl std::fmt::Display for DeviceType {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 match *self {
115 DeviceType::CPU => write!(f, "cpu"),
116 DeviceType::CUDA => write!(f, "cuda"),
117 DeviceType::IPU => write!(f, "ipu"),
118 DeviceType::XPU => write!(f, "xpu"),
119 DeviceType::MKLDNN => write!(f, "mkldnn"),
120 DeviceType::OPENGL => write!(f, "opengl"),
121 DeviceType::OPENCL => write!(f, "opencl"),
122 DeviceType::IDEEP => write!(f, "ideep"),
123 DeviceType::HIP => write!(f, "hip"),
124 DeviceType::VE => write!(f, "ve"),
125 DeviceType::FPGA => write!(f, "fpga"),
126 DeviceType::MAIA => write!(f, "maia"),
127 DeviceType::XLA => write!(f, "xla"),
128 DeviceType::Lazy => write!(f, "lazy"),
129 DeviceType::Vulkan => write!(f, "vulkan"),
130 DeviceType::MPS => write!(f, "mps"),
131 DeviceType::Meta => write!(f, "meta"),
132 DeviceType::HPU => write!(f, "hpu"),
133 DeviceType::MTIA => write!(f, "mtia"),
134 DeviceType::PrivateUse1 => write!(f, "privateuseone"),
135 _ => write!(f, "unknown"),
136 }
137 }
138}
139
140unsafe impl ExternType for DeviceType {
142 type Id = type_id!("c10::DeviceType");
143 type Kind = cxx::kind::Trivial;
145}
146
147impl FromPyObject<'_> for DeviceType {
148 fn extract_bound(obj: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
149 obj.extract::<String>()?
150 .as_str()
151 .try_into()
152 .map_err(|e| PyErr::new::<PyValueError, _>(format!("Failed extracting from py: {}", e)))
153 }
154}
155
156impl<'py> IntoPyObject<'py> for DeviceType {
157 type Target = PyAny;
158 type Output = Bound<'py, Self::Target>;
159 type Error = PyErr;
160
161 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
162 format!("{}", self).into_bound_py_any(py)
163 }
164}
165
166#[derive(
176 Debug,
177 Copy,
178 Clone,
179 Serialize,
180 Deserialize,
181 PartialEq,
182 Eq,
183 Hash,
184 Into,
185 From
186)]
187#[repr(transparent)]
188pub struct DeviceIndex(pub i8);
189
190unsafe impl ExternType for DeviceIndex {
192 type Id = type_id!("c10::DeviceIndex");
193 type Kind = cxx::kind::Trivial;
195}
196
197impl FromPyObject<'_> for DeviceIndex {
198 fn extract_bound(obj: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
199 let v = obj.extract::<Option<i8>>()?;
200 Ok(DeviceIndex(v.unwrap_or(-1)))
201 }
202}
203
204impl<'py> IntoPyObject<'py> for DeviceIndex {
205 type Target = PyAny;
206 type Output = Bound<'py, Self::Target>;
207 type Error = PyErr;
208
209 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
210 self.0.into_bound_py_any(py)
211 }
212}
213
214#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
227#[repr(C)]
228pub struct Device {
229 device_type: DeviceType,
230 index: DeviceIndex,
231}
232
233impl Device {
234 pub fn new(device_type: DeviceType) -> Self {
236 Self {
237 device_type,
238 index: DeviceIndex(-1),
239 }
240 }
241
242 pub fn new_with_index(device_type: DeviceType, index: DeviceIndex) -> Self {
243 debug_assert!(
244 index.0 >= -1,
245 "Device index must be -1 or non-negative, got: {}",
246 index.0
247 );
248 debug_assert!(
249 !matches!(device_type, DeviceType::CPU) || index.0 <= 0,
250 "Device index for CPU must be -1 or 0, got: {}",
251 index.0
252 );
253 Self { device_type, index }
254 }
255
256 pub fn device_type(&self) -> DeviceType {
257 self.device_type
258 }
259
260 pub fn index(&self) -> DeviceIndex {
261 self.index
262 }
263}
264
265impl TryFrom<Device> for CudaDevice {
266 type Error = &'static str;
267 fn try_from(value: Device) -> Result<Self, Self::Error> {
268 if value.device_type() == DeviceType::CUDA {
269 Ok(CudaDevice { index: value.index })
270 } else {
271 Err("Device is not a CUDA device")
272 }
273 }
274}
275
276#[derive(Debug, Copy, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
278pub struct CudaDevice {
279 index: DeviceIndex,
280}
281
282impl CudaDevice {
283 pub fn new(index: DeviceIndex) -> Self {
284 Self { index }
285 }
286
287 pub fn index(&self) -> DeviceIndex {
288 self.index
289 }
290}
291
292impl From<CudaDevice> for Device {
293 fn from(device: CudaDevice) -> Self {
294 Device::new_with_index(DeviceType::CUDA, device.index)
295 }
296}
297
298static DEVICE_REGEX: LazyLock<Regex> =
299 LazyLock::new(|| Regex::new("([a-zA-Z_]+)(?::([1-9]\\d*|0))?").unwrap());
300
301impl TryFrom<&str> for Device {
302 type Error = DeviceParseError;
303 fn try_from(val: &str) -> Result<Device, Self::Error> {
304 let captures = DEVICE_REGEX
305 .captures(val)
306 .ok_or_else(|| DeviceParseError::ParserFailure(val.to_string()))?;
307
308 if captures.get(0).unwrap().len() != val.len() {
309 return Err(DeviceParseError::ParserFailure(val.to_string()));
310 }
311
312 let device_type: DeviceType = captures
313 .get(1)
314 .ok_or_else(|| DeviceParseError::ParserFailure(val.to_string()))?
315 .as_str()
316 .try_into()?;
317
318 let index = captures.get(2);
319 match index {
320 Some(match_) => Ok(Device::new_with_index(
321 device_type,
322 DeviceIndex(
323 match_
324 .as_str()
325 .parse::<i8>()
326 .map_err(DeviceParseError::from)?,
327 ),
328 )),
329 None => Ok(Device::new(device_type)),
330 }
331 }
332}
333
334impl TryFrom<String> for Device {
335 type Error = DeviceParseError;
336 fn try_from(val: String) -> Result<Device, Self::Error> {
337 Device::try_from(val.as_ref())
338 }
339}
340
341impl std::fmt::Display for Device {
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 if self.index.0 == -1 {
344 write!(f, "{}", self.device_type)
345 } else {
346 write!(f, "{}:{}", self.device_type, self.index.0)
347 }
348 }
349}
350
351unsafe impl ExternType for Device {
353 type Id = type_id!("c10::Device");
354 type Kind = cxx::kind::Trivial;
356}
357
358impl FromPyObject<'_> for Device {
359 fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
360 ffi::device_from_py_object(obj.into()).map_err(|e| {
361 PyValueError::new_err(format!(
362 "Failed extracting {} from py as Device: {}",
363 obj, e
364 ))
365 })
366 }
367}
368
369impl<'py> IntoPyObject<'py> for Device {
370 type Target = PyAny;
371 type Output = Bound<'py, Self::Target>;
372 type Error = PyErr;
373
374 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
375 ffi::device_to_py_object(self).into_pyobject(py)
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use anyhow::Result;
382
383 use super::*;
384
385 #[test]
386 fn from_str_bad() {
387 let result: Result<Device, DeviceParseError> = "a@#fij".try_into();
388 assert!(matches!(result, Err(DeviceParseError::ParserFailure(_))));
389
390 let result: Result<Device, DeviceParseError> = "asdf:4".try_into();
391 assert!(matches!(
392 result,
393 Err(DeviceParseError::InvalidDeviceType(_))
394 ));
395 }
396
397 #[test]
398 fn from_str_good() {
399 let device: Device = "cuda".try_into().unwrap();
400 assert!(matches!(device.device_type(), DeviceType::CUDA));
401 assert_eq!(device.index(), DeviceIndex(-1));
402 }
403
404 #[test]
405 fn from_str_index() {
406 let device: Device = "cuda:5".try_into().unwrap();
407 assert!(matches!(device.device_type(), DeviceType::CUDA));
408 assert_eq!(device.index(), DeviceIndex(5));
409 }
410
411 #[test]
412 fn device_type_convert_to_py_and_back() -> Result<()> {
413 pyo3::prepare_freethreaded_python();
414 let expected: DeviceType = DeviceType::CUDA;
415 let actual = Python::with_gil(|py| {
416 py.import("torch")?;
418 let obj = expected.clone().into_pyobject(py)?;
419 obj.extract::<DeviceType>()
420 })?;
421 assert_eq!(actual, expected);
422 Ok(())
423 }
424
425 #[test]
426 fn device_index_convert_to_py_and_back() -> Result<()> {
427 pyo3::prepare_freethreaded_python();
428 let expected: DeviceIndex = 3.into();
429 let actual = Python::with_gil(|py| {
430 py.import("torch")?;
432 let obj = expected.clone().into_pyobject(py)?;
433 obj.extract::<DeviceIndex>()
434 })?;
435 assert_eq!(actual, expected);
436 Ok(())
437 }
438
439 #[test]
440 fn device_convert_to_py_and_back() -> Result<()> {
441 pyo3::prepare_freethreaded_python();
442 let expected: Device = "cuda:2".try_into()?;
443 let actual = Python::with_gil(|py| {
444 py.import("torch")?;
446 let obj = expected.clone().into_pyobject(py)?;
447 obj.extract::<Device>()
448 })?;
449 assert_eq!(actual, expected);
450 Ok(())
451 }
452
453 #[test]
454 fn device_from_py() -> Result<()> {
455 pyo3::prepare_freethreaded_python();
456 let expected: Device = "cuda:2".try_into()?;
457 let actual = Python::with_gil(|py| {
458 let obj = py.import("torch")?.getattr("device")?.call1(("cuda:2",))?;
459 obj.extract::<Device>()
460 })?;
461 assert_eq!(actual, expected);
462 Ok(())
463 }
464}