torch_sys_cuda/
cuda.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9//! Bindings for torch's wrappers around CUDA-related functionality.
10use std::time::Duration;
11
12use derive_more::Into;
13use monarch_types::py_global;
14use nccl_sys::cudaError_t;
15use nccl_sys::cudaSetDevice;
16use nccl_sys::cudaStream_t;
17use pyo3::PyObject;
18use pyo3::prelude::*;
19use thiserror::Error;
20use torch_sys2::CudaDevice;
21
22// Cached imports for torch.cuda APIs
23py_global!(cuda_stream_class, "torch.cuda", "Stream");
24py_global!(cuda_event_class, "torch.cuda", "Event");
25py_global!(cuda_current_stream, "torch.cuda", "current_stream");
26py_global!(cuda_set_stream, "torch.cuda", "set_stream");
27py_global!(cuda_current_device, "torch.cuda", "current_device");
28py_global!(cuda_set_device, "torch.cuda", "set_device");
29
30/// Wrapper around a CUDA stream.
31///
32/// A CUDA stream is a linear sequence of execution that belongs to a specific
33/// device, independent from other streams.  See the documentation for
34/// `torch.cuda.Stream` for more details.
35#[derive(Debug, Into)]
36#[into(ref)]
37pub struct Stream {
38    inner: PyObject,
39}
40
41impl Clone for Stream {
42    fn clone(&self) -> Self {
43        Python::with_gil(|py| Self {
44            inner: self.inner.clone_ref(py),
45        })
46    }
47}
48
49impl Stream {
50    /// Create a new stream on the current device, at priority 0.
51    pub fn new() -> Self {
52        Python::with_gil(|py| {
53            let stream = cuda_stream_class(py).call0().unwrap();
54            Self {
55                inner: stream.into(),
56            }
57        })
58    }
59
60    /// Clone this stream reference. Requires holding the GIL.
61    pub fn clone_ref(&self, py: Python<'_>) -> Self {
62        Self {
63            inner: self.inner.clone_ref(py),
64        }
65    }
66
67    /// Create a new stream on the specified device, at priority 0.
68    pub fn new_with_device(device: CudaDevice) -> Self {
69        Python::with_gil(|py| {
70            let device_idx: i8 = device.index().into();
71            let stream = cuda_stream_class(py).call1((device_idx,)).unwrap();
72            Self {
73                inner: stream.into(),
74            }
75        })
76    }
77
78    /// Get the current stream on the current device.
79    pub fn get_current_stream() -> Self {
80        Python::with_gil(|py| {
81            let stream = cuda_current_stream(py).call0().unwrap();
82            Self {
83                inner: stream.into(),
84            }
85        })
86    }
87
88    /// Get the current stream on the specified device.
89    pub fn get_current_stream_on_device(device: CudaDevice) -> Self {
90        Python::with_gil(|py| {
91            let device_idx: i8 = device.index().into();
92            let stream = cuda_current_stream(py).call1((device_idx,)).unwrap();
93            Self {
94                inner: stream.into(),
95            }
96        })
97    }
98
99    /// Set the provided stream as the current stream. Also sets the current
100    /// device to be the same as the stream's device.
101    pub fn set_current_stream(stream: &Stream) {
102        Python::with_gil(|py| {
103            let stream_obj = stream.inner.bind(py);
104
105            // Get current device and stream device
106            let current_device = cuda_current_device(py)
107                .call0()
108                .unwrap()
109                .extract::<i64>()
110                .unwrap();
111
112            let stream_device = stream_obj
113                .getattr("device_index")
114                .unwrap()
115                .extract::<i64>()
116                .unwrap();
117
118            // Set device if different
119            if current_device != stream_device {
120                cuda_set_device(py).call1((stream_device,)).unwrap();
121            }
122
123            // Set the stream
124            cuda_set_stream(py).call1((stream_obj,)).unwrap();
125        })
126    }
127
128    /// Make all future work submitted to this stream wait for an event.
129    pub fn wait_event(&self, event: &mut Event) {
130        event.wait(Some(self))
131    }
132
133    /// Synchronize with another stream.
134    ///
135    /// All future work submitted to this stream will wait until all kernels
136    /// submitted to a given stream at the time of call entry complete.
137    pub fn wait_stream(&self, stream: &Stream) {
138        self.wait_event(&mut stream.record_event(None))
139    }
140
141    /// Record an event on this stream. If no event is provided one will be
142    /// created.
143    pub fn record_event(&self, event: Option<Event>) -> Event {
144        let mut event = event.unwrap_or(Event::new());
145        event.record(Some(self));
146        event
147    }
148
149    /// Check if all work submitted to this stream has completed.
150    pub fn query(&self) -> bool {
151        Python::with_gil(|py| {
152            let stream_obj = self.inner.bind(py);
153            stream_obj
154                .call_method0("query")
155                .unwrap()
156                .extract::<bool>()
157                .unwrap()
158        })
159    }
160
161    /// Wait for all kernels in this stream to complete.
162    pub fn synchronize(&self) {
163        Python::with_gil(|py| {
164            let stream_obj = self.inner.bind(py);
165            stream_obj.call_method0("synchronize").unwrap();
166        })
167    }
168
169    pub fn stream(&self) -> cudaStream_t {
170        Python::with_gil(|py| {
171            let stream_obj = self.inner.bind(py);
172            let cuda_stream = stream_obj.getattr("cuda_stream").unwrap();
173
174            // Extract the raw pointer
175            let ptr = cuda_stream.extract::<usize>().unwrap();
176
177            ptr as cudaStream_t
178        })
179    }
180}
181
182impl PartialEq for Stream {
183    fn eq(&self, other: &Self) -> bool {
184        self.stream() == other.stream()
185    }
186}
187
188/// Wrapper around a CUDA event.
189///
190/// CUDA events are synchronization markers that can be used to monitor the
191/// device's progress, to accurately measure timing, and to synchronize CUDA
192/// streams.
193///
194/// The underlying CUDA events are lazily initialized when the event is first
195/// recorded or exported to another process. After creation, only streams on the
196/// same device may record the event. However, streams on any device can wait on
197/// the event.
198///
199/// See the docs of `torch.cuda.Event` for more details.
200#[derive(Debug)]
201pub struct Event {
202    inner: PyObject,
203}
204
205impl Clone for Event {
206    fn clone(&self) -> Self {
207        Python::with_gil(|py| Self {
208            inner: self.inner.clone_ref(py),
209        })
210    }
211}
212
213impl Event {
214    /// Create a new event.
215    // TODO: add support for flags.
216    pub fn new() -> Self {
217        Python::with_gil(|py| {
218            let event = cuda_event_class(py).call0().unwrap();
219            Self {
220                inner: event.into(),
221            }
222        })
223    }
224
225    /// Record the event on the current stream.
226    ///
227    /// Uses the current stream if no stream is provided.
228    pub fn record(&mut self, stream: Option<&Stream>) {
229        Python::with_gil(|py| {
230            let event_obj = self.inner.bind(py);
231
232            match stream {
233                Some(stream) => {
234                    let stream_obj = stream.inner.bind(py);
235                    event_obj.call_method1("record", (stream_obj,)).unwrap();
236                }
237                None => {
238                    event_obj.call_method0("record").unwrap();
239                }
240            }
241        })
242    }
243
244    /// Make all future work submitted to the given stream wait for this event.
245    ///
246    /// Uses the current stream if no stream is specified.
247    pub fn wait(&mut self, stream: Option<&Stream>) {
248        Python::with_gil(|py| {
249            let event_obj = self.inner.bind(py);
250
251            match stream {
252                Some(stream) => {
253                    let stream_obj = stream.inner.bind(py);
254                    event_obj.call_method1("wait", (stream_obj,)).unwrap();
255                }
256                None => {
257                    event_obj.call_method0("wait").unwrap();
258                }
259            }
260        })
261    }
262
263    /// Check if all work currently captured by event has completed.
264    pub fn query(&self) -> bool {
265        Python::with_gil(|py| {
266            let event_obj = self.inner.bind(py);
267            event_obj
268                .call_method0("query")
269                .unwrap()
270                .extract::<bool>()
271                .unwrap()
272        })
273    }
274
275    /// Return the time elapsed.
276    ///
277    /// Time reported in after the event was recorded and before the end_event
278    /// was recorded.
279    pub fn elapsed_time(&self, end_event: &Event) -> Duration {
280        Python::with_gil(|py| {
281            let event_obj = self.inner.bind(py);
282            let end_event_obj = end_event.inner.bind(py);
283
284            let elapsed_ms = event_obj
285                .call_method1("elapsed_time", (end_event_obj,))
286                .unwrap()
287                .extract::<f64>()
288                .unwrap();
289
290            Duration::from_millis(elapsed_ms as u64)
291        })
292    }
293
294    /// Wait for the event to complete.
295    /// Waits until the completion of all work currently captured in this event.
296    /// This prevents the CPU thread from proceeding until the event completes.
297    pub fn synchronize(&self) {
298        Python::with_gil(|py| {
299            let event_obj = self.inner.bind(py);
300            event_obj.call_method0("synchronize").unwrap();
301        })
302    }
303}
304
305/// Corresponds to the CUDA error codes.
306#[derive(Debug, Error)]
307pub enum CudaError {
308    #[error(
309        "one or more parameters passed to the API call is not within an acceptable range of values"
310    )]
311    InvalidValue,
312    #[error("the API call failed due to insufficient memory or resources")]
313    MemoryAllocation,
314    #[error("failed to initialize the CUDA driver and runtime")]
315    InitializationError,
316    #[error("CUDA Runtime API call was executed after the CUDA driver has been unloaded")]
317    CudartUnloading,
318    #[error("profiler is not initialized for this run, possibly due to an external profiling tool")]
319    ProfilerDisabled,
320    #[error("deprecated. Attempted to enable/disable profiling without initialization")]
321    ProfilerNotInitialized,
322    #[error("deprecated. Profiling is already started")]
323    ProfilerAlreadyStarted,
324    #[error("deprecated. Profiling is already stopped")]
325    ProfilerAlreadyStopped,
326    #[error("kernel launch requested resources that cannot be satisfied by the current device")]
327    InvalidConfiguration,
328    #[error("one or more of the pitch-related parameters passed to the API call is out of range")]
329    InvalidPitchValue,
330    #[error("the symbol name/identifier passed to the API call is invalid")]
331    InvalidSymbol,
332    #[error("the host pointer passed to the API call is invalid")]
333    InvalidHostPointer,
334    #[error("the device pointer passed to the API call is invalid")]
335    InvalidDevicePointer,
336    #[error("the texture passed to the API call is invalid")]
337    InvalidTexture,
338    #[error("the texture binding is invalid")]
339    InvalidTextureBinding,
340    #[error("the channel descriptor passed to the API call is invalid")]
341    InvalidChannelDescriptor,
342    #[error("the direction of the memcpy operation is invalid")]
343    InvalidMemcpyDirection,
344    #[error(
345        "attempted to take the address of a constant variable, which is forbidden before CUDA 3.1"
346    )]
347    AddressOfConstant,
348    #[error("deprecated. A texture fetch operation failed")]
349    TextureFetchFailed,
350    #[error("deprecated. The texture is not bound for access")]
351    TextureNotBound,
352    #[error("a synchronization operation failed")]
353    SynchronizationError,
354    #[error(
355        "a non-float texture was accessed with linear filtering, which is not supported by CUDA"
356    )]
357    InvalidFilterSetting,
358    #[error(
359        "attempted to read a non-float texture as a normalized float, which is not supported by CUDA"
360    )]
361    InvalidNormSetting,
362    #[error("the API call is not yet implemented")]
363    NotYetImplemented,
364    #[error("an emulated device pointer exceeded the 32-bit address range")]
365    MemoryValueTooLarge,
366    #[error("the CUDA driver is a stub library")]
367    StubLibrary,
368    #[error("the installed NVIDIA CUDA driver is older than the CUDA runtime library")]
369    InsufficientDriver,
370    #[error("the API call requires a newer CUDA driver")]
371    CallRequiresNewerDriver,
372    #[error("the surface passed to the API call is invalid")]
373    InvalidSurface,
374    #[error("multiple global or constant variables share the same string name")]
375    DuplicateVariableName,
376    #[error("multiple textures share the same string name")]
377    DuplicateTextureName,
378    #[error("multiple surfaces share the same string name")]
379    DuplicateSurfaceName,
380    #[error("all CUDA devices are currently busy or unavailable")]
381    DevicesUnavailable,
382    #[error("the current CUDA context is not compatible with the runtime")]
383    IncompatibleDriverContext,
384    #[error("the device function being invoked was not previously configured")]
385    MissingConfiguration,
386    #[error("a previous kernel launch failed")]
387    PriorLaunchFailure,
388    #[error(
389        "the depth of the child grid exceeded the maximum supported number of nested grid launches"
390    )]
391    LaunchMaxDepthExceeded,
392    #[error("a grid launch did not occur because file-scoped textures are unsupported")]
393    LaunchFileScopedTex,
394    #[error("a grid launch did not occur because file-scoped surfaces are unsupported")]
395    LaunchFileScopedSurf,
396    #[error("a call to cudaDeviceSynchronize failed due to exceeding the sync depth")]
397    SyncDepthExceeded,
398    #[error(
399        "a grid launch failed because the launch exceeded the limit of pending device runtime launches"
400    )]
401    LaunchPendingCountExceeded,
402    #[error(
403        "the requested device function does not exist or is not compiled for the proper device architecture"
404    )]
405    InvalidDeviceFunction,
406    #[error("no CUDA-capable devices were detected")]
407    NoDevice,
408    #[error("the device ordinal supplied does not correspond to a valid CUDA device")]
409    InvalidDevice,
410    #[error("the device does not have a valid Grid License")]
411    DeviceNotLicensed,
412    #[error("an internal startup failure occurred in the CUDA runtime")]
413    StartupFailure,
414    #[error("the device kernel image is invalid")]
415    InvalidKernelImage,
416    #[error("the device is not initialized")]
417    DeviceUninitialized,
418    #[error("the buffer object could not be mapped")]
419    MapBufferObjectFailed,
420    #[error("the buffer object could not be unmapped")]
421    UnmapBufferObjectFailed,
422    #[error("the specified array is currently mapped and cannot be destroyed")]
423    ArrayIsMapped,
424    #[error("the resource is already mapped")]
425    AlreadyMapped,
426    #[error("there is no kernel image available that is suitable for the device")]
427    NoKernelImageForDevice,
428    #[error("the resource has already been acquired")]
429    AlreadyAcquired,
430    #[error("the resource is not mapped")]
431    NotMapped,
432    #[error("the mapped resource is not available for access as an array")]
433    NotMappedAsArray,
434    #[error("the mapped resource is not available for access as a pointer")]
435    NotMappedAsPointer,
436    #[error("an uncorrectable ECC error was detected")]
437    ECCUncorrectable,
438    #[error("the specified cudaLimit is not supported by the device")]
439    UnsupportedLimit,
440    #[error("a call tried to access an exclusive-thread device that is already in use")]
441    DeviceAlreadyInUse,
442    #[error("P2P access is not supported across the given devices")]
443    PeerAccessUnsupported,
444    #[error("a PTX compilation failed")]
445    InvalidPtx,
446    #[error("an error occurred with the OpenGL or DirectX context")]
447    InvalidGraphicsContext,
448    #[error("an uncorrectable NVLink error was detected during execution")]
449    NvlinkUncorrectable,
450    #[error("the PTX JIT compiler library was not found")]
451    JitCompilerNotFound,
452    #[error("the provided PTX was compiled with an unsupported toolchain")]
453    UnsupportedPtxVersion,
454    #[error("JIT compilation was disabled")]
455    JitCompilationDisabled,
456    #[error("the provided execution affinity is not supported by the device")]
457    UnsupportedExecAffinity,
458    #[error("the operation is not permitted when the stream is capturing")]
459    StreamCaptureUnsupported,
460    #[error(
461        "the current capture sequence on the stream has been invalidated due to a previous error"
462    )]
463    StreamCaptureInvalidated,
464    #[error("a merge of two independent capture sequences was not allowed")]
465    StreamCaptureMerge,
466    #[error("the capture was not initiated in this stream")]
467    StreamCaptureUnmatched,
468    #[error("a stream capture sequence was passed to cudaStreamEndCapture in a different thread")]
469    StreamCaptureWrongThread,
470    #[error("the wait operation has timed out")]
471    Timeout,
472    #[error("an unknown internal error occurred")]
473    Unknown,
474    #[error("the API call returned a failure")]
475    ApiFailureBase,
476}
477
478pub fn cuda_check(result: cudaError_t) -> Result<(), CudaError> {
479    match result.0 {
480        0 => Ok(()),
481        1 => Err(CudaError::InvalidValue),
482        2 => Err(CudaError::MemoryAllocation),
483        3 => Err(CudaError::InitializationError),
484        4 => Err(CudaError::CudartUnloading),
485        5 => Err(CudaError::ProfilerDisabled),
486        6 => Err(CudaError::ProfilerNotInitialized),
487        7 => Err(CudaError::ProfilerAlreadyStarted),
488        8 => Err(CudaError::ProfilerAlreadyStopped),
489        9 => Err(CudaError::InvalidConfiguration),
490        12 => Err(CudaError::InvalidPitchValue),
491        13 => Err(CudaError::InvalidSymbol),
492        16 => Err(CudaError::InvalidHostPointer),
493        17 => Err(CudaError::InvalidDevicePointer),
494        18 => Err(CudaError::InvalidTexture),
495        19 => Err(CudaError::InvalidTextureBinding),
496        20 => Err(CudaError::InvalidChannelDescriptor),
497        21 => Err(CudaError::InvalidMemcpyDirection),
498        22 => Err(CudaError::AddressOfConstant),
499        23 => Err(CudaError::TextureFetchFailed),
500        24 => Err(CudaError::TextureNotBound),
501        25 => Err(CudaError::SynchronizationError),
502        26 => Err(CudaError::InvalidFilterSetting),
503        27 => Err(CudaError::InvalidNormSetting),
504        31 => Err(CudaError::NotYetImplemented),
505        32 => Err(CudaError::MemoryValueTooLarge),
506        34 => Err(CudaError::StubLibrary),
507        35 => Err(CudaError::InsufficientDriver),
508        36 => Err(CudaError::CallRequiresNewerDriver),
509        37 => Err(CudaError::InvalidSurface),
510        43 => Err(CudaError::DuplicateVariableName),
511        44 => Err(CudaError::DuplicateTextureName),
512        45 => Err(CudaError::DuplicateSurfaceName),
513        46 => Err(CudaError::DevicesUnavailable),
514        49 => Err(CudaError::IncompatibleDriverContext),
515        52 => Err(CudaError::MissingConfiguration),
516        53 => Err(CudaError::PriorLaunchFailure),
517        65 => Err(CudaError::LaunchMaxDepthExceeded),
518        66 => Err(CudaError::LaunchFileScopedTex),
519        67 => Err(CudaError::LaunchFileScopedSurf),
520        68 => Err(CudaError::SyncDepthExceeded),
521        69 => Err(CudaError::LaunchPendingCountExceeded),
522        98 => Err(CudaError::InvalidDeviceFunction),
523        100 => Err(CudaError::NoDevice),
524        101 => Err(CudaError::InvalidDevice),
525        102 => Err(CudaError::DeviceNotLicensed),
526        127 => Err(CudaError::StartupFailure),
527        200 => Err(CudaError::InvalidKernelImage),
528        201 => Err(CudaError::DeviceUninitialized),
529        205 => Err(CudaError::MapBufferObjectFailed),
530        206 => Err(CudaError::UnmapBufferObjectFailed),
531        207 => Err(CudaError::ArrayIsMapped),
532        208 => Err(CudaError::AlreadyMapped),
533        209 => Err(CudaError::NoKernelImageForDevice),
534        210 => Err(CudaError::AlreadyAcquired),
535        211 => Err(CudaError::NotMapped),
536        212 => Err(CudaError::NotMappedAsArray),
537        213 => Err(CudaError::NotMappedAsPointer),
538        214 => Err(CudaError::ECCUncorrectable),
539        215 => Err(CudaError::UnsupportedLimit),
540        216 => Err(CudaError::DeviceAlreadyInUse),
541        217 => Err(CudaError::PeerAccessUnsupported),
542        218 => Err(CudaError::InvalidPtx),
543        219 => Err(CudaError::InvalidGraphicsContext),
544        220 => Err(CudaError::NvlinkUncorrectable),
545        221 => Err(CudaError::JitCompilerNotFound),
546        222 => Err(CudaError::UnsupportedPtxVersion),
547        223 => Err(CudaError::JitCompilationDisabled),
548        224 => Err(CudaError::UnsupportedExecAffinity),
549        900 => Err(CudaError::StreamCaptureUnsupported),
550        901 => Err(CudaError::StreamCaptureInvalidated),
551        902 => Err(CudaError::StreamCaptureMerge),
552        903 => Err(CudaError::StreamCaptureUnmatched),
553        904 => Err(CudaError::StreamCaptureWrongThread),
554        909 => Err(CudaError::Timeout),
555        999 => Err(CudaError::Unknown),
556        _ => panic!("Unknown cudaError_t: {:?}", result.0),
557    }
558}
559
560pub fn set_device(device: CudaDevice) -> Result<(), CudaError> {
561    let index: i8 = device.index().into();
562    // SAFETY: intended usage of this function
563    unsafe { cuda_check(cudaSetDevice(index.into())) }
564}