torch_sys_cuda/
cuda.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Bindings for torch's wrappers around CUDA-related functionality.
10use std::time::Duration;
11
12use derive_more::Into;
13use monarch_types::py_global;
14use nccl_sys::cudaError_t;
15use nccl_sys::cudaSetDevice;
16use nccl_sys::cudaStream_t;
17use pyo3::Py;
18use pyo3::PyAny;
19use pyo3::prelude::*;
20use thiserror::Error;
21use torch_sys2::CudaDevice;
22
23// Cached imports for torch.cuda APIs
24py_global!(cuda_stream_class, "torch.cuda", "Stream");
25py_global!(cuda_event_class, "torch.cuda", "Event");
26py_global!(cuda_current_stream, "torch.cuda", "current_stream");
27py_global!(cuda_set_stream, "torch.cuda", "set_stream");
28py_global!(cuda_current_device, "torch.cuda", "current_device");
29py_global!(cuda_set_device, "torch.cuda", "set_device");
30
31/// Wrapper around a CUDA stream.
32///
33/// A CUDA stream is a linear sequence of execution that belongs to a specific
34/// device, independent from other streams.  See the documentation for
35/// `torch.cuda.Stream` for more details.
36#[derive(Debug, Into)]
37#[into(ref)]
38pub struct Stream {
39    inner: Py<PyAny>,
40}
41
42impl Clone for Stream {
43    fn clone(&self) -> Self {
44        Python::attach(|py| Self {
45            inner: self.inner.clone_ref(py),
46        })
47    }
48}
49
50impl Stream {
51    /// Create a new stream on the current device, at priority 0.
52    pub fn new() -> Self {
53        Python::attach(|py| {
54            let stream = cuda_stream_class(py).call0().unwrap();
55            Self {
56                inner: stream.into(),
57            }
58        })
59    }
60
61    /// Clone this stream reference. Requires holding the GIL.
62    pub fn clone_ref(&self, py: Python<'_>) -> Self {
63        Self {
64            inner: self.inner.clone_ref(py),
65        }
66    }
67
68    /// Create a new stream on the specified device, at priority 0.
69    pub fn new_with_device(device: CudaDevice) -> Self {
70        Python::attach(|py| {
71            let device_idx: i8 = device.index().into();
72            let stream = cuda_stream_class(py).call1((device_idx,)).unwrap();
73            Self {
74                inner: stream.into(),
75            }
76        })
77    }
78
79    /// Get the current stream on the current device.
80    pub fn get_current_stream() -> Self {
81        Python::attach(|py| {
82            let stream = cuda_current_stream(py).call0().unwrap();
83            Self {
84                inner: stream.into(),
85            }
86        })
87    }
88
89    /// Get the current stream on the specified device.
90    pub fn get_current_stream_on_device(device: CudaDevice) -> Self {
91        Python::attach(|py| {
92            let device_idx: i8 = device.index().into();
93            let stream = cuda_current_stream(py).call1((device_idx,)).unwrap();
94            Self {
95                inner: stream.into(),
96            }
97        })
98    }
99
100    /// Set the provided stream as the current stream. Also sets the current
101    /// device to be the same as the stream's device.
102    pub fn set_current_stream(stream: &Stream) {
103        Python::attach(|py| {
104            let stream_obj = stream.inner.bind(py);
105
106            // Get current device and stream device
107            let current_device = cuda_current_device(py)
108                .call0()
109                .unwrap()
110                .extract::<i64>()
111                .unwrap();
112
113            let stream_device = stream_obj
114                .getattr("device_index")
115                .unwrap()
116                .extract::<i64>()
117                .unwrap();
118
119            // Set device if different
120            if current_device != stream_device {
121                cuda_set_device(py).call1((stream_device,)).unwrap();
122            }
123
124            // Set the stream
125            cuda_set_stream(py).call1((stream_obj,)).unwrap();
126        })
127    }
128
129    /// Make all future work submitted to this stream wait for an event.
130    pub fn wait_event(&self, event: &mut Event) {
131        event.wait(Some(self))
132    }
133
134    /// Synchronize with another stream.
135    ///
136    /// All future work submitted to this stream will wait until all kernels
137    /// submitted to a given stream at the time of call entry complete.
138    pub fn wait_stream(&self, stream: &Stream) {
139        self.wait_event(&mut stream.record_event(None))
140    }
141
142    /// Record an event on this stream. If no event is provided one will be
143    /// created.
144    pub fn record_event(&self, event: Option<Event>) -> Event {
145        let mut event = event.unwrap_or(Event::new());
146        event.record(Some(self));
147        event
148    }
149
150    /// Check if all work submitted to this stream has completed.
151    pub fn query(&self) -> bool {
152        Python::attach(|py| {
153            let stream_obj = self.inner.bind(py);
154            stream_obj
155                .call_method0("query")
156                .unwrap()
157                .extract::<bool>()
158                .unwrap()
159        })
160    }
161
162    /// Wait for all kernels in this stream to complete.
163    pub fn synchronize(&self) {
164        Python::attach(|py| {
165            let stream_obj = self.inner.bind(py);
166            stream_obj.call_method0("synchronize").unwrap();
167        })
168    }
169
170    pub fn stream(&self) -> cudaStream_t {
171        Python::attach(|py| {
172            let stream_obj = self.inner.bind(py);
173            let cuda_stream = stream_obj.getattr("cuda_stream").unwrap();
174
175            // Extract the raw pointer
176            let ptr = cuda_stream.extract::<usize>().unwrap();
177
178            ptr as cudaStream_t
179        })
180    }
181}
182
183impl PartialEq for Stream {
184    fn eq(&self, other: &Self) -> bool {
185        self.stream() == other.stream()
186    }
187}
188
189/// Wrapper around a CUDA event.
190///
191/// CUDA events are synchronization markers that can be used to monitor the
192/// device's progress, to accurately measure timing, and to synchronize CUDA
193/// streams.
194///
195/// The underlying CUDA events are lazily initialized when the event is first
196/// recorded or exported to another process. After creation, only streams on the
197/// same device may record the event. However, streams on any device can wait on
198/// the event.
199///
200/// See the docs of `torch.cuda.Event` for more details.
201#[derive(Debug)]
202pub struct Event {
203    inner: Py<PyAny>,
204}
205
206impl Clone for Event {
207    fn clone(&self) -> Self {
208        Python::attach(|py| Self {
209            inner: self.inner.clone_ref(py),
210        })
211    }
212}
213
214impl Event {
215    /// Create a new event.
216    // TODO: add support for flags.
217    pub fn new() -> Self {
218        Python::attach(|py| {
219            let event = cuda_event_class(py).call0().unwrap();
220            Self {
221                inner: event.into(),
222            }
223        })
224    }
225
226    /// Record the event on the current stream.
227    ///
228    /// Uses the current stream if no stream is provided.
229    pub fn record(&mut self, stream: Option<&Stream>) {
230        Python::attach(|py| {
231            let event_obj = self.inner.bind(py);
232
233            match stream {
234                Some(stream) => {
235                    let stream_obj = stream.inner.bind(py);
236                    event_obj.call_method1("record", (stream_obj,)).unwrap();
237                }
238                None => {
239                    event_obj.call_method0("record").unwrap();
240                }
241            }
242        })
243    }
244
245    /// Make all future work submitted to the given stream wait for this event.
246    ///
247    /// Uses the current stream if no stream is specified.
248    pub fn wait(&mut self, stream: Option<&Stream>) {
249        Python::attach(|py| {
250            let event_obj = self.inner.bind(py);
251
252            match stream {
253                Some(stream) => {
254                    let stream_obj = stream.inner.bind(py);
255                    event_obj.call_method1("wait", (stream_obj,)).unwrap();
256                }
257                None => {
258                    event_obj.call_method0("wait").unwrap();
259                }
260            }
261        })
262    }
263
264    /// Check if all work currently captured by event has completed.
265    pub fn query(&self) -> bool {
266        Python::attach(|py| {
267            let event_obj = self.inner.bind(py);
268            event_obj
269                .call_method0("query")
270                .unwrap()
271                .extract::<bool>()
272                .unwrap()
273        })
274    }
275
276    /// Return the time elapsed.
277    ///
278    /// Time reported in after the event was recorded and before the end_event
279    /// was recorded.
280    pub fn elapsed_time(&self, end_event: &Event) -> Duration {
281        Python::attach(|py| {
282            let event_obj = self.inner.bind(py);
283            let end_event_obj = end_event.inner.bind(py);
284
285            let elapsed_ms = event_obj
286                .call_method1("elapsed_time", (end_event_obj,))
287                .unwrap()
288                .extract::<f64>()
289                .unwrap();
290
291            Duration::from_millis(elapsed_ms as u64)
292        })
293    }
294
295    /// Wait for the event to complete.
296    /// Waits until the completion of all work currently captured in this event.
297    /// This prevents the CPU thread from proceeding until the event completes.
298    pub fn synchronize(&self) {
299        Python::attach(|py| {
300            let event_obj = self.inner.bind(py);
301            event_obj.call_method0("synchronize").unwrap();
302        })
303    }
304}
305
306/// Corresponds to the CUDA error codes.
307#[derive(Debug, Error)]
308pub enum CudaError {
309    #[error(
310        "one or more parameters passed to the API call is not within an acceptable range of values"
311    )]
312    InvalidValue,
313    #[error("the API call failed due to insufficient memory or resources")]
314    MemoryAllocation,
315    #[error("failed to initialize the CUDA driver and runtime")]
316    InitializationError,
317    #[error("CUDA Runtime API call was executed after the CUDA driver has been unloaded")]
318    CudartUnloading,
319    #[error("profiler is not initialized for this run, possibly due to an external profiling tool")]
320    ProfilerDisabled,
321    #[error("deprecated. Attempted to enable/disable profiling without initialization")]
322    ProfilerNotInitialized,
323    #[error("deprecated. Profiling is already started")]
324    ProfilerAlreadyStarted,
325    #[error("deprecated. Profiling is already stopped")]
326    ProfilerAlreadyStopped,
327    #[error("kernel launch requested resources that cannot be satisfied by the current device")]
328    InvalidConfiguration,
329    #[error("one or more of the pitch-related parameters passed to the API call is out of range")]
330    InvalidPitchValue,
331    #[error("the symbol name/identifier passed to the API call is invalid")]
332    InvalidSymbol,
333    #[error("the host pointer passed to the API call is invalid")]
334    InvalidHostPointer,
335    #[error("the device pointer passed to the API call is invalid")]
336    InvalidDevicePointer,
337    #[error("the texture passed to the API call is invalid")]
338    InvalidTexture,
339    #[error("the texture binding is invalid")]
340    InvalidTextureBinding,
341    #[error("the channel descriptor passed to the API call is invalid")]
342    InvalidChannelDescriptor,
343    #[error("the direction of the memcpy operation is invalid")]
344    InvalidMemcpyDirection,
345    #[error(
346        "attempted to take the address of a constant variable, which is forbidden before CUDA 3.1"
347    )]
348    AddressOfConstant,
349    #[error("deprecated. A texture fetch operation failed")]
350    TextureFetchFailed,
351    #[error("deprecated. The texture is not bound for access")]
352    TextureNotBound,
353    #[error("a synchronization operation failed")]
354    SynchronizationError,
355    #[error(
356        "a non-float texture was accessed with linear filtering, which is not supported by CUDA"
357    )]
358    InvalidFilterSetting,
359    #[error(
360        "attempted to read a non-float texture as a normalized float, which is not supported by CUDA"
361    )]
362    InvalidNormSetting,
363    #[error("the API call is not yet implemented")]
364    NotYetImplemented,
365    #[error("an emulated device pointer exceeded the 32-bit address range")]
366    MemoryValueTooLarge,
367    #[error("the CUDA driver is a stub library")]
368    StubLibrary,
369    #[error("the installed NVIDIA CUDA driver is older than the CUDA runtime library")]
370    InsufficientDriver,
371    #[error("the API call requires a newer CUDA driver")]
372    CallRequiresNewerDriver,
373    #[error("the surface passed to the API call is invalid")]
374    InvalidSurface,
375    #[error("multiple global or constant variables share the same string name")]
376    DuplicateVariableName,
377    #[error("multiple textures share the same string name")]
378    DuplicateTextureName,
379    #[error("multiple surfaces share the same string name")]
380    DuplicateSurfaceName,
381    #[error("all CUDA devices are currently busy or unavailable")]
382    DevicesUnavailable,
383    #[error("the current CUDA context is not compatible with the runtime")]
384    IncompatibleDriverContext,
385    #[error("the device function being invoked was not previously configured")]
386    MissingConfiguration,
387    #[error("a previous kernel launch failed")]
388    PriorLaunchFailure,
389    #[error(
390        "the depth of the child grid exceeded the maximum supported number of nested grid launches"
391    )]
392    LaunchMaxDepthExceeded,
393    #[error("a grid launch did not occur because file-scoped textures are unsupported")]
394    LaunchFileScopedTex,
395    #[error("a grid launch did not occur because file-scoped surfaces are unsupported")]
396    LaunchFileScopedSurf,
397    #[error("a call to cudaDeviceSynchronize failed due to exceeding the sync depth")]
398    SyncDepthExceeded,
399    #[error(
400        "a grid launch failed because the launch exceeded the limit of pending device runtime launches"
401    )]
402    LaunchPendingCountExceeded,
403    #[error(
404        "the requested device function does not exist or is not compiled for the proper device architecture"
405    )]
406    InvalidDeviceFunction,
407    #[error("no CUDA-capable devices were detected")]
408    NoDevice,
409    #[error("the device ordinal supplied does not correspond to a valid CUDA device")]
410    InvalidDevice,
411    #[error("the device does not have a valid Grid License")]
412    DeviceNotLicensed,
413    #[error("an internal startup failure occurred in the CUDA runtime")]
414    StartupFailure,
415    #[error("the device kernel image is invalid")]
416    InvalidKernelImage,
417    #[error("the device is not initialized")]
418    DeviceUninitialized,
419    #[error("the buffer object could not be mapped")]
420    MapBufferObjectFailed,
421    #[error("the buffer object could not be unmapped")]
422    UnmapBufferObjectFailed,
423    #[error("the specified array is currently mapped and cannot be destroyed")]
424    ArrayIsMapped,
425    #[error("the resource is already mapped")]
426    AlreadyMapped,
427    #[error("there is no kernel image available that is suitable for the device")]
428    NoKernelImageForDevice,
429    #[error("the resource has already been acquired")]
430    AlreadyAcquired,
431    #[error("the resource is not mapped")]
432    NotMapped,
433    #[error("the mapped resource is not available for access as an array")]
434    NotMappedAsArray,
435    #[error("the mapped resource is not available for access as a pointer")]
436    NotMappedAsPointer,
437    #[error("an uncorrectable ECC error was detected")]
438    ECCUncorrectable,
439    #[error("the specified cudaLimit is not supported by the device")]
440    UnsupportedLimit,
441    #[error("a call tried to access an exclusive-thread device that is already in use")]
442    DeviceAlreadyInUse,
443    #[error("P2P access is not supported across the given devices")]
444    PeerAccessUnsupported,
445    #[error("a PTX compilation failed")]
446    InvalidPtx,
447    #[error("an error occurred with the OpenGL or DirectX context")]
448    InvalidGraphicsContext,
449    #[error("an uncorrectable NVLink error was detected during execution")]
450    NvlinkUncorrectable,
451    #[error("the PTX JIT compiler library was not found")]
452    JitCompilerNotFound,
453    #[error("the provided PTX was compiled with an unsupported toolchain")]
454    UnsupportedPtxVersion,
455    #[error("JIT compilation was disabled")]
456    JitCompilationDisabled,
457    #[error("the provided execution affinity is not supported by the device")]
458    UnsupportedExecAffinity,
459    #[error("the operation is not permitted when the stream is capturing")]
460    StreamCaptureUnsupported,
461    #[error(
462        "the current capture sequence on the stream has been invalidated due to a previous error"
463    )]
464    StreamCaptureInvalidated,
465    #[error("a merge of two independent capture sequences was not allowed")]
466    StreamCaptureMerge,
467    #[error("the capture was not initiated in this stream")]
468    StreamCaptureUnmatched,
469    #[error("a stream capture sequence was passed to cudaStreamEndCapture in a different thread")]
470    StreamCaptureWrongThread,
471    #[error("the wait operation has timed out")]
472    Timeout,
473    #[error("an unknown internal error occurred")]
474    Unknown,
475    #[error("the API call returned a failure")]
476    ApiFailureBase,
477}
478
479pub fn cuda_check(result: cudaError_t) -> Result<(), CudaError> {
480    match result.0 {
481        0 => Ok(()),
482        1 => Err(CudaError::InvalidValue),
483        2 => Err(CudaError::MemoryAllocation),
484        3 => Err(CudaError::InitializationError),
485        4 => Err(CudaError::CudartUnloading),
486        5 => Err(CudaError::ProfilerDisabled),
487        6 => Err(CudaError::ProfilerNotInitialized),
488        7 => Err(CudaError::ProfilerAlreadyStarted),
489        8 => Err(CudaError::ProfilerAlreadyStopped),
490        9 => Err(CudaError::InvalidConfiguration),
491        12 => Err(CudaError::InvalidPitchValue),
492        13 => Err(CudaError::InvalidSymbol),
493        16 => Err(CudaError::InvalidHostPointer),
494        17 => Err(CudaError::InvalidDevicePointer),
495        18 => Err(CudaError::InvalidTexture),
496        19 => Err(CudaError::InvalidTextureBinding),
497        20 => Err(CudaError::InvalidChannelDescriptor),
498        21 => Err(CudaError::InvalidMemcpyDirection),
499        22 => Err(CudaError::AddressOfConstant),
500        23 => Err(CudaError::TextureFetchFailed),
501        24 => Err(CudaError::TextureNotBound),
502        25 => Err(CudaError::SynchronizationError),
503        26 => Err(CudaError::InvalidFilterSetting),
504        27 => Err(CudaError::InvalidNormSetting),
505        31 => Err(CudaError::NotYetImplemented),
506        32 => Err(CudaError::MemoryValueTooLarge),
507        34 => Err(CudaError::StubLibrary),
508        35 => Err(CudaError::InsufficientDriver),
509        36 => Err(CudaError::CallRequiresNewerDriver),
510        37 => Err(CudaError::InvalidSurface),
511        43 => Err(CudaError::DuplicateVariableName),
512        44 => Err(CudaError::DuplicateTextureName),
513        45 => Err(CudaError::DuplicateSurfaceName),
514        46 => Err(CudaError::DevicesUnavailable),
515        49 => Err(CudaError::IncompatibleDriverContext),
516        52 => Err(CudaError::MissingConfiguration),
517        53 => Err(CudaError::PriorLaunchFailure),
518        65 => Err(CudaError::LaunchMaxDepthExceeded),
519        66 => Err(CudaError::LaunchFileScopedTex),
520        67 => Err(CudaError::LaunchFileScopedSurf),
521        68 => Err(CudaError::SyncDepthExceeded),
522        69 => Err(CudaError::LaunchPendingCountExceeded),
523        98 => Err(CudaError::InvalidDeviceFunction),
524        100 => Err(CudaError::NoDevice),
525        101 => Err(CudaError::InvalidDevice),
526        102 => Err(CudaError::DeviceNotLicensed),
527        127 => Err(CudaError::StartupFailure),
528        200 => Err(CudaError::InvalidKernelImage),
529        201 => Err(CudaError::DeviceUninitialized),
530        205 => Err(CudaError::MapBufferObjectFailed),
531        206 => Err(CudaError::UnmapBufferObjectFailed),
532        207 => Err(CudaError::ArrayIsMapped),
533        208 => Err(CudaError::AlreadyMapped),
534        209 => Err(CudaError::NoKernelImageForDevice),
535        210 => Err(CudaError::AlreadyAcquired),
536        211 => Err(CudaError::NotMapped),
537        212 => Err(CudaError::NotMappedAsArray),
538        213 => Err(CudaError::NotMappedAsPointer),
539        214 => Err(CudaError::ECCUncorrectable),
540        215 => Err(CudaError::UnsupportedLimit),
541        216 => Err(CudaError::DeviceAlreadyInUse),
542        217 => Err(CudaError::PeerAccessUnsupported),
543        218 => Err(CudaError::InvalidPtx),
544        219 => Err(CudaError::InvalidGraphicsContext),
545        220 => Err(CudaError::NvlinkUncorrectable),
546        221 => Err(CudaError::JitCompilerNotFound),
547        222 => Err(CudaError::UnsupportedPtxVersion),
548        223 => Err(CudaError::JitCompilationDisabled),
549        224 => Err(CudaError::UnsupportedExecAffinity),
550        900 => Err(CudaError::StreamCaptureUnsupported),
551        901 => Err(CudaError::StreamCaptureInvalidated),
552        902 => Err(CudaError::StreamCaptureMerge),
553        903 => Err(CudaError::StreamCaptureUnmatched),
554        904 => Err(CudaError::StreamCaptureWrongThread),
555        909 => Err(CudaError::Timeout),
556        999 => Err(CudaError::Unknown),
557        _ => panic!("Unknown cudaError_t: {:?}", result.0),
558    }
559}
560
561pub fn set_device(device: CudaDevice) -> Result<(), CudaError> {
562    let index: i8 = device.index().into();
563    // SAFETY: intended usage of this function
564    unsafe { cuda_check(cudaSetDevice(index.into())) }
565}