feat(gpu): add function to check if a cuda device is available

This commit is contained in:
Andrei Stoian
2024-11-25 17:19:39 +01:00
committed by Agnès Leroy
parent 9584f57dca
commit 0898cdd05b
4 changed files with 13 additions and 0 deletions

View File

@@ -42,6 +42,8 @@ void cuda_destroy_stream(cudaStream_t stream, uint32_t gpu_index);
void cuda_synchronize_stream(cudaStream_t stream, uint32_t gpu_index);
uint32_t cuda_is_available();
void *cuda_malloc(uint64_t size, uint32_t gpu_index);
void *cuda_malloc_async(uint64_t size, cudaStream_t stream, uint32_t gpu_index);

View File

@@ -45,6 +45,9 @@ void cuda_synchronize_stream(cudaStream_t stream, uint32_t gpu_index) {
check_cuda_error(cudaStreamSynchronize(stream));
}
// Determine if a CUDA device is available at runtime
uint32_t cuda_is_available() { return cudaSetDevice(0) == cudaSuccess; }
/// Unsafe function that will try to allocate even if gpu_index is invalid
/// or if there's not enough memory. A safe wrapper around it must call
/// cuda_check_valid_malloc() first

View File

@@ -9,6 +9,8 @@ extern "C" {
pub fn cuda_synchronize_stream(stream: *mut c_void, gpu_index: u32);
pub fn cuda_is_available() -> u32;
pub fn cuda_malloc(size: u64, gpu_index: u32) -> *mut c_void;
pub fn cuda_malloc_async(size: u64, stream: *mut c_void, gpu_index: u32) -> *mut c_void;

View File

@@ -678,6 +678,12 @@ pub fn synchronize_devices(gpu_count: u32) {
}
}
// Determine if a cuda device is available, at runtime
pub fn is_cuda_available() -> bool {
let result = unsafe { cuda_is_available() };
result == 1u32
}
#[cfg(test)]
mod tests {
use super::*;