feat(GPU-runtime): add per device cache of BS and KS keys.

This commit is contained in:
Antoniu Pop
2023-02-16 10:30:25 +00:00
committed by Quentin Bourgerie
parent e42d7bbe64
commit 6eb8841652
2 changed files with 32 additions and 30 deletions

View File

@@ -10,6 +10,7 @@
#include <map>
#include <mutex>
#include <pthread.h>
#include <vector>
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/Common/Error.h"
@@ -39,15 +40,14 @@ typedef struct FFT {
typedef struct RuntimeContext {
RuntimeContext() = delete;
RuntimeContext(::concretelang::clientlib::EvaluationKeys evaluationKeys);
~RuntimeContext() {
#ifdef CONCRETELANG_CUDA_SUPPORT
if (bsk_gpu != nullptr) {
cuda_drop(bsk_gpu, 0);
}
if (ksk_gpu != nullptr) {
cuda_drop(ksk_gpu, 0);
for (int i = 0; i < num_devices; ++i) {
if (bsk_gpu[i] != nullptr)
cuda_drop(bsk_gpu[i], i);
if (ksk_gpu[i] != nullptr)
cuda_drop(ksk_gpu[i], i);
}
#endif
};
@@ -80,13 +80,13 @@ public:
void *get_bsk_gpu(uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
uint32_t glwe_dim, uint32_t gpu_idx, void *stream) {
if (bsk_gpu != nullptr) {
return bsk_gpu;
if (bsk_gpu[gpu_idx] != nullptr) {
return bsk_gpu[gpu_idx];
}
const std::lock_guard<std::mutex> guard(bsk_gpu_mutex);
const std::lock_guard<std::mutex> guard(*bsk_gpu_mutex[gpu_idx]);
if (bsk_gpu != nullptr) {
return bsk_gpu;
if (bsk_gpu[gpu_idx] != nullptr) {
return bsk_gpu[gpu_idx];
}
auto bsk = evaluationKeys.getBootstrapKey(0);
@@ -103,20 +103,20 @@ public:
// we have to free CPU memory after
// conversion
cuda_synchronize_device(gpu_idx);
bsk_gpu = bsk_gpu_tmp;
return bsk_gpu;
bsk_gpu[gpu_idx] = bsk_gpu_tmp;
return bsk_gpu[gpu_idx];
}
void *get_ksk_gpu(uint32_t level, uint32_t input_lwe_dim,
uint32_t output_lwe_dim, uint32_t gpu_idx, void *stream) {
if (ksk_gpu != nullptr) {
return ksk_gpu;
if (ksk_gpu[gpu_idx] != nullptr) {
return ksk_gpu[gpu_idx];
}
const std::lock_guard<std::mutex> guard(ksk_gpu_mutex);
if (ksk_gpu != nullptr) {
return ksk_gpu;
const std::lock_guard<std::mutex> guard(*ksk_gpu_mutex[gpu_idx]);
if (ksk_gpu[gpu_idx] != nullptr) {
return ksk_gpu[gpu_idx];
}
auto ksk = evaluationKeys.getKeyswitchKey(0);
@@ -126,21 +126,18 @@ public:
cuda_memcpy_async_to_gpu(ksk_gpu_tmp, (void *)ksk.buffer(), ksk_buffer_size,
stream, gpu_idx);
// This is currently not 100% async as
// we have to free CPU memory after
// conversion
cuda_synchronize_device(gpu_idx);
ksk_gpu = ksk_gpu_tmp;
return ksk_gpu;
ksk_gpu[gpu_idx] = ksk_gpu_tmp;
return ksk_gpu[gpu_idx];
}
private:
std::mutex bsk_gpu_mutex;
void *bsk_gpu;
std::mutex ksk_gpu_mutex;
void *ksk_gpu;
std::vector<std::unique_ptr<std::mutex>> bsk_gpu_mutex;
std::vector<void *> bsk_gpu;
std::vector<std::unique_ptr<std::mutex>> ksk_gpu_mutex;
std::vector<void *> ksk_gpu;
int num_devices;
#endif
} RuntimeContext;
} // namespace concretelang

View File

@@ -71,8 +71,13 @@ RuntimeContext::RuntimeContext(clientlib::EvaluationKeys evaluationKeys)
}
#ifdef CONCRETELANG_CUDA_SUPPORT
bsk_gpu = nullptr;
ksk_gpu = nullptr;
assert(cudaGetDeviceCount(&num_devices) == cudaSuccess);
bsk_gpu.resize(num_devices, nullptr);
ksk_gpu.resize(num_devices, nullptr);
for (int i = 0; i < num_devices; ++i) {
bsk_gpu_mutex.push_back(std::make_unique<std::mutex>());
ksk_gpu_mutex.push_back(std::make_unique<std::mutex>());
}
#endif
}
}