enhance(runtime/gpu): Cache keys copy to gpu

This commit is contained in:
Quentin Bourgerie
2022-10-14 20:40:49 +02:00
parent f36e1fe882
commit 792d46fa80
2 changed files with 101 additions and 54 deletions

View File

@@ -17,6 +17,19 @@
#include "concrete-core-ffi.h"
#include "concretelang/Common/Error.h"
#ifdef CONCRETELANG_CUDA_SUPPORT
// We need to define the double2 struct from the CUDA backend header files
// This shouldn't be defined here, but included along with concrete-cuda header
// files
typedef struct double2 {
double x, y;
} double2;
// From concrete-cuda
#include "bootstrap.h"
#include "device.h"
#include "keyswitch.h"
#endif
namespace mlir {
namespace concretelang {
@@ -24,16 +37,14 @@ typedef struct RuntimeContext {
RuntimeContext() {
CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &default_engine));
#ifdef CONCRETELANG_CUDA_SUPPORT
bsk_gpu = nullptr;
ksk_gpu = nullptr;
#endif
}
/// Ensure that the engines map is not copied
RuntimeContext(const RuntimeContext &ctx)
: evaluationKeys(ctx.evaluationKeys) {
CAPI_ASSERT_ERROR(new_default_engine(best_seeder, &default_engine));
}
RuntimeContext(const RuntimeContext &&other)
: evaluationKeys(other.evaluationKeys),
default_engine(other.default_engine) {}
RuntimeContext(const RuntimeContext &ctx){};
~RuntimeContext() {
CAPI_ASSERT_ERROR(destroy_default_engine(default_engine));
@@ -43,6 +54,14 @@ typedef struct RuntimeContext {
if (fbsk != nullptr) {
CAPI_ASSERT_ERROR(destroy_fft_fourier_lwe_bootstrap_key_u64(fbsk));
}
#ifdef CONCRETELANG_CUDA_SUPPORT
if (bsk_gpu != nullptr) {
cuda_drop(bsk_gpu, 0);
}
if (ksk_gpu != nullptr) {
cuda_drop(ksk_gpu, 0);
}
#endif
}
FftEngine *get_fft_engine() {
@@ -80,6 +99,70 @@ typedef struct RuntimeContext {
return fbsk;
}
#ifdef CONCRETELANG_CUDA_SUPPORT
void *get_bsk_gpu(uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t glwe_dim, uint32_t gpu_idx, void *stream) {
if (bsk_gpu != nullptr) {
return bsk_gpu;
}
const std::lock_guard<std::mutex> guard(bsk_gpu_mutex);
if (bsk_gpu != nullptr) {
return bsk_gpu;
}
LweBootstrapKey64 *bsk = get_bsk();
size_t bsk_buffer_len =
input_lwe_dim * (glwe_dim + 1) * (glwe_dim + 1) * poly_size * level;
size_t bsk_buffer_size = bsk_buffer_len * sizeof(uint64_t);
uint64_t *bsk_buffer =
(uint64_t *)aligned_alloc(U64_ALIGNMENT, bsk_buffer_size);
size_t bsk_gpu_buffer_size = bsk_buffer_len * sizeof(double);
bsk_gpu = cuda_malloc(bsk_gpu_buffer_size, gpu_idx);
CAPI_ASSERT_ERROR(
default_engine_discard_convert_lwe_bootstrap_key_to_lwe_bootstrap_key_mut_view_u64_raw_ptr_buffers(
default_engine, bsk, bsk_buffer));
cuda_initialize_twiddles(poly_size, gpu_idx);
cuda_convert_lwe_bootstrap_key_64(bsk_gpu, bsk_buffer, stream, gpu_idx,
input_lwe_dim, glwe_dim, level,
poly_size);
// This is currently not 100% async as we have to free CPU memory after
// conversion
cuda_synchronize_device(gpu_idx);
free(bsk_buffer);
return bsk_gpu;
}
void *get_ksk_gpu(uint32_t level, uint32_t input_lwe_dim,
uint32_t output_lwe_dim, uint32_t gpu_idx, void *stream) {
if (ksk_gpu != nullptr) {
return ksk_gpu;
}
const std::lock_guard<std::mutex> guard(ksk_gpu_mutex);
if (ksk_gpu != nullptr) {
return ksk_gpu;
}
LweKeyswitchKey64 *ksk = get_ksk();
size_t ksk_buffer_len = input_lwe_dim * (output_lwe_dim + 1) * level;
size_t ksk_buffer_size = sizeof(uint64_t) * ksk_buffer_len;
uint64_t *ksk_buffer =
(uint64_t *)aligned_alloc(U64_ALIGNMENT, ksk_buffer_size);
void *ksk_gpu = cuda_malloc(ksk_buffer_size, gpu_idx);
CAPI_ASSERT_ERROR(
default_engine_discard_convert_lwe_keyswitch_key_to_lwe_keyswitch_key_mut_view_u64_raw_ptr_buffers(
default_engine, ksk, ksk_buffer));
cuda_memcpy_async_to_gpu(ksk_gpu, ksk_buffer, ksk_buffer_size, stream,
gpu_idx);
// This is currently not 100% async as we have to free CPU memory after
// conversion
cuda_synchronize_device(gpu_idx);
free(ksk_buffer);
return ksk_gpu;
}
#endif
LweBootstrapKey64 *get_bsk() { return evaluationKeys.getBsk(); }
LweKeyswitchKey64 *get_ksk() { return evaluationKeys.getKsk(); }
@@ -102,6 +185,13 @@ private:
std::map<pthread_t, FftEngine *> fft_engines;
std::mutex engines_map_guard;
#ifdef CONCRETELANG_CUDA_SUPPORT
std::mutex bsk_gpu_mutex;
void *bsk_gpu;
std::mutex ksk_gpu_mutex;
void *ksk_gpu;
#endif
} RuntimeContext;
} // namespace concretelang