mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
feat(GPU-runtime): add per device cache of BS and KS keys.
This commit is contained in:
committed by
Quentin Bourgerie
parent
e42d7bbe64
commit
6eb8841652
@@ -10,6 +10,7 @@
|
||||
#include <map>
|
||||
#include <mutex>
|
||||
#include <pthread.h>
|
||||
#include <vector>
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
@@ -39,15 +40,14 @@ typedef struct FFT {
|
||||
typedef struct RuntimeContext {
|
||||
|
||||
RuntimeContext() = delete;
|
||||
|
||||
RuntimeContext(::concretelang::clientlib::EvaluationKeys evaluationKeys);
|
||||
~RuntimeContext() {
|
||||
#ifdef CONCRETELANG_CUDA_SUPPORT
|
||||
if (bsk_gpu != nullptr) {
|
||||
cuda_drop(bsk_gpu, 0);
|
||||
}
|
||||
if (ksk_gpu != nullptr) {
|
||||
cuda_drop(ksk_gpu, 0);
|
||||
for (int i = 0; i < num_devices; ++i) {
|
||||
if (bsk_gpu[i] != nullptr)
|
||||
cuda_drop(bsk_gpu[i], i);
|
||||
if (ksk_gpu[i] != nullptr)
|
||||
cuda_drop(ksk_gpu[i], i);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
@@ -80,13 +80,13 @@ public:
|
||||
void *get_bsk_gpu(uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level,
|
||||
uint32_t glwe_dim, uint32_t gpu_idx, void *stream) {
|
||||
|
||||
if (bsk_gpu != nullptr) {
|
||||
return bsk_gpu;
|
||||
if (bsk_gpu[gpu_idx] != nullptr) {
|
||||
return bsk_gpu[gpu_idx];
|
||||
}
|
||||
const std::lock_guard<std::mutex> guard(bsk_gpu_mutex);
|
||||
const std::lock_guard<std::mutex> guard(*bsk_gpu_mutex[gpu_idx]);
|
||||
|
||||
if (bsk_gpu != nullptr) {
|
||||
return bsk_gpu;
|
||||
if (bsk_gpu[gpu_idx] != nullptr) {
|
||||
return bsk_gpu[gpu_idx];
|
||||
}
|
||||
|
||||
auto bsk = evaluationKeys.getBootstrapKey(0);
|
||||
@@ -103,20 +103,20 @@ public:
|
||||
// we have to free CPU memory after
|
||||
// conversion
|
||||
cuda_synchronize_device(gpu_idx);
|
||||
bsk_gpu = bsk_gpu_tmp;
|
||||
return bsk_gpu;
|
||||
bsk_gpu[gpu_idx] = bsk_gpu_tmp;
|
||||
return bsk_gpu[gpu_idx];
|
||||
}
|
||||
|
||||
void *get_ksk_gpu(uint32_t level, uint32_t input_lwe_dim,
|
||||
uint32_t output_lwe_dim, uint32_t gpu_idx, void *stream) {
|
||||
|
||||
if (ksk_gpu != nullptr) {
|
||||
return ksk_gpu;
|
||||
if (ksk_gpu[gpu_idx] != nullptr) {
|
||||
return ksk_gpu[gpu_idx];
|
||||
}
|
||||
|
||||
const std::lock_guard<std::mutex> guard(ksk_gpu_mutex);
|
||||
if (ksk_gpu != nullptr) {
|
||||
return ksk_gpu;
|
||||
const std::lock_guard<std::mutex> guard(*ksk_gpu_mutex[gpu_idx]);
|
||||
if (ksk_gpu[gpu_idx] != nullptr) {
|
||||
return ksk_gpu[gpu_idx];
|
||||
}
|
||||
auto ksk = evaluationKeys.getKeyswitchKey(0);
|
||||
|
||||
@@ -126,21 +126,18 @@ public:
|
||||
|
||||
cuda_memcpy_async_to_gpu(ksk_gpu_tmp, (void *)ksk.buffer(), ksk_buffer_size,
|
||||
stream, gpu_idx);
|
||||
// This is currently not 100% async as
|
||||
// we have to free CPU memory after
|
||||
// conversion
|
||||
cuda_synchronize_device(gpu_idx);
|
||||
ksk_gpu = ksk_gpu_tmp;
|
||||
return ksk_gpu;
|
||||
ksk_gpu[gpu_idx] = ksk_gpu_tmp;
|
||||
return ksk_gpu[gpu_idx];
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex bsk_gpu_mutex;
|
||||
void *bsk_gpu;
|
||||
std::mutex ksk_gpu_mutex;
|
||||
void *ksk_gpu;
|
||||
std::vector<std::unique_ptr<std::mutex>> bsk_gpu_mutex;
|
||||
std::vector<void *> bsk_gpu;
|
||||
std::vector<std::unique_ptr<std::mutex>> ksk_gpu_mutex;
|
||||
std::vector<void *> ksk_gpu;
|
||||
int num_devices;
|
||||
#endif
|
||||
|
||||
} RuntimeContext;
|
||||
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -71,8 +71,13 @@ RuntimeContext::RuntimeContext(clientlib::EvaluationKeys evaluationKeys)
|
||||
}
|
||||
|
||||
#ifdef CONCRETELANG_CUDA_SUPPORT
|
||||
bsk_gpu = nullptr;
|
||||
ksk_gpu = nullptr;
|
||||
assert(cudaGetDeviceCount(&num_devices) == cudaSuccess);
|
||||
bsk_gpu.resize(num_devices, nullptr);
|
||||
ksk_gpu.resize(num_devices, nullptr);
|
||||
for (int i = 0; i < num_devices; ++i) {
|
||||
bsk_gpu_mutex.push_back(std::make_unique<std::mutex>());
|
||||
ksk_gpu_mutex.push_back(std::make_unique<std::mutex>());
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user