1use std::time::Duration;
11
12use derive_more::Into;
13use monarch_types::py_global;
14use nccl_sys::cudaError_t;
15use nccl_sys::cudaSetDevice;
16use nccl_sys::cudaStream_t;
17use pyo3::PyObject;
18use pyo3::prelude::*;
19use thiserror::Error;
20use torch_sys2::CudaDevice;
21
22py_global!(cuda_stream_class, "torch.cuda", "Stream");
24py_global!(cuda_event_class, "torch.cuda", "Event");
25py_global!(cuda_current_stream, "torch.cuda", "current_stream");
26py_global!(cuda_set_stream, "torch.cuda", "set_stream");
27py_global!(cuda_current_device, "torch.cuda", "current_device");
28py_global!(cuda_set_device, "torch.cuda", "set_device");
29
30#[derive(Debug, Into)]
36#[into(ref)]
37pub struct Stream {
38 inner: PyObject,
39}
40
41impl Clone for Stream {
42 fn clone(&self) -> Self {
43 Python::with_gil(|py| Self {
44 inner: self.inner.clone_ref(py),
45 })
46 }
47}
48
49impl Stream {
50 pub fn new() -> Self {
52 Python::with_gil(|py| {
53 let stream = cuda_stream_class(py).call0().unwrap();
54 Self {
55 inner: stream.into(),
56 }
57 })
58 }
59
60 pub fn clone_ref(&self, py: Python<'_>) -> Self {
62 Self {
63 inner: self.inner.clone_ref(py),
64 }
65 }
66
67 pub fn new_with_device(device: CudaDevice) -> Self {
69 Python::with_gil(|py| {
70 let device_idx: i8 = device.index().into();
71 let stream = cuda_stream_class(py).call1((device_idx,)).unwrap();
72 Self {
73 inner: stream.into(),
74 }
75 })
76 }
77
78 pub fn get_current_stream() -> Self {
80 Python::with_gil(|py| {
81 let stream = cuda_current_stream(py).call0().unwrap();
82 Self {
83 inner: stream.into(),
84 }
85 })
86 }
87
88 pub fn get_current_stream_on_device(device: CudaDevice) -> Self {
90 Python::with_gil(|py| {
91 let device_idx: i8 = device.index().into();
92 let stream = cuda_current_stream(py).call1((device_idx,)).unwrap();
93 Self {
94 inner: stream.into(),
95 }
96 })
97 }
98
99 pub fn set_current_stream(stream: &Stream) {
102 Python::with_gil(|py| {
103 let stream_obj = stream.inner.bind(py);
104
105 let current_device = cuda_current_device(py)
107 .call0()
108 .unwrap()
109 .extract::<i64>()
110 .unwrap();
111
112 let stream_device = stream_obj
113 .getattr("device_index")
114 .unwrap()
115 .extract::<i64>()
116 .unwrap();
117
118 if current_device != stream_device {
120 cuda_set_device(py).call1((stream_device,)).unwrap();
121 }
122
123 cuda_set_stream(py).call1((stream_obj,)).unwrap();
125 })
126 }
127
128 pub fn wait_event(&self, event: &mut Event) {
130 event.wait(Some(self))
131 }
132
133 pub fn wait_stream(&self, stream: &Stream) {
138 self.wait_event(&mut stream.record_event(None))
139 }
140
141 pub fn record_event(&self, event: Option<Event>) -> Event {
144 let mut event = event.unwrap_or(Event::new());
145 event.record(Some(self));
146 event
147 }
148
149 pub fn query(&self) -> bool {
151 Python::with_gil(|py| {
152 let stream_obj = self.inner.bind(py);
153 stream_obj
154 .call_method0("query")
155 .unwrap()
156 .extract::<bool>()
157 .unwrap()
158 })
159 }
160
161 pub fn synchronize(&self) {
163 Python::with_gil(|py| {
164 let stream_obj = self.inner.bind(py);
165 stream_obj.call_method0("synchronize").unwrap();
166 })
167 }
168
169 pub fn stream(&self) -> cudaStream_t {
170 Python::with_gil(|py| {
171 let stream_obj = self.inner.bind(py);
172 let cuda_stream = stream_obj.getattr("cuda_stream").unwrap();
173
174 let ptr = cuda_stream.extract::<usize>().unwrap();
176
177 ptr as cudaStream_t
178 })
179 }
180}
181
182impl PartialEq for Stream {
183 fn eq(&self, other: &Self) -> bool {
184 self.stream() == other.stream()
185 }
186}
187
188#[derive(Debug)]
201pub struct Event {
202 inner: PyObject,
203}
204
205impl Clone for Event {
206 fn clone(&self) -> Self {
207 Python::with_gil(|py| Self {
208 inner: self.inner.clone_ref(py),
209 })
210 }
211}
212
213impl Event {
214 pub fn new() -> Self {
217 Python::with_gil(|py| {
218 let event = cuda_event_class(py).call0().unwrap();
219 Self {
220 inner: event.into(),
221 }
222 })
223 }
224
225 pub fn record(&mut self, stream: Option<&Stream>) {
229 Python::with_gil(|py| {
230 let event_obj = self.inner.bind(py);
231
232 match stream {
233 Some(stream) => {
234 let stream_obj = stream.inner.bind(py);
235 event_obj.call_method1("record", (stream_obj,)).unwrap();
236 }
237 None => {
238 event_obj.call_method0("record").unwrap();
239 }
240 }
241 })
242 }
243
244 pub fn wait(&mut self, stream: Option<&Stream>) {
248 Python::with_gil(|py| {
249 let event_obj = self.inner.bind(py);
250
251 match stream {
252 Some(stream) => {
253 let stream_obj = stream.inner.bind(py);
254 event_obj.call_method1("wait", (stream_obj,)).unwrap();
255 }
256 None => {
257 event_obj.call_method0("wait").unwrap();
258 }
259 }
260 })
261 }
262
263 pub fn query(&self) -> bool {
265 Python::with_gil(|py| {
266 let event_obj = self.inner.bind(py);
267 event_obj
268 .call_method0("query")
269 .unwrap()
270 .extract::<bool>()
271 .unwrap()
272 })
273 }
274
275 pub fn elapsed_time(&self, end_event: &Event) -> Duration {
280 Python::with_gil(|py| {
281 let event_obj = self.inner.bind(py);
282 let end_event_obj = end_event.inner.bind(py);
283
284 let elapsed_ms = event_obj
285 .call_method1("elapsed_time", (end_event_obj,))
286 .unwrap()
287 .extract::<f64>()
288 .unwrap();
289
290 Duration::from_millis(elapsed_ms as u64)
291 })
292 }
293
294 pub fn synchronize(&self) {
298 Python::with_gil(|py| {
299 let event_obj = self.inner.bind(py);
300 event_obj.call_method0("synchronize").unwrap();
301 })
302 }
303}
304
305#[derive(Debug, Error)]
307pub enum CudaError {
308 #[error(
309 "one or more parameters passed to the API call is not within an acceptable range of values"
310 )]
311 InvalidValue,
312 #[error("the API call failed due to insufficient memory or resources")]
313 MemoryAllocation,
314 #[error("failed to initialize the CUDA driver and runtime")]
315 InitializationError,
316 #[error("CUDA Runtime API call was executed after the CUDA driver has been unloaded")]
317 CudartUnloading,
318 #[error("profiler is not initialized for this run, possibly due to an external profiling tool")]
319 ProfilerDisabled,
320 #[error("deprecated. Attempted to enable/disable profiling without initialization")]
321 ProfilerNotInitialized,
322 #[error("deprecated. Profiling is already started")]
323 ProfilerAlreadyStarted,
324 #[error("deprecated. Profiling is already stopped")]
325 ProfilerAlreadyStopped,
326 #[error("kernel launch requested resources that cannot be satisfied by the current device")]
327 InvalidConfiguration,
328 #[error("one or more of the pitch-related parameters passed to the API call is out of range")]
329 InvalidPitchValue,
330 #[error("the symbol name/identifier passed to the API call is invalid")]
331 InvalidSymbol,
332 #[error("the host pointer passed to the API call is invalid")]
333 InvalidHostPointer,
334 #[error("the device pointer passed to the API call is invalid")]
335 InvalidDevicePointer,
336 #[error("the texture passed to the API call is invalid")]
337 InvalidTexture,
338 #[error("the texture binding is invalid")]
339 InvalidTextureBinding,
340 #[error("the channel descriptor passed to the API call is invalid")]
341 InvalidChannelDescriptor,
342 #[error("the direction of the memcpy operation is invalid")]
343 InvalidMemcpyDirection,
344 #[error(
345 "attempted to take the address of a constant variable, which is forbidden before CUDA 3.1"
346 )]
347 AddressOfConstant,
348 #[error("deprecated. A texture fetch operation failed")]
349 TextureFetchFailed,
350 #[error("deprecated. The texture is not bound for access")]
351 TextureNotBound,
352 #[error("a synchronization operation failed")]
353 SynchronizationError,
354 #[error(
355 "a non-float texture was accessed with linear filtering, which is not supported by CUDA"
356 )]
357 InvalidFilterSetting,
358 #[error(
359 "attempted to read a non-float texture as a normalized float, which is not supported by CUDA"
360 )]
361 InvalidNormSetting,
362 #[error("the API call is not yet implemented")]
363 NotYetImplemented,
364 #[error("an emulated device pointer exceeded the 32-bit address range")]
365 MemoryValueTooLarge,
366 #[error("the CUDA driver is a stub library")]
367 StubLibrary,
368 #[error("the installed NVIDIA CUDA driver is older than the CUDA runtime library")]
369 InsufficientDriver,
370 #[error("the API call requires a newer CUDA driver")]
371 CallRequiresNewerDriver,
372 #[error("the surface passed to the API call is invalid")]
373 InvalidSurface,
374 #[error("multiple global or constant variables share the same string name")]
375 DuplicateVariableName,
376 #[error("multiple textures share the same string name")]
377 DuplicateTextureName,
378 #[error("multiple surfaces share the same string name")]
379 DuplicateSurfaceName,
380 #[error("all CUDA devices are currently busy or unavailable")]
381 DevicesUnavailable,
382 #[error("the current CUDA context is not compatible with the runtime")]
383 IncompatibleDriverContext,
384 #[error("the device function being invoked was not previously configured")]
385 MissingConfiguration,
386 #[error("a previous kernel launch failed")]
387 PriorLaunchFailure,
388 #[error(
389 "the depth of the child grid exceeded the maximum supported number of nested grid launches"
390 )]
391 LaunchMaxDepthExceeded,
392 #[error("a grid launch did not occur because file-scoped textures are unsupported")]
393 LaunchFileScopedTex,
394 #[error("a grid launch did not occur because file-scoped surfaces are unsupported")]
395 LaunchFileScopedSurf,
396 #[error("a call to cudaDeviceSynchronize failed due to exceeding the sync depth")]
397 SyncDepthExceeded,
398 #[error(
399 "a grid launch failed because the launch exceeded the limit of pending device runtime launches"
400 )]
401 LaunchPendingCountExceeded,
402 #[error(
403 "the requested device function does not exist or is not compiled for the proper device architecture"
404 )]
405 InvalidDeviceFunction,
406 #[error("no CUDA-capable devices were detected")]
407 NoDevice,
408 #[error("the device ordinal supplied does not correspond to a valid CUDA device")]
409 InvalidDevice,
410 #[error("the device does not have a valid Grid License")]
411 DeviceNotLicensed,
412 #[error("an internal startup failure occurred in the CUDA runtime")]
413 StartupFailure,
414 #[error("the device kernel image is invalid")]
415 InvalidKernelImage,
416 #[error("the device is not initialized")]
417 DeviceUninitialized,
418 #[error("the buffer object could not be mapped")]
419 MapBufferObjectFailed,
420 #[error("the buffer object could not be unmapped")]
421 UnmapBufferObjectFailed,
422 #[error("the specified array is currently mapped and cannot be destroyed")]
423 ArrayIsMapped,
424 #[error("the resource is already mapped")]
425 AlreadyMapped,
426 #[error("there is no kernel image available that is suitable for the device")]
427 NoKernelImageForDevice,
428 #[error("the resource has already been acquired")]
429 AlreadyAcquired,
430 #[error("the resource is not mapped")]
431 NotMapped,
432 #[error("the mapped resource is not available for access as an array")]
433 NotMappedAsArray,
434 #[error("the mapped resource is not available for access as a pointer")]
435 NotMappedAsPointer,
436 #[error("an uncorrectable ECC error was detected")]
437 ECCUncorrectable,
438 #[error("the specified cudaLimit is not supported by the device")]
439 UnsupportedLimit,
440 #[error("a call tried to access an exclusive-thread device that is already in use")]
441 DeviceAlreadyInUse,
442 #[error("P2P access is not supported across the given devices")]
443 PeerAccessUnsupported,
444 #[error("a PTX compilation failed")]
445 InvalidPtx,
446 #[error("an error occurred with the OpenGL or DirectX context")]
447 InvalidGraphicsContext,
448 #[error("an uncorrectable NVLink error was detected during execution")]
449 NvlinkUncorrectable,
450 #[error("the PTX JIT compiler library was not found")]
451 JitCompilerNotFound,
452 #[error("the provided PTX was compiled with an unsupported toolchain")]
453 UnsupportedPtxVersion,
454 #[error("JIT compilation was disabled")]
455 JitCompilationDisabled,
456 #[error("the provided execution affinity is not supported by the device")]
457 UnsupportedExecAffinity,
458 #[error("the operation is not permitted when the stream is capturing")]
459 StreamCaptureUnsupported,
460 #[error(
461 "the current capture sequence on the stream has been invalidated due to a previous error"
462 )]
463 StreamCaptureInvalidated,
464 #[error("a merge of two independent capture sequences was not allowed")]
465 StreamCaptureMerge,
466 #[error("the capture was not initiated in this stream")]
467 StreamCaptureUnmatched,
468 #[error("a stream capture sequence was passed to cudaStreamEndCapture in a different thread")]
469 StreamCaptureWrongThread,
470 #[error("the wait operation has timed out")]
471 Timeout,
472 #[error("an unknown internal error occurred")]
473 Unknown,
474 #[error("the API call returned a failure")]
475 ApiFailureBase,
476}
477
478pub fn cuda_check(result: cudaError_t) -> Result<(), CudaError> {
479 match result.0 {
480 0 => Ok(()),
481 1 => Err(CudaError::InvalidValue),
482 2 => Err(CudaError::MemoryAllocation),
483 3 => Err(CudaError::InitializationError),
484 4 => Err(CudaError::CudartUnloading),
485 5 => Err(CudaError::ProfilerDisabled),
486 6 => Err(CudaError::ProfilerNotInitialized),
487 7 => Err(CudaError::ProfilerAlreadyStarted),
488 8 => Err(CudaError::ProfilerAlreadyStopped),
489 9 => Err(CudaError::InvalidConfiguration),
490 12 => Err(CudaError::InvalidPitchValue),
491 13 => Err(CudaError::InvalidSymbol),
492 16 => Err(CudaError::InvalidHostPointer),
493 17 => Err(CudaError::InvalidDevicePointer),
494 18 => Err(CudaError::InvalidTexture),
495 19 => Err(CudaError::InvalidTextureBinding),
496 20 => Err(CudaError::InvalidChannelDescriptor),
497 21 => Err(CudaError::InvalidMemcpyDirection),
498 22 => Err(CudaError::AddressOfConstant),
499 23 => Err(CudaError::TextureFetchFailed),
500 24 => Err(CudaError::TextureNotBound),
501 25 => Err(CudaError::SynchronizationError),
502 26 => Err(CudaError::InvalidFilterSetting),
503 27 => Err(CudaError::InvalidNormSetting),
504 31 => Err(CudaError::NotYetImplemented),
505 32 => Err(CudaError::MemoryValueTooLarge),
506 34 => Err(CudaError::StubLibrary),
507 35 => Err(CudaError::InsufficientDriver),
508 36 => Err(CudaError::CallRequiresNewerDriver),
509 37 => Err(CudaError::InvalidSurface),
510 43 => Err(CudaError::DuplicateVariableName),
511 44 => Err(CudaError::DuplicateTextureName),
512 45 => Err(CudaError::DuplicateSurfaceName),
513 46 => Err(CudaError::DevicesUnavailable),
514 49 => Err(CudaError::IncompatibleDriverContext),
515 52 => Err(CudaError::MissingConfiguration),
516 53 => Err(CudaError::PriorLaunchFailure),
517 65 => Err(CudaError::LaunchMaxDepthExceeded),
518 66 => Err(CudaError::LaunchFileScopedTex),
519 67 => Err(CudaError::LaunchFileScopedSurf),
520 68 => Err(CudaError::SyncDepthExceeded),
521 69 => Err(CudaError::LaunchPendingCountExceeded),
522 98 => Err(CudaError::InvalidDeviceFunction),
523 100 => Err(CudaError::NoDevice),
524 101 => Err(CudaError::InvalidDevice),
525 102 => Err(CudaError::DeviceNotLicensed),
526 127 => Err(CudaError::StartupFailure),
527 200 => Err(CudaError::InvalidKernelImage),
528 201 => Err(CudaError::DeviceUninitialized),
529 205 => Err(CudaError::MapBufferObjectFailed),
530 206 => Err(CudaError::UnmapBufferObjectFailed),
531 207 => Err(CudaError::ArrayIsMapped),
532 208 => Err(CudaError::AlreadyMapped),
533 209 => Err(CudaError::NoKernelImageForDevice),
534 210 => Err(CudaError::AlreadyAcquired),
535 211 => Err(CudaError::NotMapped),
536 212 => Err(CudaError::NotMappedAsArray),
537 213 => Err(CudaError::NotMappedAsPointer),
538 214 => Err(CudaError::ECCUncorrectable),
539 215 => Err(CudaError::UnsupportedLimit),
540 216 => Err(CudaError::DeviceAlreadyInUse),
541 217 => Err(CudaError::PeerAccessUnsupported),
542 218 => Err(CudaError::InvalidPtx),
543 219 => Err(CudaError::InvalidGraphicsContext),
544 220 => Err(CudaError::NvlinkUncorrectable),
545 221 => Err(CudaError::JitCompilerNotFound),
546 222 => Err(CudaError::UnsupportedPtxVersion),
547 223 => Err(CudaError::JitCompilationDisabled),
548 224 => Err(CudaError::UnsupportedExecAffinity),
549 900 => Err(CudaError::StreamCaptureUnsupported),
550 901 => Err(CudaError::StreamCaptureInvalidated),
551 902 => Err(CudaError::StreamCaptureMerge),
552 903 => Err(CudaError::StreamCaptureUnmatched),
553 904 => Err(CudaError::StreamCaptureWrongThread),
554 909 => Err(CudaError::Timeout),
555 999 => Err(CudaError::Unknown),
556 _ => panic!("Unknown cudaError_t: {:?}", result.0),
557 }
558}
559
560pub fn set_device(device: CudaDevice) -> Result<(), CudaError> {
561 let index: i8 = device.index().into();
562 unsafe { cuda_check(cudaSetDevice(index.into())) }
564}