torch_sys_cuda/
cuda.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Bindings for torch's wrappers around CUDA-related functionality.
10use std::time::Duration;
11
12use cxx::SharedPtr;
13use cxx::UniquePtr;
14use derive_more::Into;
15use nccl_sys::cudaError_t;
16use nccl_sys::cudaSetDevice;
17use nccl_sys::cudaStream_t;
18use thiserror::Error;
19use torch_sys::CudaDevice;
20
21use crate::bridge::ffi::{self};
22
23/// Wrapper around a CUDA stream.
24///
25/// A CUDA stream is a linear sequence of execution that belongs to a specific
26/// device, independent from other streams.  See the documentation for
27/// `torch.cuda.Stream` for more details.
28#[derive(Debug, Clone, Into)]
29#[into(ref)]
30pub struct Stream {
31    inner: SharedPtr<ffi::CUDAStream>,
32}
33
34// SAFETY: CUDAStream is thread safe
35unsafe impl Send for Stream {}
36// SAFETY: see above
37unsafe impl Sync for Stream {}
38
39impl Stream {
40    /// Create a new stream on the current device, at priority 0.
41    pub fn new() -> Self {
42        Self {
43            inner: ffi::create_stream(-1, 0),
44        }
45    }
46
47    /// Create a new stream on the specified device, at priority 0.
48    pub fn new_with_device(device: CudaDevice) -> Self {
49        Self {
50            inner: ffi::create_stream(device.index().into(), 0),
51        }
52    }
53
54    /// Get the current stream on the current device.
55    pub fn get_current_stream() -> Self {
56        Self {
57            inner: ffi::get_current_stream(-1),
58        }
59    }
60    /// Get the current stream on the specified device.
61    pub fn get_current_stream_on_device(device: CudaDevice) -> Self {
62        Self {
63            inner: ffi::get_current_stream(device.index().into()),
64        }
65    }
66
67    /// Set the provided stream as the current stream. Also sets the current
68    /// device to be the same as the stream's device.
69    pub fn set_current_stream(stream: &Stream) {
70        ffi::set_current_stream(stream.as_ref())
71    }
72
73    /// Make all future work submitted to this stream wait for an event.
74    pub fn wait_event(&self, event: &mut Event) {
75        event.wait(Some(self))
76    }
77
78    /// Synchronize with another stream.
79    ///
80    /// All future work submitted to this stream will wait until all kernels
81    /// submitted to a given stream at the time of call entry complete.
82    pub fn wait_stream(&self, stream: &Stream) {
83        self.wait_event(&mut stream.record_event(None))
84    }
85
86    /// Record an event on this stream. If no event is provided one will be
87    /// created.
88    pub fn record_event(&self, event: Option<Event>) -> Event {
89        let mut event = event.unwrap_or(Event::new());
90        event.record(Some(self));
91        event
92    }
93
94    /// Check if all work submitted to this stream has completed.
95    pub fn query(&self) -> bool {
96        self.inner.query()
97    }
98
99    /// Wait for all kernels in this stream to complete.
100    pub fn synchronize(&self) {
101        self.inner.synchronize()
102    }
103
104    pub fn stream(&self) -> cudaStream_t {
105        self.inner.stream()
106    }
107}
108
109impl AsRef<ffi::CUDAStream> for Stream {
110    fn as_ref(&self) -> &ffi::CUDAStream {
111        // Fine to unwrap here, `Stream` guarantees that `inner` is never null.
112        self.inner.as_ref().unwrap()
113    }
114}
115
116impl PartialEq for Stream {
117    fn eq(&self, other: &Self) -> bool {
118        self.stream() == other.stream()
119    }
120}
121
122/// Wrapper around a CUDA event.
123///
124/// CUDA events are synchronization markers that can be used to monitor the
125/// device's progress, to accurately measure timing, and to synchronize CUDA
126/// streams.
127///
128/// The underlying CUDA events are lazily initialized when the event is first
129/// recorded or exported to another process. After creation, only streams on the
130/// same device may record the event. However, streams on any device can wait on
131/// the event.
132///
133/// See the docs of `torch.cuda.Event` for more details.
134#[derive(Debug)]
135pub struct Event {
136    inner: UniquePtr<ffi::CUDAEvent>,
137}
138
139impl Event {
140    /// Create a new event.
141    // TODO: add support for flags.
142    pub fn new() -> Self {
143        Self {
144            inner: ffi::create_cuda_event(false, false, false),
145        }
146    }
147
148    /// Record the event on the current stream.
149    ///
150    /// Uses the current stream if no stream is provided.
151    pub fn record(&mut self, stream: Option<&Stream>) {
152        match stream {
153            Some(stream) => self.inner.pin_mut().record(stream.as_ref()),
154            None => self
155                .inner
156                .pin_mut()
157                .record(Stream::get_current_stream().as_ref()),
158        }
159    }
160
161    /// Make all future work submitted to the given stream wait for this event.
162    ///
163    /// Uses the current stream if no stream is specified.
164    pub fn wait(&mut self, stream: Option<&Stream>) {
165        match stream {
166            Some(stream) => self.inner.pin_mut().block(stream.as_ref()),
167            None => self
168                .inner
169                .pin_mut()
170                .block(Stream::get_current_stream().as_ref()),
171        }
172    }
173
174    /// Check if all work currently captured by event has completed.
175    pub fn query(&self) -> bool {
176        self.inner.query()
177    }
178
179    /// Return the time elapsed.
180    ///
181    /// Time reported in after the event was recorded and before the end_event
182    /// was recorded.
183    pub fn elapsed_time(&self, end_event: &Event) -> Duration {
184        Duration::from_millis(self.inner.elapsed_time(end_event.as_ref()) as u64)
185    }
186
187    /// Wait for the event to complete.
188    /// Waits until the completion of all work currently captured in this event.
189    /// This prevents the CPU thread from proceeding until the event completes.
190    pub fn synchronize(&self) {
191        self.inner.synchronize()
192    }
193}
194
195impl AsRef<ffi::CUDAEvent> for Event {
196    fn as_ref(&self) -> &ffi::CUDAEvent {
197        // Fine to unwrap here, `Event` guarantees that `inner` is never null.
198        self.inner.as_ref().unwrap()
199    }
200}
201
202/// Corresponds to the CUDA error codes.
203#[derive(Debug, Error)]
204pub enum CudaError {
205    #[error(
206        "one or more parameters passed to the API call is not within an acceptable range of values"
207    )]
208    InvalidValue,
209    #[error("the API call failed due to insufficient memory or resources")]
210    MemoryAllocation,
211    #[error("failed to initialize the CUDA driver and runtime")]
212    InitializationError,
213    #[error("CUDA Runtime API call was executed after the CUDA driver has been unloaded")]
214    CudartUnloading,
215    #[error("profiler is not initialized for this run, possibly due to an external profiling tool")]
216    ProfilerDisabled,
217    #[error("deprecated. Attempted to enable/disable profiling without initialization")]
218    ProfilerNotInitialized,
219    #[error("deprecated. Profiling is already started")]
220    ProfilerAlreadyStarted,
221    #[error("deprecated. Profiling is already stopped")]
222    ProfilerAlreadyStopped,
223    #[error("kernel launch requested resources that cannot be satisfied by the current device")]
224    InvalidConfiguration,
225    #[error("one or more of the pitch-related parameters passed to the API call is out of range")]
226    InvalidPitchValue,
227    #[error("the symbol name/identifier passed to the API call is invalid")]
228    InvalidSymbol,
229    #[error("the host pointer passed to the API call is invalid")]
230    InvalidHostPointer,
231    #[error("the device pointer passed to the API call is invalid")]
232    InvalidDevicePointer,
233    #[error("the texture passed to the API call is invalid")]
234    InvalidTexture,
235    #[error("the texture binding is invalid")]
236    InvalidTextureBinding,
237    #[error("the channel descriptor passed to the API call is invalid")]
238    InvalidChannelDescriptor,
239    #[error("the direction of the memcpy operation is invalid")]
240    InvalidMemcpyDirection,
241    #[error(
242        "attempted to take the address of a constant variable, which is forbidden before CUDA 3.1"
243    )]
244    AddressOfConstant,
245    #[error("deprecated. A texture fetch operation failed")]
246    TextureFetchFailed,
247    #[error("deprecated. The texture is not bound for access")]
248    TextureNotBound,
249    #[error("a synchronization operation failed")]
250    SynchronizationError,
251    #[error(
252        "a non-float texture was accessed with linear filtering, which is not supported by CUDA"
253    )]
254    InvalidFilterSetting,
255    #[error(
256        "attempted to read a non-float texture as a normalized float, which is not supported by CUDA"
257    )]
258    InvalidNormSetting,
259    #[error("the API call is not yet implemented")]
260    NotYetImplemented,
261    #[error("an emulated device pointer exceeded the 32-bit address range")]
262    MemoryValueTooLarge,
263    #[error("the CUDA driver is a stub library")]
264    StubLibrary,
265    #[error("the installed NVIDIA CUDA driver is older than the CUDA runtime library")]
266    InsufficientDriver,
267    #[error("the API call requires a newer CUDA driver")]
268    CallRequiresNewerDriver,
269    #[error("the surface passed to the API call is invalid")]
270    InvalidSurface,
271    #[error("multiple global or constant variables share the same string name")]
272    DuplicateVariableName,
273    #[error("multiple textures share the same string name")]
274    DuplicateTextureName,
275    #[error("multiple surfaces share the same string name")]
276    DuplicateSurfaceName,
277    #[error("all CUDA devices are currently busy or unavailable")]
278    DevicesUnavailable,
279    #[error("the current CUDA context is not compatible with the runtime")]
280    IncompatibleDriverContext,
281    #[error("the device function being invoked was not previously configured")]
282    MissingConfiguration,
283    #[error("a previous kernel launch failed")]
284    PriorLaunchFailure,
285    #[error(
286        "the depth of the child grid exceeded the maximum supported number of nested grid launches"
287    )]
288    LaunchMaxDepthExceeded,
289    #[error("a grid launch did not occur because file-scoped textures are unsupported")]
290    LaunchFileScopedTex,
291    #[error("a grid launch did not occur because file-scoped surfaces are unsupported")]
292    LaunchFileScopedSurf,
293    #[error("a call to cudaDeviceSynchronize failed due to exceeding the sync depth")]
294    SyncDepthExceeded,
295    #[error(
296        "a grid launch failed because the launch exceeded the limit of pending device runtime launches"
297    )]
298    LaunchPendingCountExceeded,
299    #[error(
300        "the requested device function does not exist or is not compiled for the proper device architecture"
301    )]
302    InvalidDeviceFunction,
303    #[error("no CUDA-capable devices were detected")]
304    NoDevice,
305    #[error("the device ordinal supplied does not correspond to a valid CUDA device")]
306    InvalidDevice,
307    #[error("the device does not have a valid Grid License")]
308    DeviceNotLicensed,
309    #[error("an internal startup failure occurred in the CUDA runtime")]
310    StartupFailure,
311    #[error("the device kernel image is invalid")]
312    InvalidKernelImage,
313    #[error("the device is not initialized")]
314    DeviceUninitialized,
315    #[error("the buffer object could not be mapped")]
316    MapBufferObjectFailed,
317    #[error("the buffer object could not be unmapped")]
318    UnmapBufferObjectFailed,
319    #[error("the specified array is currently mapped and cannot be destroyed")]
320    ArrayIsMapped,
321    #[error("the resource is already mapped")]
322    AlreadyMapped,
323    #[error("there is no kernel image available that is suitable for the device")]
324    NoKernelImageForDevice,
325    #[error("the resource has already been acquired")]
326    AlreadyAcquired,
327    #[error("the resource is not mapped")]
328    NotMapped,
329    #[error("the mapped resource is not available for access as an array")]
330    NotMappedAsArray,
331    #[error("the mapped resource is not available for access as a pointer")]
332    NotMappedAsPointer,
333    #[error("an uncorrectable ECC error was detected")]
334    ECCUncorrectable,
335    #[error("the specified cudaLimit is not supported by the device")]
336    UnsupportedLimit,
337    #[error("a call tried to access an exclusive-thread device that is already in use")]
338    DeviceAlreadyInUse,
339    #[error("P2P access is not supported across the given devices")]
340    PeerAccessUnsupported,
341    #[error("a PTX compilation failed")]
342    InvalidPtx,
343    #[error("an error occurred with the OpenGL or DirectX context")]
344    InvalidGraphicsContext,
345    #[error("an uncorrectable NVLink error was detected during execution")]
346    NvlinkUncorrectable,
347    #[error("the PTX JIT compiler library was not found")]
348    JitCompilerNotFound,
349    #[error("the provided PTX was compiled with an unsupported toolchain")]
350    UnsupportedPtxVersion,
351    #[error("JIT compilation was disabled")]
352    JitCompilationDisabled,
353    #[error("the provided execution affinity is not supported by the device")]
354    UnsupportedExecAffinity,
355    #[error("the operation is not permitted when the stream is capturing")]
356    StreamCaptureUnsupported,
357    #[error(
358        "the current capture sequence on the stream has been invalidated due to a previous error"
359    )]
360    StreamCaptureInvalidated,
361    #[error("a merge of two independent capture sequences was not allowed")]
362    StreamCaptureMerge,
363    #[error("the capture was not initiated in this stream")]
364    StreamCaptureUnmatched,
365    #[error("a stream capture sequence was passed to cudaStreamEndCapture in a different thread")]
366    StreamCaptureWrongThread,
367    #[error("the wait operation has timed out")]
368    Timeout,
369    #[error("an unknown internal error occurred")]
370    Unknown,
371    #[error("the API call returned a failure")]
372    ApiFailureBase,
373}
374
375pub fn cuda_check(result: cudaError_t) -> Result<(), CudaError> {
376    match result.0 {
377        0 => Ok(()),
378        1 => Err(CudaError::InvalidValue),
379        2 => Err(CudaError::MemoryAllocation),
380        3 => Err(CudaError::InitializationError),
381        4 => Err(CudaError::CudartUnloading),
382        5 => Err(CudaError::ProfilerDisabled),
383        6 => Err(CudaError::ProfilerNotInitialized),
384        7 => Err(CudaError::ProfilerAlreadyStarted),
385        8 => Err(CudaError::ProfilerAlreadyStopped),
386        9 => Err(CudaError::InvalidConfiguration),
387        12 => Err(CudaError::InvalidPitchValue),
388        13 => Err(CudaError::InvalidSymbol),
389        16 => Err(CudaError::InvalidHostPointer),
390        17 => Err(CudaError::InvalidDevicePointer),
391        18 => Err(CudaError::InvalidTexture),
392        19 => Err(CudaError::InvalidTextureBinding),
393        20 => Err(CudaError::InvalidChannelDescriptor),
394        21 => Err(CudaError::InvalidMemcpyDirection),
395        22 => Err(CudaError::AddressOfConstant),
396        23 => Err(CudaError::TextureFetchFailed),
397        24 => Err(CudaError::TextureNotBound),
398        25 => Err(CudaError::SynchronizationError),
399        26 => Err(CudaError::InvalidFilterSetting),
400        27 => Err(CudaError::InvalidNormSetting),
401        31 => Err(CudaError::NotYetImplemented),
402        32 => Err(CudaError::MemoryValueTooLarge),
403        34 => Err(CudaError::StubLibrary),
404        35 => Err(CudaError::InsufficientDriver),
405        36 => Err(CudaError::CallRequiresNewerDriver),
406        37 => Err(CudaError::InvalidSurface),
407        43 => Err(CudaError::DuplicateVariableName),
408        44 => Err(CudaError::DuplicateTextureName),
409        45 => Err(CudaError::DuplicateSurfaceName),
410        46 => Err(CudaError::DevicesUnavailable),
411        49 => Err(CudaError::IncompatibleDriverContext),
412        52 => Err(CudaError::MissingConfiguration),
413        53 => Err(CudaError::PriorLaunchFailure),
414        65 => Err(CudaError::LaunchMaxDepthExceeded),
415        66 => Err(CudaError::LaunchFileScopedTex),
416        67 => Err(CudaError::LaunchFileScopedSurf),
417        68 => Err(CudaError::SyncDepthExceeded),
418        69 => Err(CudaError::LaunchPendingCountExceeded),
419        98 => Err(CudaError::InvalidDeviceFunction),
420        100 => Err(CudaError::NoDevice),
421        101 => Err(CudaError::InvalidDevice),
422        102 => Err(CudaError::DeviceNotLicensed),
423        127 => Err(CudaError::StartupFailure),
424        200 => Err(CudaError::InvalidKernelImage),
425        201 => Err(CudaError::DeviceUninitialized),
426        205 => Err(CudaError::MapBufferObjectFailed),
427        206 => Err(CudaError::UnmapBufferObjectFailed),
428        207 => Err(CudaError::ArrayIsMapped),
429        208 => Err(CudaError::AlreadyMapped),
430        209 => Err(CudaError::NoKernelImageForDevice),
431        210 => Err(CudaError::AlreadyAcquired),
432        211 => Err(CudaError::NotMapped),
433        212 => Err(CudaError::NotMappedAsArray),
434        213 => Err(CudaError::NotMappedAsPointer),
435        214 => Err(CudaError::ECCUncorrectable),
436        215 => Err(CudaError::UnsupportedLimit),
437        216 => Err(CudaError::DeviceAlreadyInUse),
438        217 => Err(CudaError::PeerAccessUnsupported),
439        218 => Err(CudaError::InvalidPtx),
440        219 => Err(CudaError::InvalidGraphicsContext),
441        220 => Err(CudaError::NvlinkUncorrectable),
442        221 => Err(CudaError::JitCompilerNotFound),
443        222 => Err(CudaError::UnsupportedPtxVersion),
444        223 => Err(CudaError::JitCompilationDisabled),
445        224 => Err(CudaError::UnsupportedExecAffinity),
446        900 => Err(CudaError::StreamCaptureUnsupported),
447        901 => Err(CudaError::StreamCaptureInvalidated),
448        902 => Err(CudaError::StreamCaptureMerge),
449        903 => Err(CudaError::StreamCaptureUnmatched),
450        904 => Err(CudaError::StreamCaptureWrongThread),
451        909 => Err(CudaError::Timeout),
452        999 => Err(CudaError::Unknown),
453        _ => panic!("Unknown cudaError_t: {:?}", result.0),
454    }
455}
456
457pub fn set_device(device: CudaDevice) -> Result<(), CudaError> {
458    let index: i8 = device.index().into();
459    // SAFETY: intended usage of this function
460    unsafe { cuda_check(cudaSetDevice(index.into())) }
461}