mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
refactor(gpu): make it possible to reuse memory in sum_ct_vec
This commit is contained in:
@@ -211,7 +211,8 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
|
||||
cudaStream_t *streams, uint32_t *gpu_indexes, uint32_t gpu_count,
|
||||
Torus *radix_lwe_out, Torus *terms, int *terms_degree, void **bsks,
|
||||
uint64_t **ksks, int_sum_ciphertexts_vec_memory<uint64_t> *mem_ptr,
|
||||
uint32_t num_blocks_in_radix, uint32_t num_radix_in_vec) {
|
||||
uint32_t num_blocks_in_radix, uint32_t num_radix_in_vec,
|
||||
int_radix_lut<Torus> *reused_lut = nullptr) {
|
||||
|
||||
auto new_blocks = mem_ptr->new_blocks;
|
||||
auto old_blocks = mem_ptr->old_blocks;
|
||||
@@ -248,6 +249,17 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
|
||||
|
||||
auto max_shared_memory = cuda_get_max_shared_memory(gpu_indexes[0]);
|
||||
|
||||
int_radix_lut<Torus> *luts_message_carry;
|
||||
bool release_reused_lut = false;
|
||||
if (reused_lut == nullptr) {
|
||||
release_reused_lut = true;
|
||||
size_t ch_amount = r / chunk_size;
|
||||
if (!ch_amount)
|
||||
ch_amount++;
|
||||
reused_lut = new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count,
|
||||
mem_ptr->params, 2,
|
||||
2 * ch_amount * num_blocks, true);
|
||||
}
|
||||
while (r > 2) {
|
||||
size_t cur_total_blocks = r * num_blocks;
|
||||
size_t ch_amount = r / chunk_size;
|
||||
@@ -279,8 +291,9 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
|
||||
// create lut object for message and carry
|
||||
// we allocate luts_message_carry in the host function (instead of scratch)
|
||||
// to reduce average memory consumption
|
||||
auto luts_message_carry = new int_radix_lut<Torus>(
|
||||
streams, gpu_indexes, gpu_count, mem_ptr->params, 2, total_count, true);
|
||||
luts_message_carry =
|
||||
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count,
|
||||
mem_ptr->params, 2, total_count, reused_lut);
|
||||
|
||||
auto message_acc = luts_message_carry->get_lut(gpu_indexes[0], 0);
|
||||
auto carry_acc = luts_message_carry->get_lut(gpu_indexes[0], 1);
|
||||
@@ -349,7 +362,6 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
|
||||
mem_ptr->params.pbs_base_log, mem_ptr->params.pbs_level,
|
||||
mem_ptr->params.grouping_factor, total_count, 2, 0,
|
||||
max_shared_memory, mem_ptr->params.pbs_type, true);
|
||||
luts_message_carry->release(streams, gpu_indexes, gpu_count);
|
||||
|
||||
int rem_blocks = (r > chunk_size) ? r % chunk_size * num_blocks : 0;
|
||||
int new_blocks_created = 2 * ch_amount * num_blocks;
|
||||
@@ -361,6 +373,12 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
|
||||
gpu_indexes[0]);
|
||||
std::swap(new_blocks, old_blocks);
|
||||
r = (new_blocks_created + rem_blocks) / num_blocks;
|
||||
luts_message_carry->release(streams, gpu_indexes, gpu_count);
|
||||
delete (luts_message_carry);
|
||||
}
|
||||
if (release_reused_lut) {
|
||||
reused_lut->release(streams, gpu_indexes, gpu_count);
|
||||
delete (reused_lut);
|
||||
}
|
||||
|
||||
host_addition(streams[0], gpu_indexes[0], radix_lwe_out, old_blocks,
|
||||
@@ -484,7 +502,7 @@ __host__ void host_integer_mult_radix_kb(
|
||||
host_integer_sum_ciphertexts_vec_kb<Torus, params>(
|
||||
streams, gpu_indexes, gpu_count, radix_lwe_out, vector_result_sb,
|
||||
terms_degree, bsks, ksks, mem_ptr->sum_ciphertexts_mem, num_blocks,
|
||||
2 * num_blocks);
|
||||
2 * num_blocks, mem_ptr->luts_array);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
|
||||
Reference in New Issue
Block a user