chore(gpu): create lut once for all layers in sum_ct_vec

This commit is contained in:
Agnes Leroy
2024-07-09 10:37:39 +02:00
committed by Agnès Leroy
parent dd74063959
commit 7542c89679

View File

@@ -249,7 +249,9 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
auto max_shared_memory = cuda_get_max_shared_memory(gpu_indexes[0]);
int_radix_lut<Torus> *luts_message_carry;
// create lut object for message and carry
// we allocate luts_message_carry in the host function (instead of scratch)
// to reduce average memory consumption
bool release_reused_lut = false;
if (reused_lut == nullptr) {
release_reused_lut = true;
@@ -260,6 +262,27 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
mem_ptr->params, 2,
2 * ch_amount * num_blocks, true);
}
int_radix_lut<Torus> *luts_message_carry = reused_lut;
auto message_acc = luts_message_carry->get_lut(gpu_indexes[0], 0);
auto carry_acc = luts_message_carry->get_lut(gpu_indexes[0], 1);
// define functions for each accumulator
auto lut_f_message = [message_modulus](Torus x) -> Torus {
return x % message_modulus;
};
auto lut_f_carry = [message_modulus](Torus x) -> Torus {
return x / message_modulus;
};
// generate accumulators
generate_device_accumulator<Torus>(
streams[0], gpu_indexes[0], message_acc, glwe_dimension, polynomial_size,
message_modulus, carry_modulus, lut_f_message);
generate_device_accumulator<Torus>(
streams[0], gpu_indexes[0], carry_acc, glwe_dimension, polynomial_size,
message_modulus, carry_modulus, lut_f_carry);
luts_message_carry->broadcast_lut(streams, gpu_indexes, gpu_indexes[0]);
while (r > 2) {
size_t cur_total_blocks = r * num_blocks;
size_t ch_amount = r / chunk_size;
@@ -288,32 +311,6 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
h_smart_copy_out, ch_amount, r, num_blocks, chunk_size, message_max,
total_count, message_count, carry_count, sm_copy_count);
// create lut object for message and carry
// we allocate luts_message_carry in the host function (instead of scratch)
// to reduce average memory consumption
luts_message_carry =
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count,
mem_ptr->params, 2, total_count, reused_lut);
auto message_acc = luts_message_carry->get_lut(gpu_indexes[0], 0);
auto carry_acc = luts_message_carry->get_lut(gpu_indexes[0], 1);
// define functions for each accumulator
auto lut_f_message = [message_modulus](Torus x) -> Torus {
return x % message_modulus;
};
auto lut_f_carry = [message_modulus](Torus x) -> Torus {
return x / message_modulus;
};
// generate accumulators
generate_device_accumulator<Torus>(
streams[0], gpu_indexes[0], message_acc, glwe_dimension,
polynomial_size, message_modulus, carry_modulus, lut_f_message);
generate_device_accumulator<Torus>(
streams[0], gpu_indexes[0], carry_acc, glwe_dimension, polynomial_size,
message_modulus, carry_modulus, lut_f_carry);
auto lwe_indexes_in = luts_message_carry->lwe_indexes_in;
auto lwe_indexes_out = luts_message_carry->lwe_indexes_out;
@@ -373,8 +370,6 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
gpu_indexes[0]);
std::swap(new_blocks, old_blocks);
r = (new_blocks_created + rem_blocks) / num_blocks;
luts_message_carry->release(streams, gpu_indexes, gpu_count);
delete (luts_message_carry);
}
if (release_reused_lut) {
reused_lut->release(streams, gpu_indexes, gpu_count);