From 0898cdd05b6526f82e07bb110a9faad1a8c28174 Mon Sep 17 00:00:00 2001 From: Andrei Stoian Date: Mon, 25 Nov 2024 17:19:39 +0100 Subject: [PATCH] feat(gpu): add function to check if a cuda device is available --- backends/tfhe-cuda-backend/cuda/include/device.h | 2 ++ backends/tfhe-cuda-backend/cuda/src/device.cu | 3 +++ backends/tfhe-cuda-backend/src/cuda_bind.rs | 2 ++ tfhe/src/core_crypto/gpu/mod.rs | 6 ++++++ 4 files changed, 13 insertions(+) diff --git a/backends/tfhe-cuda-backend/cuda/include/device.h b/backends/tfhe-cuda-backend/cuda/include/device.h index 3c3a61b8f..431b72536 100644 --- a/backends/tfhe-cuda-backend/cuda/include/device.h +++ b/backends/tfhe-cuda-backend/cuda/include/device.h @@ -42,6 +42,8 @@ void cuda_destroy_stream(cudaStream_t stream, uint32_t gpu_index); void cuda_synchronize_stream(cudaStream_t stream, uint32_t gpu_index); +uint32_t cuda_is_available(); + void *cuda_malloc(uint64_t size, uint32_t gpu_index); void *cuda_malloc_async(uint64_t size, cudaStream_t stream, uint32_t gpu_index); diff --git a/backends/tfhe-cuda-backend/cuda/src/device.cu b/backends/tfhe-cuda-backend/cuda/src/device.cu index e15fc7218..041e228b3 100644 --- a/backends/tfhe-cuda-backend/cuda/src/device.cu +++ b/backends/tfhe-cuda-backend/cuda/src/device.cu @@ -45,6 +45,9 @@ void cuda_synchronize_stream(cudaStream_t stream, uint32_t gpu_index) { check_cuda_error(cudaStreamSynchronize(stream)); } +// Determine if a CUDA device is available at runtime +uint32_t cuda_is_available() { return cudaSetDevice(0) == cudaSuccess; } + /// Unsafe function that will try to allocate even if gpu_index is invalid /// or if there's not enough memory. A safe wrapper around it must call /// cuda_check_valid_malloc() first diff --git a/backends/tfhe-cuda-backend/src/cuda_bind.rs b/backends/tfhe-cuda-backend/src/cuda_bind.rs index 3c359de5e..28e88f253 100644 --- a/backends/tfhe-cuda-backend/src/cuda_bind.rs +++ b/backends/tfhe-cuda-backend/src/cuda_bind.rs @@ -9,6 +9,8 @@ extern "C" { pub fn cuda_synchronize_stream(stream: *mut c_void, gpu_index: u32); + pub fn cuda_is_available() -> u32; + pub fn cuda_malloc(size: u64, gpu_index: u32) -> *mut c_void; pub fn cuda_malloc_async(size: u64, stream: *mut c_void, gpu_index: u32) -> *mut c_void; diff --git a/tfhe/src/core_crypto/gpu/mod.rs b/tfhe/src/core_crypto/gpu/mod.rs index 8fd40719b..721673e16 100644 --- a/tfhe/src/core_crypto/gpu/mod.rs +++ b/tfhe/src/core_crypto/gpu/mod.rs @@ -678,6 +678,12 @@ pub fn synchronize_devices(gpu_count: u32) { } } +// Determine if a cuda device is available, at runtime +pub fn is_cuda_available() -> bool { + let result = unsafe { cuda_is_available() }; + result == 1u32 +} + #[cfg(test)] mod tests { use super::*;