1use std::time::Duration;
11
12use derive_more::Into;
13use monarch_types::py_global;
14use pyo3::Py;
15use pyo3::PyAny;
16use pyo3::prelude::*;
17use thiserror::Error;
18use torch_sys2::CudaDevice;
19
20py_global!(cuda_stream_class, "torch.cuda", "Stream");
22py_global!(cuda_event_class, "torch.cuda", "Event");
23py_global!(cuda_current_stream, "torch.cuda", "current_stream");
24py_global!(cuda_set_stream, "torch.cuda", "set_stream");
25py_global!(cuda_current_device, "torch.cuda", "current_device");
26py_global!(cuda_set_device, "torch.cuda", "set_device");
27
28#[derive(Debug, Into)]
34#[into(ref)]
35pub struct Stream {
36 inner: Py<PyAny>,
37}
38
39impl Clone for Stream {
40 fn clone(&self) -> Self {
41 Python::attach(|py| Self {
42 inner: self.inner.clone_ref(py),
43 })
44 }
45}
46
47impl Stream {
48 pub fn new() -> Self {
50 Python::attach(|py| {
51 let stream = cuda_stream_class(py).call0().unwrap();
52 Self {
53 inner: stream.into(),
54 }
55 })
56 }
57
58 pub fn clone_ref(&self, py: Python<'_>) -> Self {
60 Self {
61 inner: self.inner.clone_ref(py),
62 }
63 }
64
65 pub fn new_with_device(device: CudaDevice) -> Self {
67 Python::attach(|py| {
68 let device_idx: i8 = device.index().into();
69 let stream = cuda_stream_class(py).call1((device_idx,)).unwrap();
70 Self {
71 inner: stream.into(),
72 }
73 })
74 }
75
76 pub fn get_current_stream() -> Self {
78 Python::attach(|py| {
79 let stream = cuda_current_stream(py).call0().unwrap();
80 Self {
81 inner: stream.into(),
82 }
83 })
84 }
85
86 pub fn get_current_stream_on_device(device: CudaDevice) -> Self {
88 Python::attach(|py| {
89 let device_idx: i8 = device.index().into();
90 let stream = cuda_current_stream(py).call1((device_idx,)).unwrap();
91 Self {
92 inner: stream.into(),
93 }
94 })
95 }
96
97 pub fn set_current_stream(stream: &Stream) {
100 Python::attach(|py| {
101 let stream_obj = stream.inner.bind(py);
102
103 let current_device = cuda_current_device(py)
105 .call0()
106 .unwrap()
107 .extract::<i64>()
108 .unwrap();
109
110 let stream_device = stream_obj
111 .getattr("device_index")
112 .unwrap()
113 .extract::<i64>()
114 .unwrap();
115
116 if current_device != stream_device {
118 cuda_set_device(py).call1((stream_device,)).unwrap();
119 }
120
121 cuda_set_stream(py).call1((stream_obj,)).unwrap();
123 })
124 }
125
126 pub fn wait_event(&self, event: &mut Event) {
128 event.wait(Some(self))
129 }
130
131 pub fn wait_stream(&self, stream: &Stream) {
136 self.wait_event(&mut stream.record_event(None))
137 }
138
139 pub fn record_event(&self, event: Option<Event>) -> Event {
142 let mut event = event.unwrap_or(Event::new());
143 event.record(Some(self));
144 event
145 }
146
147 pub fn query(&self) -> bool {
149 Python::attach(|py| {
150 let stream_obj = self.inner.bind(py);
151 stream_obj
152 .call_method0("query")
153 .unwrap()
154 .extract::<bool>()
155 .unwrap()
156 })
157 }
158
159 pub fn synchronize(&self) {
161 Python::attach(|py| {
162 let stream_obj = self.inner.bind(py);
163 stream_obj.call_method0("synchronize").unwrap();
164 })
165 }
166
167 #[cfg(feature = "cuda")]
169 pub fn stream(&self) -> nccl_sys::cudaStream_t {
170 Python::attach(|py| {
171 let stream_obj = self.inner.bind(py);
172 let cuda_stream = stream_obj.getattr("cuda_stream").unwrap();
173
174 let ptr = cuda_stream.extract::<usize>().unwrap();
176
177 ptr as nccl_sys::cudaStream_t
178 })
179 }
180}
181
182impl PartialEq for Stream {
183 fn eq(&self, other: &Self) -> bool {
184 #[cfg(feature = "cuda")]
185 {
186 self.stream() == other.stream()
187 }
188
189 #[cfg(not(feature = "cuda"))]
190 {
191 self.inner.as_ptr() == other.inner.as_ptr()
192 }
193 }
194}
195
196#[derive(Debug)]
209pub struct Event {
210 inner: Py<PyAny>,
211}
212
213impl Clone for Event {
214 fn clone(&self) -> Self {
215 Python::attach(|py| Self {
216 inner: self.inner.clone_ref(py),
217 })
218 }
219}
220
221impl Event {
222 pub fn new() -> Self {
225 Python::attach(|py| {
226 let event = cuda_event_class(py).call0().unwrap();
227 Self {
228 inner: event.into(),
229 }
230 })
231 }
232
233 pub fn record(&mut self, stream: Option<&Stream>) {
237 Python::attach(|py| {
238 let event_obj = self.inner.bind(py);
239
240 match stream {
241 Some(stream) => {
242 let stream_obj = stream.inner.bind(py);
243 event_obj.call_method1("record", (stream_obj,)).unwrap();
244 }
245 None => {
246 event_obj.call_method0("record").unwrap();
247 }
248 }
249 })
250 }
251
252 pub fn wait(&mut self, stream: Option<&Stream>) {
256 Python::attach(|py| {
257 let event_obj = self.inner.bind(py);
258
259 match stream {
260 Some(stream) => {
261 let stream_obj = stream.inner.bind(py);
262 event_obj.call_method1("wait", (stream_obj,)).unwrap();
263 }
264 None => {
265 event_obj.call_method0("wait").unwrap();
266 }
267 }
268 })
269 }
270
271 pub fn query(&self) -> bool {
273 Python::attach(|py| {
274 let event_obj = self.inner.bind(py);
275 event_obj
276 .call_method0("query")
277 .unwrap()
278 .extract::<bool>()
279 .unwrap()
280 })
281 }
282
283 pub fn elapsed_time(&self, end_event: &Event) -> Duration {
288 Python::attach(|py| {
289 let event_obj = self.inner.bind(py);
290 let end_event_obj = end_event.inner.bind(py);
291
292 let elapsed_ms = event_obj
293 .call_method1("elapsed_time", (end_event_obj,))
294 .unwrap()
295 .extract::<f64>()
296 .unwrap();
297
298 Duration::from_millis(elapsed_ms as u64)
299 })
300 }
301
302 pub fn synchronize(&self) {
306 Python::attach(|py| {
307 let event_obj = self.inner.bind(py);
308 event_obj.call_method0("synchronize").unwrap();
309 })
310 }
311}
312
313#[derive(Debug, Error)]
315pub enum CudaError {
316 #[error(
317 "one or more parameters passed to the API call is not within an acceptable range of values"
318 )]
319 InvalidValue,
320 #[error("the API call failed due to insufficient memory or resources")]
321 MemoryAllocation,
322 #[error("failed to initialize the CUDA driver and runtime")]
323 InitializationError,
324 #[error("CUDA Runtime API call was executed after the CUDA driver has been unloaded")]
325 CudartUnloading,
326 #[error("profiler is not initialized for this run, possibly due to an external profiling tool")]
327 ProfilerDisabled,
328 #[error("deprecated. Attempted to enable/disable profiling without initialization")]
329 ProfilerNotInitialized,
330 #[error("deprecated. Profiling is already started")]
331 ProfilerAlreadyStarted,
332 #[error("deprecated. Profiling is already stopped")]
333 ProfilerAlreadyStopped,
334 #[error("kernel launch requested resources that cannot be satisfied by the current device")]
335 InvalidConfiguration,
336 #[error("one or more of the pitch-related parameters passed to the API call is out of range")]
337 InvalidPitchValue,
338 #[error("the symbol name/identifier passed to the API call is invalid")]
339 InvalidSymbol,
340 #[error("the host pointer passed to the API call is invalid")]
341 InvalidHostPointer,
342 #[error("the device pointer passed to the API call is invalid")]
343 InvalidDevicePointer,
344 #[error("the texture passed to the API call is invalid")]
345 InvalidTexture,
346 #[error("the texture binding is invalid")]
347 InvalidTextureBinding,
348 #[error("the channel descriptor passed to the API call is invalid")]
349 InvalidChannelDescriptor,
350 #[error("the direction of the memcpy operation is invalid")]
351 InvalidMemcpyDirection,
352 #[error(
353 "attempted to take the address of a constant variable, which is forbidden before CUDA 3.1"
354 )]
355 AddressOfConstant,
356 #[error("deprecated. A texture fetch operation failed")]
357 TextureFetchFailed,
358 #[error("deprecated. The texture is not bound for access")]
359 TextureNotBound,
360 #[error("a synchronization operation failed")]
361 SynchronizationError,
362 #[error(
363 "a non-float texture was accessed with linear filtering, which is not supported by CUDA"
364 )]
365 InvalidFilterSetting,
366 #[error(
367 "attempted to read a non-float texture as a normalized float, which is not supported by CUDA"
368 )]
369 InvalidNormSetting,
370 #[error("the API call is not yet implemented")]
371 NotYetImplemented,
372 #[error("an emulated device pointer exceeded the 32-bit address range")]
373 MemoryValueTooLarge,
374 #[error("the CUDA driver is a stub library")]
375 StubLibrary,
376 #[error("the installed NVIDIA CUDA driver is older than the CUDA runtime library")]
377 InsufficientDriver,
378 #[error("the API call requires a newer CUDA driver")]
379 CallRequiresNewerDriver,
380 #[error("the surface passed to the API call is invalid")]
381 InvalidSurface,
382 #[error("multiple global or constant variables share the same string name")]
383 DuplicateVariableName,
384 #[error("multiple textures share the same string name")]
385 DuplicateTextureName,
386 #[error("multiple surfaces share the same string name")]
387 DuplicateSurfaceName,
388 #[error("all CUDA devices are currently busy or unavailable")]
389 DevicesUnavailable,
390 #[error("the current CUDA context is not compatible with the runtime")]
391 IncompatibleDriverContext,
392 #[error("the device function being invoked was not previously configured")]
393 MissingConfiguration,
394 #[error("a previous kernel launch failed")]
395 PriorLaunchFailure,
396 #[error(
397 "the depth of the child grid exceeded the maximum supported number of nested grid launches"
398 )]
399 LaunchMaxDepthExceeded,
400 #[error("a grid launch did not occur because file-scoped textures are unsupported")]
401 LaunchFileScopedTex,
402 #[error("a grid launch did not occur because file-scoped surfaces are unsupported")]
403 LaunchFileScopedSurf,
404 #[error("a call to cudaDeviceSynchronize failed due to exceeding the sync depth")]
405 SyncDepthExceeded,
406 #[error(
407 "a grid launch failed because the launch exceeded the limit of pending device runtime launches"
408 )]
409 LaunchPendingCountExceeded,
410 #[error(
411 "the requested device function does not exist or is not compiled for the proper device architecture"
412 )]
413 InvalidDeviceFunction,
414 #[error("no CUDA-capable devices were detected")]
415 NoDevice,
416 #[error("the device ordinal supplied does not correspond to a valid CUDA device")]
417 InvalidDevice,
418 #[error("the device does not have a valid Grid License")]
419 DeviceNotLicensed,
420 #[error("an internal startup failure occurred in the CUDA runtime")]
421 StartupFailure,
422 #[error("the device kernel image is invalid")]
423 InvalidKernelImage,
424 #[error("the device is not initialized")]
425 DeviceUninitialized,
426 #[error("the buffer object could not be mapped")]
427 MapBufferObjectFailed,
428 #[error("the buffer object could not be unmapped")]
429 UnmapBufferObjectFailed,
430 #[error("the specified array is currently mapped and cannot be destroyed")]
431 ArrayIsMapped,
432 #[error("the resource is already mapped")]
433 AlreadyMapped,
434 #[error("there is no kernel image available that is suitable for the device")]
435 NoKernelImageForDevice,
436 #[error("the resource has already been acquired")]
437 AlreadyAcquired,
438 #[error("the resource is not mapped")]
439 NotMapped,
440 #[error("the mapped resource is not available for access as an array")]
441 NotMappedAsArray,
442 #[error("the mapped resource is not available for access as a pointer")]
443 NotMappedAsPointer,
444 #[error("an uncorrectable ECC error was detected")]
445 ECCUncorrectable,
446 #[error("the specified cudaLimit is not supported by the device")]
447 UnsupportedLimit,
448 #[error("a call tried to access an exclusive-thread device that is already in use")]
449 DeviceAlreadyInUse,
450 #[error("P2P access is not supported across the given devices")]
451 PeerAccessUnsupported,
452 #[error("a PTX compilation failed")]
453 InvalidPtx,
454 #[error("an error occurred with the OpenGL or DirectX context")]
455 InvalidGraphicsContext,
456 #[error("an uncorrectable NVLink error was detected during execution")]
457 NvlinkUncorrectable,
458 #[error("the PTX JIT compiler library was not found")]
459 JitCompilerNotFound,
460 #[error("the provided PTX was compiled with an unsupported toolchain")]
461 UnsupportedPtxVersion,
462 #[error("JIT compilation was disabled")]
463 JitCompilationDisabled,
464 #[error("the provided execution affinity is not supported by the device")]
465 UnsupportedExecAffinity,
466 #[error("the operation is not permitted when the stream is capturing")]
467 StreamCaptureUnsupported,
468 #[error(
469 "the current capture sequence on the stream has been invalidated due to a previous error"
470 )]
471 StreamCaptureInvalidated,
472 #[error("a merge of two independent capture sequences was not allowed")]
473 StreamCaptureMerge,
474 #[error("the capture was not initiated in this stream")]
475 StreamCaptureUnmatched,
476 #[error("a stream capture sequence was passed to cudaStreamEndCapture in a different thread")]
477 StreamCaptureWrongThread,
478 #[error("the wait operation has timed out")]
479 Timeout,
480 #[error("an unknown internal error occurred")]
481 Unknown,
482 #[error("the API call returned a failure")]
483 ApiFailureBase,
484}
485
486#[cfg(feature = "cuda")]
487pub fn cuda_check(result: nccl_sys::cudaError_t) -> Result<(), CudaError> {
488 match result.0 {
489 0 => Ok(()),
490 1 => Err(CudaError::InvalidValue),
491 2 => Err(CudaError::MemoryAllocation),
492 3 => Err(CudaError::InitializationError),
493 4 => Err(CudaError::CudartUnloading),
494 5 => Err(CudaError::ProfilerDisabled),
495 6 => Err(CudaError::ProfilerNotInitialized),
496 7 => Err(CudaError::ProfilerAlreadyStarted),
497 8 => Err(CudaError::ProfilerAlreadyStopped),
498 9 => Err(CudaError::InvalidConfiguration),
499 12 => Err(CudaError::InvalidPitchValue),
500 13 => Err(CudaError::InvalidSymbol),
501 16 => Err(CudaError::InvalidHostPointer),
502 17 => Err(CudaError::InvalidDevicePointer),
503 18 => Err(CudaError::InvalidTexture),
504 19 => Err(CudaError::InvalidTextureBinding),
505 20 => Err(CudaError::InvalidChannelDescriptor),
506 21 => Err(CudaError::InvalidMemcpyDirection),
507 22 => Err(CudaError::AddressOfConstant),
508 23 => Err(CudaError::TextureFetchFailed),
509 24 => Err(CudaError::TextureNotBound),
510 25 => Err(CudaError::SynchronizationError),
511 26 => Err(CudaError::InvalidFilterSetting),
512 27 => Err(CudaError::InvalidNormSetting),
513 31 => Err(CudaError::NotYetImplemented),
514 32 => Err(CudaError::MemoryValueTooLarge),
515 34 => Err(CudaError::StubLibrary),
516 35 => Err(CudaError::InsufficientDriver),
517 36 => Err(CudaError::CallRequiresNewerDriver),
518 37 => Err(CudaError::InvalidSurface),
519 43 => Err(CudaError::DuplicateVariableName),
520 44 => Err(CudaError::DuplicateTextureName),
521 45 => Err(CudaError::DuplicateSurfaceName),
522 46 => Err(CudaError::DevicesUnavailable),
523 49 => Err(CudaError::IncompatibleDriverContext),
524 52 => Err(CudaError::MissingConfiguration),
525 53 => Err(CudaError::PriorLaunchFailure),
526 65 => Err(CudaError::LaunchMaxDepthExceeded),
527 66 => Err(CudaError::LaunchFileScopedTex),
528 67 => Err(CudaError::LaunchFileScopedSurf),
529 68 => Err(CudaError::SyncDepthExceeded),
530 69 => Err(CudaError::LaunchPendingCountExceeded),
531 98 => Err(CudaError::InvalidDeviceFunction),
532 100 => Err(CudaError::NoDevice),
533 101 => Err(CudaError::InvalidDevice),
534 102 => Err(CudaError::DeviceNotLicensed),
535 127 => Err(CudaError::StartupFailure),
536 200 => Err(CudaError::InvalidKernelImage),
537 201 => Err(CudaError::DeviceUninitialized),
538 205 => Err(CudaError::MapBufferObjectFailed),
539 206 => Err(CudaError::UnmapBufferObjectFailed),
540 207 => Err(CudaError::ArrayIsMapped),
541 208 => Err(CudaError::AlreadyMapped),
542 209 => Err(CudaError::NoKernelImageForDevice),
543 210 => Err(CudaError::AlreadyAcquired),
544 211 => Err(CudaError::NotMapped),
545 212 => Err(CudaError::NotMappedAsArray),
546 213 => Err(CudaError::NotMappedAsPointer),
547 214 => Err(CudaError::ECCUncorrectable),
548 215 => Err(CudaError::UnsupportedLimit),
549 216 => Err(CudaError::DeviceAlreadyInUse),
550 217 => Err(CudaError::PeerAccessUnsupported),
551 218 => Err(CudaError::InvalidPtx),
552 219 => Err(CudaError::InvalidGraphicsContext),
553 220 => Err(CudaError::NvlinkUncorrectable),
554 221 => Err(CudaError::JitCompilerNotFound),
555 222 => Err(CudaError::UnsupportedPtxVersion),
556 223 => Err(CudaError::JitCompilationDisabled),
557 224 => Err(CudaError::UnsupportedExecAffinity),
558 900 => Err(CudaError::StreamCaptureUnsupported),
559 901 => Err(CudaError::StreamCaptureInvalidated),
560 902 => Err(CudaError::StreamCaptureMerge),
561 903 => Err(CudaError::StreamCaptureUnmatched),
562 904 => Err(CudaError::StreamCaptureWrongThread),
563 909 => Err(CudaError::Timeout),
564 999 => Err(CudaError::Unknown),
565 _ => panic!("Unknown cudaError_t: {:?}", result.0),
566 }
567}
568
569#[cfg(feature = "cuda")]
570pub fn set_device(device: CudaDevice) -> Result<(), CudaError> {
571 let index: i8 = device.index().into();
572 unsafe { cuda_check(nccl_sys::cudaSetDevice(index.into())) }
574}
575
576#[cfg(not(feature = "cuda"))]
577pub fn set_device(_device: CudaDevice) -> Result<(), CudaError> {
578 panic!("set_device requires the `cuda` feature (CUDA/ROCm runtime)")
579}