1use std::time::Duration;
11
12use derive_more::Into;
13use monarch_types::py_global;
14use nccl_sys::cudaError_t;
15use nccl_sys::cudaSetDevice;
16use nccl_sys::cudaStream_t;
17use pyo3::Py;
18use pyo3::PyAny;
19use pyo3::prelude::*;
20use thiserror::Error;
21use torch_sys2::CudaDevice;
22
23py_global!(cuda_stream_class, "torch.cuda", "Stream");
25py_global!(cuda_event_class, "torch.cuda", "Event");
26py_global!(cuda_current_stream, "torch.cuda", "current_stream");
27py_global!(cuda_set_stream, "torch.cuda", "set_stream");
28py_global!(cuda_current_device, "torch.cuda", "current_device");
29py_global!(cuda_set_device, "torch.cuda", "set_device");
30
31#[derive(Debug, Into)]
37#[into(ref)]
38pub struct Stream {
39 inner: Py<PyAny>,
40}
41
42impl Clone for Stream {
43 fn clone(&self) -> Self {
44 Python::attach(|py| Self {
45 inner: self.inner.clone_ref(py),
46 })
47 }
48}
49
50impl Stream {
51 pub fn new() -> Self {
53 Python::attach(|py| {
54 let stream = cuda_stream_class(py).call0().unwrap();
55 Self {
56 inner: stream.into(),
57 }
58 })
59 }
60
61 pub fn clone_ref(&self, py: Python<'_>) -> Self {
63 Self {
64 inner: self.inner.clone_ref(py),
65 }
66 }
67
68 pub fn new_with_device(device: CudaDevice) -> Self {
70 Python::attach(|py| {
71 let device_idx: i8 = device.index().into();
72 let stream = cuda_stream_class(py).call1((device_idx,)).unwrap();
73 Self {
74 inner: stream.into(),
75 }
76 })
77 }
78
79 pub fn get_current_stream() -> Self {
81 Python::attach(|py| {
82 let stream = cuda_current_stream(py).call0().unwrap();
83 Self {
84 inner: stream.into(),
85 }
86 })
87 }
88
89 pub fn get_current_stream_on_device(device: CudaDevice) -> Self {
91 Python::attach(|py| {
92 let device_idx: i8 = device.index().into();
93 let stream = cuda_current_stream(py).call1((device_idx,)).unwrap();
94 Self {
95 inner: stream.into(),
96 }
97 })
98 }
99
100 pub fn set_current_stream(stream: &Stream) {
103 Python::attach(|py| {
104 let stream_obj = stream.inner.bind(py);
105
106 let current_device = cuda_current_device(py)
108 .call0()
109 .unwrap()
110 .extract::<i64>()
111 .unwrap();
112
113 let stream_device = stream_obj
114 .getattr("device_index")
115 .unwrap()
116 .extract::<i64>()
117 .unwrap();
118
119 if current_device != stream_device {
121 cuda_set_device(py).call1((stream_device,)).unwrap();
122 }
123
124 cuda_set_stream(py).call1((stream_obj,)).unwrap();
126 })
127 }
128
129 pub fn wait_event(&self, event: &mut Event) {
131 event.wait(Some(self))
132 }
133
134 pub fn wait_stream(&self, stream: &Stream) {
139 self.wait_event(&mut stream.record_event(None))
140 }
141
142 pub fn record_event(&self, event: Option<Event>) -> Event {
145 let mut event = event.unwrap_or(Event::new());
146 event.record(Some(self));
147 event
148 }
149
150 pub fn query(&self) -> bool {
152 Python::attach(|py| {
153 let stream_obj = self.inner.bind(py);
154 stream_obj
155 .call_method0("query")
156 .unwrap()
157 .extract::<bool>()
158 .unwrap()
159 })
160 }
161
162 pub fn synchronize(&self) {
164 Python::attach(|py| {
165 let stream_obj = self.inner.bind(py);
166 stream_obj.call_method0("synchronize").unwrap();
167 })
168 }
169
170 pub fn stream(&self) -> cudaStream_t {
171 Python::attach(|py| {
172 let stream_obj = self.inner.bind(py);
173 let cuda_stream = stream_obj.getattr("cuda_stream").unwrap();
174
175 let ptr = cuda_stream.extract::<usize>().unwrap();
177
178 ptr as cudaStream_t
179 })
180 }
181}
182
183impl PartialEq for Stream {
184 fn eq(&self, other: &Self) -> bool {
185 self.stream() == other.stream()
186 }
187}
188
189#[derive(Debug)]
202pub struct Event {
203 inner: Py<PyAny>,
204}
205
206impl Clone for Event {
207 fn clone(&self) -> Self {
208 Python::attach(|py| Self {
209 inner: self.inner.clone_ref(py),
210 })
211 }
212}
213
214impl Event {
215 pub fn new() -> Self {
218 Python::attach(|py| {
219 let event = cuda_event_class(py).call0().unwrap();
220 Self {
221 inner: event.into(),
222 }
223 })
224 }
225
226 pub fn record(&mut self, stream: Option<&Stream>) {
230 Python::attach(|py| {
231 let event_obj = self.inner.bind(py);
232
233 match stream {
234 Some(stream) => {
235 let stream_obj = stream.inner.bind(py);
236 event_obj.call_method1("record", (stream_obj,)).unwrap();
237 }
238 None => {
239 event_obj.call_method0("record").unwrap();
240 }
241 }
242 })
243 }
244
245 pub fn wait(&mut self, stream: Option<&Stream>) {
249 Python::attach(|py| {
250 let event_obj = self.inner.bind(py);
251
252 match stream {
253 Some(stream) => {
254 let stream_obj = stream.inner.bind(py);
255 event_obj.call_method1("wait", (stream_obj,)).unwrap();
256 }
257 None => {
258 event_obj.call_method0("wait").unwrap();
259 }
260 }
261 })
262 }
263
264 pub fn query(&self) -> bool {
266 Python::attach(|py| {
267 let event_obj = self.inner.bind(py);
268 event_obj
269 .call_method0("query")
270 .unwrap()
271 .extract::<bool>()
272 .unwrap()
273 })
274 }
275
276 pub fn elapsed_time(&self, end_event: &Event) -> Duration {
281 Python::attach(|py| {
282 let event_obj = self.inner.bind(py);
283 let end_event_obj = end_event.inner.bind(py);
284
285 let elapsed_ms = event_obj
286 .call_method1("elapsed_time", (end_event_obj,))
287 .unwrap()
288 .extract::<f64>()
289 .unwrap();
290
291 Duration::from_millis(elapsed_ms as u64)
292 })
293 }
294
295 pub fn synchronize(&self) {
299 Python::attach(|py| {
300 let event_obj = self.inner.bind(py);
301 event_obj.call_method0("synchronize").unwrap();
302 })
303 }
304}
305
306#[derive(Debug, Error)]
308pub enum CudaError {
309 #[error(
310 "one or more parameters passed to the API call is not within an acceptable range of values"
311 )]
312 InvalidValue,
313 #[error("the API call failed due to insufficient memory or resources")]
314 MemoryAllocation,
315 #[error("failed to initialize the CUDA driver and runtime")]
316 InitializationError,
317 #[error("CUDA Runtime API call was executed after the CUDA driver has been unloaded")]
318 CudartUnloading,
319 #[error("profiler is not initialized for this run, possibly due to an external profiling tool")]
320 ProfilerDisabled,
321 #[error("deprecated. Attempted to enable/disable profiling without initialization")]
322 ProfilerNotInitialized,
323 #[error("deprecated. Profiling is already started")]
324 ProfilerAlreadyStarted,
325 #[error("deprecated. Profiling is already stopped")]
326 ProfilerAlreadyStopped,
327 #[error("kernel launch requested resources that cannot be satisfied by the current device")]
328 InvalidConfiguration,
329 #[error("one or more of the pitch-related parameters passed to the API call is out of range")]
330 InvalidPitchValue,
331 #[error("the symbol name/identifier passed to the API call is invalid")]
332 InvalidSymbol,
333 #[error("the host pointer passed to the API call is invalid")]
334 InvalidHostPointer,
335 #[error("the device pointer passed to the API call is invalid")]
336 InvalidDevicePointer,
337 #[error("the texture passed to the API call is invalid")]
338 InvalidTexture,
339 #[error("the texture binding is invalid")]
340 InvalidTextureBinding,
341 #[error("the channel descriptor passed to the API call is invalid")]
342 InvalidChannelDescriptor,
343 #[error("the direction of the memcpy operation is invalid")]
344 InvalidMemcpyDirection,
345 #[error(
346 "attempted to take the address of a constant variable, which is forbidden before CUDA 3.1"
347 )]
348 AddressOfConstant,
349 #[error("deprecated. A texture fetch operation failed")]
350 TextureFetchFailed,
351 #[error("deprecated. The texture is not bound for access")]
352 TextureNotBound,
353 #[error("a synchronization operation failed")]
354 SynchronizationError,
355 #[error(
356 "a non-float texture was accessed with linear filtering, which is not supported by CUDA"
357 )]
358 InvalidFilterSetting,
359 #[error(
360 "attempted to read a non-float texture as a normalized float, which is not supported by CUDA"
361 )]
362 InvalidNormSetting,
363 #[error("the API call is not yet implemented")]
364 NotYetImplemented,
365 #[error("an emulated device pointer exceeded the 32-bit address range")]
366 MemoryValueTooLarge,
367 #[error("the CUDA driver is a stub library")]
368 StubLibrary,
369 #[error("the installed NVIDIA CUDA driver is older than the CUDA runtime library")]
370 InsufficientDriver,
371 #[error("the API call requires a newer CUDA driver")]
372 CallRequiresNewerDriver,
373 #[error("the surface passed to the API call is invalid")]
374 InvalidSurface,
375 #[error("multiple global or constant variables share the same string name")]
376 DuplicateVariableName,
377 #[error("multiple textures share the same string name")]
378 DuplicateTextureName,
379 #[error("multiple surfaces share the same string name")]
380 DuplicateSurfaceName,
381 #[error("all CUDA devices are currently busy or unavailable")]
382 DevicesUnavailable,
383 #[error("the current CUDA context is not compatible with the runtime")]
384 IncompatibleDriverContext,
385 #[error("the device function being invoked was not previously configured")]
386 MissingConfiguration,
387 #[error("a previous kernel launch failed")]
388 PriorLaunchFailure,
389 #[error(
390 "the depth of the child grid exceeded the maximum supported number of nested grid launches"
391 )]
392 LaunchMaxDepthExceeded,
393 #[error("a grid launch did not occur because file-scoped textures are unsupported")]
394 LaunchFileScopedTex,
395 #[error("a grid launch did not occur because file-scoped surfaces are unsupported")]
396 LaunchFileScopedSurf,
397 #[error("a call to cudaDeviceSynchronize failed due to exceeding the sync depth")]
398 SyncDepthExceeded,
399 #[error(
400 "a grid launch failed because the launch exceeded the limit of pending device runtime launches"
401 )]
402 LaunchPendingCountExceeded,
403 #[error(
404 "the requested device function does not exist or is not compiled for the proper device architecture"
405 )]
406 InvalidDeviceFunction,
407 #[error("no CUDA-capable devices were detected")]
408 NoDevice,
409 #[error("the device ordinal supplied does not correspond to a valid CUDA device")]
410 InvalidDevice,
411 #[error("the device does not have a valid Grid License")]
412 DeviceNotLicensed,
413 #[error("an internal startup failure occurred in the CUDA runtime")]
414 StartupFailure,
415 #[error("the device kernel image is invalid")]
416 InvalidKernelImage,
417 #[error("the device is not initialized")]
418 DeviceUninitialized,
419 #[error("the buffer object could not be mapped")]
420 MapBufferObjectFailed,
421 #[error("the buffer object could not be unmapped")]
422 UnmapBufferObjectFailed,
423 #[error("the specified array is currently mapped and cannot be destroyed")]
424 ArrayIsMapped,
425 #[error("the resource is already mapped")]
426 AlreadyMapped,
427 #[error("there is no kernel image available that is suitable for the device")]
428 NoKernelImageForDevice,
429 #[error("the resource has already been acquired")]
430 AlreadyAcquired,
431 #[error("the resource is not mapped")]
432 NotMapped,
433 #[error("the mapped resource is not available for access as an array")]
434 NotMappedAsArray,
435 #[error("the mapped resource is not available for access as a pointer")]
436 NotMappedAsPointer,
437 #[error("an uncorrectable ECC error was detected")]
438 ECCUncorrectable,
439 #[error("the specified cudaLimit is not supported by the device")]
440 UnsupportedLimit,
441 #[error("a call tried to access an exclusive-thread device that is already in use")]
442 DeviceAlreadyInUse,
443 #[error("P2P access is not supported across the given devices")]
444 PeerAccessUnsupported,
445 #[error("a PTX compilation failed")]
446 InvalidPtx,
447 #[error("an error occurred with the OpenGL or DirectX context")]
448 InvalidGraphicsContext,
449 #[error("an uncorrectable NVLink error was detected during execution")]
450 NvlinkUncorrectable,
451 #[error("the PTX JIT compiler library was not found")]
452 JitCompilerNotFound,
453 #[error("the provided PTX was compiled with an unsupported toolchain")]
454 UnsupportedPtxVersion,
455 #[error("JIT compilation was disabled")]
456 JitCompilationDisabled,
457 #[error("the provided execution affinity is not supported by the device")]
458 UnsupportedExecAffinity,
459 #[error("the operation is not permitted when the stream is capturing")]
460 StreamCaptureUnsupported,
461 #[error(
462 "the current capture sequence on the stream has been invalidated due to a previous error"
463 )]
464 StreamCaptureInvalidated,
465 #[error("a merge of two independent capture sequences was not allowed")]
466 StreamCaptureMerge,
467 #[error("the capture was not initiated in this stream")]
468 StreamCaptureUnmatched,
469 #[error("a stream capture sequence was passed to cudaStreamEndCapture in a different thread")]
470 StreamCaptureWrongThread,
471 #[error("the wait operation has timed out")]
472 Timeout,
473 #[error("an unknown internal error occurred")]
474 Unknown,
475 #[error("the API call returned a failure")]
476 ApiFailureBase,
477}
478
479pub fn cuda_check(result: cudaError_t) -> Result<(), CudaError> {
480 match result.0 {
481 0 => Ok(()),
482 1 => Err(CudaError::InvalidValue),
483 2 => Err(CudaError::MemoryAllocation),
484 3 => Err(CudaError::InitializationError),
485 4 => Err(CudaError::CudartUnloading),
486 5 => Err(CudaError::ProfilerDisabled),
487 6 => Err(CudaError::ProfilerNotInitialized),
488 7 => Err(CudaError::ProfilerAlreadyStarted),
489 8 => Err(CudaError::ProfilerAlreadyStopped),
490 9 => Err(CudaError::InvalidConfiguration),
491 12 => Err(CudaError::InvalidPitchValue),
492 13 => Err(CudaError::InvalidSymbol),
493 16 => Err(CudaError::InvalidHostPointer),
494 17 => Err(CudaError::InvalidDevicePointer),
495 18 => Err(CudaError::InvalidTexture),
496 19 => Err(CudaError::InvalidTextureBinding),
497 20 => Err(CudaError::InvalidChannelDescriptor),
498 21 => Err(CudaError::InvalidMemcpyDirection),
499 22 => Err(CudaError::AddressOfConstant),
500 23 => Err(CudaError::TextureFetchFailed),
501 24 => Err(CudaError::TextureNotBound),
502 25 => Err(CudaError::SynchronizationError),
503 26 => Err(CudaError::InvalidFilterSetting),
504 27 => Err(CudaError::InvalidNormSetting),
505 31 => Err(CudaError::NotYetImplemented),
506 32 => Err(CudaError::MemoryValueTooLarge),
507 34 => Err(CudaError::StubLibrary),
508 35 => Err(CudaError::InsufficientDriver),
509 36 => Err(CudaError::CallRequiresNewerDriver),
510 37 => Err(CudaError::InvalidSurface),
511 43 => Err(CudaError::DuplicateVariableName),
512 44 => Err(CudaError::DuplicateTextureName),
513 45 => Err(CudaError::DuplicateSurfaceName),
514 46 => Err(CudaError::DevicesUnavailable),
515 49 => Err(CudaError::IncompatibleDriverContext),
516 52 => Err(CudaError::MissingConfiguration),
517 53 => Err(CudaError::PriorLaunchFailure),
518 65 => Err(CudaError::LaunchMaxDepthExceeded),
519 66 => Err(CudaError::LaunchFileScopedTex),
520 67 => Err(CudaError::LaunchFileScopedSurf),
521 68 => Err(CudaError::SyncDepthExceeded),
522 69 => Err(CudaError::LaunchPendingCountExceeded),
523 98 => Err(CudaError::InvalidDeviceFunction),
524 100 => Err(CudaError::NoDevice),
525 101 => Err(CudaError::InvalidDevice),
526 102 => Err(CudaError::DeviceNotLicensed),
527 127 => Err(CudaError::StartupFailure),
528 200 => Err(CudaError::InvalidKernelImage),
529 201 => Err(CudaError::DeviceUninitialized),
530 205 => Err(CudaError::MapBufferObjectFailed),
531 206 => Err(CudaError::UnmapBufferObjectFailed),
532 207 => Err(CudaError::ArrayIsMapped),
533 208 => Err(CudaError::AlreadyMapped),
534 209 => Err(CudaError::NoKernelImageForDevice),
535 210 => Err(CudaError::AlreadyAcquired),
536 211 => Err(CudaError::NotMapped),
537 212 => Err(CudaError::NotMappedAsArray),
538 213 => Err(CudaError::NotMappedAsPointer),
539 214 => Err(CudaError::ECCUncorrectable),
540 215 => Err(CudaError::UnsupportedLimit),
541 216 => Err(CudaError::DeviceAlreadyInUse),
542 217 => Err(CudaError::PeerAccessUnsupported),
543 218 => Err(CudaError::InvalidPtx),
544 219 => Err(CudaError::InvalidGraphicsContext),
545 220 => Err(CudaError::NvlinkUncorrectable),
546 221 => Err(CudaError::JitCompilerNotFound),
547 222 => Err(CudaError::UnsupportedPtxVersion),
548 223 => Err(CudaError::JitCompilationDisabled),
549 224 => Err(CudaError::UnsupportedExecAffinity),
550 900 => Err(CudaError::StreamCaptureUnsupported),
551 901 => Err(CudaError::StreamCaptureInvalidated),
552 902 => Err(CudaError::StreamCaptureMerge),
553 903 => Err(CudaError::StreamCaptureUnmatched),
554 904 => Err(CudaError::StreamCaptureWrongThread),
555 909 => Err(CudaError::Timeout),
556 999 => Err(CudaError::Unknown),
557 _ => panic!("Unknown cudaError_t: {:?}", result.0),
558 }
559}
560
561pub fn set_device(device: CudaDevice) -> Result<(), CudaError> {
562 let index: i8 = device.index().into();
563 unsafe { cuda_check(cudaSetDevice(index.into())) }
565}