refactor(gpu): make it possible to reuse memory in sum_ct_vec

This commit is contained in:
Agnes Leroy
2024-07-08 10:11:49 +02:00
committed by Agnès Leroy
parent f6845a988b
commit dd74063959

View File

@@ -211,7 +211,8 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
cudaStream_t *streams, uint32_t *gpu_indexes, uint32_t gpu_count,
Torus *radix_lwe_out, Torus *terms, int *terms_degree, void **bsks,
uint64_t **ksks, int_sum_ciphertexts_vec_memory<uint64_t> *mem_ptr,
uint32_t num_blocks_in_radix, uint32_t num_radix_in_vec) {
uint32_t num_blocks_in_radix, uint32_t num_radix_in_vec,
int_radix_lut<Torus> *reused_lut = nullptr) {
auto new_blocks = mem_ptr->new_blocks;
auto old_blocks = mem_ptr->old_blocks;
@@ -248,6 +249,17 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
auto max_shared_memory = cuda_get_max_shared_memory(gpu_indexes[0]);
int_radix_lut<Torus> *luts_message_carry;
bool release_reused_lut = false;
if (reused_lut == nullptr) {
release_reused_lut = true;
size_t ch_amount = r / chunk_size;
if (!ch_amount)
ch_amount++;
reused_lut = new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count,
mem_ptr->params, 2,
2 * ch_amount * num_blocks, true);
}
while (r > 2) {
size_t cur_total_blocks = r * num_blocks;
size_t ch_amount = r / chunk_size;
@@ -279,8 +291,9 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
// create lut object for message and carry
// we allocate luts_message_carry in the host function (instead of scratch)
// to reduce average memory consumption
auto luts_message_carry = new int_radix_lut<Torus>(
streams, gpu_indexes, gpu_count, mem_ptr->params, 2, total_count, true);
luts_message_carry =
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count,
mem_ptr->params, 2, total_count, reused_lut);
auto message_acc = luts_message_carry->get_lut(gpu_indexes[0], 0);
auto carry_acc = luts_message_carry->get_lut(gpu_indexes[0], 1);
@@ -349,7 +362,6 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
mem_ptr->params.pbs_base_log, mem_ptr->params.pbs_level,
mem_ptr->params.grouping_factor, total_count, 2, 0,
max_shared_memory, mem_ptr->params.pbs_type, true);
luts_message_carry->release(streams, gpu_indexes, gpu_count);
int rem_blocks = (r > chunk_size) ? r % chunk_size * num_blocks : 0;
int new_blocks_created = 2 * ch_amount * num_blocks;
@@ -361,6 +373,12 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
gpu_indexes[0]);
std::swap(new_blocks, old_blocks);
r = (new_blocks_created + rem_blocks) / num_blocks;
luts_message_carry->release(streams, gpu_indexes, gpu_count);
delete (luts_message_carry);
}
if (release_reused_lut) {
reused_lut->release(streams, gpu_indexes, gpu_count);
delete (reused_lut);
}
host_addition(streams[0], gpu_indexes[0], radix_lwe_out, old_blocks,
@@ -484,7 +502,7 @@ __host__ void host_integer_mult_radix_kb(
host_integer_sum_ciphertexts_vec_kb<Torus, params>(
streams, gpu_indexes, gpu_count, radix_lwe_out, vector_result_sb,
terms_degree, bsks, ksks, mem_ptr->sum_ciphertexts_mem, num_blocks,
2 * num_blocks);
2 * num_blocks, mem_ptr->luts_array);
}
template <typename Torus>