Skip to main content

torch_sys_cuda/
cuda.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Bindings for torch's wrappers around CUDA-related functionality.
10use std::time::Duration;
11
12use derive_more::Into;
13use monarch_types::py_global;
14use pyo3::Py;
15use pyo3::PyAny;
16use pyo3::prelude::*;
17use thiserror::Error;
18use torch_sys2::CudaDevice;
19
20// Cached imports for torch.cuda APIs
21py_global!(cuda_stream_class, "torch.cuda", "Stream");
22py_global!(cuda_event_class, "torch.cuda", "Event");
23py_global!(cuda_current_stream, "torch.cuda", "current_stream");
24py_global!(cuda_set_stream, "torch.cuda", "set_stream");
25py_global!(cuda_current_device, "torch.cuda", "current_device");
26py_global!(cuda_set_device, "torch.cuda", "set_device");
27
28/// Wrapper around a CUDA stream.
29///
30/// A CUDA stream is a linear sequence of execution that belongs to a specific
31/// device, independent from other streams.  See the documentation for
32/// `torch.cuda.Stream` for more details.
33#[derive(Debug, Into)]
34#[into(ref)]
35pub struct Stream {
36    inner: Py<PyAny>,
37}
38
39impl Clone for Stream {
40    fn clone(&self) -> Self {
41        Python::attach(|py| Self {
42            inner: self.inner.clone_ref(py),
43        })
44    }
45}
46
47impl Stream {
48    /// Create a new stream on the current device, at priority 0.
49    pub fn new() -> Self {
50        Python::attach(|py| {
51            let stream = cuda_stream_class(py).call0().unwrap();
52            Self {
53                inner: stream.into(),
54            }
55        })
56    }
57
58    /// Clone this stream reference. Requires holding the GIL.
59    pub fn clone_ref(&self, py: Python<'_>) -> Self {
60        Self {
61            inner: self.inner.clone_ref(py),
62        }
63    }
64
65    /// Create a new stream on the specified device, at priority 0.
66    pub fn new_with_device(device: CudaDevice) -> Self {
67        Python::attach(|py| {
68            let device_idx: i8 = device.index().into();
69            let stream = cuda_stream_class(py).call1((device_idx,)).unwrap();
70            Self {
71                inner: stream.into(),
72            }
73        })
74    }
75
76    /// Get the current stream on the current device.
77    pub fn get_current_stream() -> Self {
78        Python::attach(|py| {
79            let stream = cuda_current_stream(py).call0().unwrap();
80            Self {
81                inner: stream.into(),
82            }
83        })
84    }
85
86    /// Get the current stream on the specified device.
87    pub fn get_current_stream_on_device(device: CudaDevice) -> Self {
88        Python::attach(|py| {
89            let device_idx: i8 = device.index().into();
90            let stream = cuda_current_stream(py).call1((device_idx,)).unwrap();
91            Self {
92                inner: stream.into(),
93            }
94        })
95    }
96
97    /// Set the provided stream as the current stream. Also sets the current
98    /// device to be the same as the stream's device.
99    pub fn set_current_stream(stream: &Stream) {
100        Python::attach(|py| {
101            let stream_obj = stream.inner.bind(py);
102
103            // Get current device and stream device
104            let current_device = cuda_current_device(py)
105                .call0()
106                .unwrap()
107                .extract::<i64>()
108                .unwrap();
109
110            let stream_device = stream_obj
111                .getattr("device_index")
112                .unwrap()
113                .extract::<i64>()
114                .unwrap();
115
116            // Set device if different
117            if current_device != stream_device {
118                cuda_set_device(py).call1((stream_device,)).unwrap();
119            }
120
121            // Set the stream
122            cuda_set_stream(py).call1((stream_obj,)).unwrap();
123        })
124    }
125
126    /// Make all future work submitted to this stream wait for an event.
127    pub fn wait_event(&self, event: &mut Event) {
128        event.wait(Some(self))
129    }
130
131    /// Synchronize with another stream.
132    ///
133    /// All future work submitted to this stream will wait until all kernels
134    /// submitted to a given stream at the time of call entry complete.
135    pub fn wait_stream(&self, stream: &Stream) {
136        self.wait_event(&mut stream.record_event(None))
137    }
138
139    /// Record an event on this stream. If no event is provided one will be
140    /// created.
141    pub fn record_event(&self, event: Option<Event>) -> Event {
142        let mut event = event.unwrap_or(Event::new());
143        event.record(Some(self));
144        event
145    }
146
147    /// Check if all work submitted to this stream has completed.
148    pub fn query(&self) -> bool {
149        Python::attach(|py| {
150            let stream_obj = self.inner.bind(py);
151            stream_obj
152                .call_method0("query")
153                .unwrap()
154                .extract::<bool>()
155                .unwrap()
156        })
157    }
158
159    /// Wait for all kernels in this stream to complete.
160    pub fn synchronize(&self) {
161        Python::attach(|py| {
162            let stream_obj = self.inner.bind(py);
163            stream_obj.call_method0("synchronize").unwrap();
164        })
165    }
166
167    /// Get the raw CUDA stream pointer. Only available with the `cuda` feature.
168    #[cfg(feature = "cuda")]
169    pub fn stream(&self) -> nccl_sys::cudaStream_t {
170        Python::attach(|py| {
171            let stream_obj = self.inner.bind(py);
172            let cuda_stream = stream_obj.getattr("cuda_stream").unwrap();
173
174            // Extract the raw pointer
175            let ptr = cuda_stream.extract::<usize>().unwrap();
176
177            ptr as nccl_sys::cudaStream_t
178        })
179    }
180}
181
182impl PartialEq for Stream {
183    fn eq(&self, other: &Self) -> bool {
184        #[cfg(feature = "cuda")]
185        {
186            self.stream() == other.stream()
187        }
188
189        #[cfg(not(feature = "cuda"))]
190        {
191            self.inner.as_ptr() == other.inner.as_ptr()
192        }
193    }
194}
195
196/// Wrapper around a CUDA event.
197///
198/// CUDA events are synchronization markers that can be used to monitor the
199/// device's progress, to accurately measure timing, and to synchronize CUDA
200/// streams.
201///
202/// The underlying CUDA events are lazily initialized when the event is first
203/// recorded or exported to another process. After creation, only streams on the
204/// same device may record the event. However, streams on any device can wait on
205/// the event.
206///
207/// See the docs of `torch.cuda.Event` for more details.
208#[derive(Debug)]
209pub struct Event {
210    inner: Py<PyAny>,
211}
212
213impl Clone for Event {
214    fn clone(&self) -> Self {
215        Python::attach(|py| Self {
216            inner: self.inner.clone_ref(py),
217        })
218    }
219}
220
221impl Event {
222    /// Create a new event.
223    // TODO: add support for flags.
224    pub fn new() -> Self {
225        Python::attach(|py| {
226            let event = cuda_event_class(py).call0().unwrap();
227            Self {
228                inner: event.into(),
229            }
230        })
231    }
232
233    /// Record the event on the current stream.
234    ///
235    /// Uses the current stream if no stream is provided.
236    pub fn record(&mut self, stream: Option<&Stream>) {
237        Python::attach(|py| {
238            let event_obj = self.inner.bind(py);
239
240            match stream {
241                Some(stream) => {
242                    let stream_obj = stream.inner.bind(py);
243                    event_obj.call_method1("record", (stream_obj,)).unwrap();
244                }
245                None => {
246                    event_obj.call_method0("record").unwrap();
247                }
248            }
249        })
250    }
251
252    /// Make all future work submitted to the given stream wait for this event.
253    ///
254    /// Uses the current stream if no stream is specified.
255    pub fn wait(&mut self, stream: Option<&Stream>) {
256        Python::attach(|py| {
257            let event_obj = self.inner.bind(py);
258
259            match stream {
260                Some(stream) => {
261                    let stream_obj = stream.inner.bind(py);
262                    event_obj.call_method1("wait", (stream_obj,)).unwrap();
263                }
264                None => {
265                    event_obj.call_method0("wait").unwrap();
266                }
267            }
268        })
269    }
270
271    /// Check if all work currently captured by event has completed.
272    pub fn query(&self) -> bool {
273        Python::attach(|py| {
274            let event_obj = self.inner.bind(py);
275            event_obj
276                .call_method0("query")
277                .unwrap()
278                .extract::<bool>()
279                .unwrap()
280        })
281    }
282
283    /// Return the time elapsed.
284    ///
285    /// Time reported in after the event was recorded and before the end_event
286    /// was recorded.
287    pub fn elapsed_time(&self, end_event: &Event) -> Duration {
288        Python::attach(|py| {
289            let event_obj = self.inner.bind(py);
290            let end_event_obj = end_event.inner.bind(py);
291
292            let elapsed_ms = event_obj
293                .call_method1("elapsed_time", (end_event_obj,))
294                .unwrap()
295                .extract::<f64>()
296                .unwrap();
297
298            Duration::from_millis(elapsed_ms as u64)
299        })
300    }
301
302    /// Wait for the event to complete.
303    /// Waits until the completion of all work currently captured in this event.
304    /// This prevents the CPU thread from proceeding until the event completes.
305    pub fn synchronize(&self) {
306        Python::attach(|py| {
307            let event_obj = self.inner.bind(py);
308            event_obj.call_method0("synchronize").unwrap();
309        })
310    }
311}
312
313/// Corresponds to the CUDA error codes.
314#[derive(Debug, Error)]
315pub enum CudaError {
316    #[error(
317        "one or more parameters passed to the API call is not within an acceptable range of values"
318    )]
319    InvalidValue,
320    #[error("the API call failed due to insufficient memory or resources")]
321    MemoryAllocation,
322    #[error("failed to initialize the CUDA driver and runtime")]
323    InitializationError,
324    #[error("CUDA Runtime API call was executed after the CUDA driver has been unloaded")]
325    CudartUnloading,
326    #[error("profiler is not initialized for this run, possibly due to an external profiling tool")]
327    ProfilerDisabled,
328    #[error("deprecated. Attempted to enable/disable profiling without initialization")]
329    ProfilerNotInitialized,
330    #[error("deprecated. Profiling is already started")]
331    ProfilerAlreadyStarted,
332    #[error("deprecated. Profiling is already stopped")]
333    ProfilerAlreadyStopped,
334    #[error("kernel launch requested resources that cannot be satisfied by the current device")]
335    InvalidConfiguration,
336    #[error("one or more of the pitch-related parameters passed to the API call is out of range")]
337    InvalidPitchValue,
338    #[error("the symbol name/identifier passed to the API call is invalid")]
339    InvalidSymbol,
340    #[error("the host pointer passed to the API call is invalid")]
341    InvalidHostPointer,
342    #[error("the device pointer passed to the API call is invalid")]
343    InvalidDevicePointer,
344    #[error("the texture passed to the API call is invalid")]
345    InvalidTexture,
346    #[error("the texture binding is invalid")]
347    InvalidTextureBinding,
348    #[error("the channel descriptor passed to the API call is invalid")]
349    InvalidChannelDescriptor,
350    #[error("the direction of the memcpy operation is invalid")]
351    InvalidMemcpyDirection,
352    #[error(
353        "attempted to take the address of a constant variable, which is forbidden before CUDA 3.1"
354    )]
355    AddressOfConstant,
356    #[error("deprecated. A texture fetch operation failed")]
357    TextureFetchFailed,
358    #[error("deprecated. The texture is not bound for access")]
359    TextureNotBound,
360    #[error("a synchronization operation failed")]
361    SynchronizationError,
362    #[error(
363        "a non-float texture was accessed with linear filtering, which is not supported by CUDA"
364    )]
365    InvalidFilterSetting,
366    #[error(
367        "attempted to read a non-float texture as a normalized float, which is not supported by CUDA"
368    )]
369    InvalidNormSetting,
370    #[error("the API call is not yet implemented")]
371    NotYetImplemented,
372    #[error("an emulated device pointer exceeded the 32-bit address range")]
373    MemoryValueTooLarge,
374    #[error("the CUDA driver is a stub library")]
375    StubLibrary,
376    #[error("the installed NVIDIA CUDA driver is older than the CUDA runtime library")]
377    InsufficientDriver,
378    #[error("the API call requires a newer CUDA driver")]
379    CallRequiresNewerDriver,
380    #[error("the surface passed to the API call is invalid")]
381    InvalidSurface,
382    #[error("multiple global or constant variables share the same string name")]
383    DuplicateVariableName,
384    #[error("multiple textures share the same string name")]
385    DuplicateTextureName,
386    #[error("multiple surfaces share the same string name")]
387    DuplicateSurfaceName,
388    #[error("all CUDA devices are currently busy or unavailable")]
389    DevicesUnavailable,
390    #[error("the current CUDA context is not compatible with the runtime")]
391    IncompatibleDriverContext,
392    #[error("the device function being invoked was not previously configured")]
393    MissingConfiguration,
394    #[error("a previous kernel launch failed")]
395    PriorLaunchFailure,
396    #[error(
397        "the depth of the child grid exceeded the maximum supported number of nested grid launches"
398    )]
399    LaunchMaxDepthExceeded,
400    #[error("a grid launch did not occur because file-scoped textures are unsupported")]
401    LaunchFileScopedTex,
402    #[error("a grid launch did not occur because file-scoped surfaces are unsupported")]
403    LaunchFileScopedSurf,
404    #[error("a call to cudaDeviceSynchronize failed due to exceeding the sync depth")]
405    SyncDepthExceeded,
406    #[error(
407        "a grid launch failed because the launch exceeded the limit of pending device runtime launches"
408    )]
409    LaunchPendingCountExceeded,
410    #[error(
411        "the requested device function does not exist or is not compiled for the proper device architecture"
412    )]
413    InvalidDeviceFunction,
414    #[error("no CUDA-capable devices were detected")]
415    NoDevice,
416    #[error("the device ordinal supplied does not correspond to a valid CUDA device")]
417    InvalidDevice,
418    #[error("the device does not have a valid Grid License")]
419    DeviceNotLicensed,
420    #[error("an internal startup failure occurred in the CUDA runtime")]
421    StartupFailure,
422    #[error("the device kernel image is invalid")]
423    InvalidKernelImage,
424    #[error("the device is not initialized")]
425    DeviceUninitialized,
426    #[error("the buffer object could not be mapped")]
427    MapBufferObjectFailed,
428    #[error("the buffer object could not be unmapped")]
429    UnmapBufferObjectFailed,
430    #[error("the specified array is currently mapped and cannot be destroyed")]
431    ArrayIsMapped,
432    #[error("the resource is already mapped")]
433    AlreadyMapped,
434    #[error("there is no kernel image available that is suitable for the device")]
435    NoKernelImageForDevice,
436    #[error("the resource has already been acquired")]
437    AlreadyAcquired,
438    #[error("the resource is not mapped")]
439    NotMapped,
440    #[error("the mapped resource is not available for access as an array")]
441    NotMappedAsArray,
442    #[error("the mapped resource is not available for access as a pointer")]
443    NotMappedAsPointer,
444    #[error("an uncorrectable ECC error was detected")]
445    ECCUncorrectable,
446    #[error("the specified cudaLimit is not supported by the device")]
447    UnsupportedLimit,
448    #[error("a call tried to access an exclusive-thread device that is already in use")]
449    DeviceAlreadyInUse,
450    #[error("P2P access is not supported across the given devices")]
451    PeerAccessUnsupported,
452    #[error("a PTX compilation failed")]
453    InvalidPtx,
454    #[error("an error occurred with the OpenGL or DirectX context")]
455    InvalidGraphicsContext,
456    #[error("an uncorrectable NVLink error was detected during execution")]
457    NvlinkUncorrectable,
458    #[error("the PTX JIT compiler library was not found")]
459    JitCompilerNotFound,
460    #[error("the provided PTX was compiled with an unsupported toolchain")]
461    UnsupportedPtxVersion,
462    #[error("JIT compilation was disabled")]
463    JitCompilationDisabled,
464    #[error("the provided execution affinity is not supported by the device")]
465    UnsupportedExecAffinity,
466    #[error("the operation is not permitted when the stream is capturing")]
467    StreamCaptureUnsupported,
468    #[error(
469        "the current capture sequence on the stream has been invalidated due to a previous error"
470    )]
471    StreamCaptureInvalidated,
472    #[error("a merge of two independent capture sequences was not allowed")]
473    StreamCaptureMerge,
474    #[error("the capture was not initiated in this stream")]
475    StreamCaptureUnmatched,
476    #[error("a stream capture sequence was passed to cudaStreamEndCapture in a different thread")]
477    StreamCaptureWrongThread,
478    #[error("the wait operation has timed out")]
479    Timeout,
480    #[error("an unknown internal error occurred")]
481    Unknown,
482    #[error("the API call returned a failure")]
483    ApiFailureBase,
484}
485
486#[cfg(feature = "cuda")]
487pub fn cuda_check(result: nccl_sys::cudaError_t) -> Result<(), CudaError> {
488    match result.0 {
489        0 => Ok(()),
490        1 => Err(CudaError::InvalidValue),
491        2 => Err(CudaError::MemoryAllocation),
492        3 => Err(CudaError::InitializationError),
493        4 => Err(CudaError::CudartUnloading),
494        5 => Err(CudaError::ProfilerDisabled),
495        6 => Err(CudaError::ProfilerNotInitialized),
496        7 => Err(CudaError::ProfilerAlreadyStarted),
497        8 => Err(CudaError::ProfilerAlreadyStopped),
498        9 => Err(CudaError::InvalidConfiguration),
499        12 => Err(CudaError::InvalidPitchValue),
500        13 => Err(CudaError::InvalidSymbol),
501        16 => Err(CudaError::InvalidHostPointer),
502        17 => Err(CudaError::InvalidDevicePointer),
503        18 => Err(CudaError::InvalidTexture),
504        19 => Err(CudaError::InvalidTextureBinding),
505        20 => Err(CudaError::InvalidChannelDescriptor),
506        21 => Err(CudaError::InvalidMemcpyDirection),
507        22 => Err(CudaError::AddressOfConstant),
508        23 => Err(CudaError::TextureFetchFailed),
509        24 => Err(CudaError::TextureNotBound),
510        25 => Err(CudaError::SynchronizationError),
511        26 => Err(CudaError::InvalidFilterSetting),
512        27 => Err(CudaError::InvalidNormSetting),
513        31 => Err(CudaError::NotYetImplemented),
514        32 => Err(CudaError::MemoryValueTooLarge),
515        34 => Err(CudaError::StubLibrary),
516        35 => Err(CudaError::InsufficientDriver),
517        36 => Err(CudaError::CallRequiresNewerDriver),
518        37 => Err(CudaError::InvalidSurface),
519        43 => Err(CudaError::DuplicateVariableName),
520        44 => Err(CudaError::DuplicateTextureName),
521        45 => Err(CudaError::DuplicateSurfaceName),
522        46 => Err(CudaError::DevicesUnavailable),
523        49 => Err(CudaError::IncompatibleDriverContext),
524        52 => Err(CudaError::MissingConfiguration),
525        53 => Err(CudaError::PriorLaunchFailure),
526        65 => Err(CudaError::LaunchMaxDepthExceeded),
527        66 => Err(CudaError::LaunchFileScopedTex),
528        67 => Err(CudaError::LaunchFileScopedSurf),
529        68 => Err(CudaError::SyncDepthExceeded),
530        69 => Err(CudaError::LaunchPendingCountExceeded),
531        98 => Err(CudaError::InvalidDeviceFunction),
532        100 => Err(CudaError::NoDevice),
533        101 => Err(CudaError::InvalidDevice),
534        102 => Err(CudaError::DeviceNotLicensed),
535        127 => Err(CudaError::StartupFailure),
536        200 => Err(CudaError::InvalidKernelImage),
537        201 => Err(CudaError::DeviceUninitialized),
538        205 => Err(CudaError::MapBufferObjectFailed),
539        206 => Err(CudaError::UnmapBufferObjectFailed),
540        207 => Err(CudaError::ArrayIsMapped),
541        208 => Err(CudaError::AlreadyMapped),
542        209 => Err(CudaError::NoKernelImageForDevice),
543        210 => Err(CudaError::AlreadyAcquired),
544        211 => Err(CudaError::NotMapped),
545        212 => Err(CudaError::NotMappedAsArray),
546        213 => Err(CudaError::NotMappedAsPointer),
547        214 => Err(CudaError::ECCUncorrectable),
548        215 => Err(CudaError::UnsupportedLimit),
549        216 => Err(CudaError::DeviceAlreadyInUse),
550        217 => Err(CudaError::PeerAccessUnsupported),
551        218 => Err(CudaError::InvalidPtx),
552        219 => Err(CudaError::InvalidGraphicsContext),
553        220 => Err(CudaError::NvlinkUncorrectable),
554        221 => Err(CudaError::JitCompilerNotFound),
555        222 => Err(CudaError::UnsupportedPtxVersion),
556        223 => Err(CudaError::JitCompilationDisabled),
557        224 => Err(CudaError::UnsupportedExecAffinity),
558        900 => Err(CudaError::StreamCaptureUnsupported),
559        901 => Err(CudaError::StreamCaptureInvalidated),
560        902 => Err(CudaError::StreamCaptureMerge),
561        903 => Err(CudaError::StreamCaptureUnmatched),
562        904 => Err(CudaError::StreamCaptureWrongThread),
563        909 => Err(CudaError::Timeout),
564        999 => Err(CudaError::Unknown),
565        _ => panic!("Unknown cudaError_t: {:?}", result.0),
566    }
567}
568
569#[cfg(feature = "cuda")]
570pub fn set_device(device: CudaDevice) -> Result<(), CudaError> {
571    let index: i8 = device.index().into();
572    // SAFETY: intended usage of this function
573    unsafe { cuda_check(nccl_sys::cudaSetDevice(index.into())) }
574}
575
576#[cfg(not(feature = "cuda"))]
577pub fn set_device(_device: CudaDevice) -> Result<(), CudaError> {
578    panic!("set_device requires the `cuda` feature (CUDA/ROCm runtime)")
579}