mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
chore(gpu): create lut once for all layers in sum_ct_vec
This commit is contained in:
@@ -249,7 +249,9 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
|
||||
|
||||
auto max_shared_memory = cuda_get_max_shared_memory(gpu_indexes[0]);
|
||||
|
||||
int_radix_lut<Torus> *luts_message_carry;
|
||||
// create lut object for message and carry
|
||||
// we allocate luts_message_carry in the host function (instead of scratch)
|
||||
// to reduce average memory consumption
|
||||
bool release_reused_lut = false;
|
||||
if (reused_lut == nullptr) {
|
||||
release_reused_lut = true;
|
||||
@@ -260,6 +262,27 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
|
||||
mem_ptr->params, 2,
|
||||
2 * ch_amount * num_blocks, true);
|
||||
}
|
||||
int_radix_lut<Torus> *luts_message_carry = reused_lut;
|
||||
auto message_acc = luts_message_carry->get_lut(gpu_indexes[0], 0);
|
||||
auto carry_acc = luts_message_carry->get_lut(gpu_indexes[0], 1);
|
||||
|
||||
// define functions for each accumulator
|
||||
auto lut_f_message = [message_modulus](Torus x) -> Torus {
|
||||
return x % message_modulus;
|
||||
};
|
||||
auto lut_f_carry = [message_modulus](Torus x) -> Torus {
|
||||
return x / message_modulus;
|
||||
};
|
||||
|
||||
// generate accumulators
|
||||
generate_device_accumulator<Torus>(
|
||||
streams[0], gpu_indexes[0], message_acc, glwe_dimension, polynomial_size,
|
||||
message_modulus, carry_modulus, lut_f_message);
|
||||
generate_device_accumulator<Torus>(
|
||||
streams[0], gpu_indexes[0], carry_acc, glwe_dimension, polynomial_size,
|
||||
message_modulus, carry_modulus, lut_f_carry);
|
||||
luts_message_carry->broadcast_lut(streams, gpu_indexes, gpu_indexes[0]);
|
||||
|
||||
while (r > 2) {
|
||||
size_t cur_total_blocks = r * num_blocks;
|
||||
size_t ch_amount = r / chunk_size;
|
||||
@@ -288,32 +311,6 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
|
||||
h_smart_copy_out, ch_amount, r, num_blocks, chunk_size, message_max,
|
||||
total_count, message_count, carry_count, sm_copy_count);
|
||||
|
||||
// create lut object for message and carry
|
||||
// we allocate luts_message_carry in the host function (instead of scratch)
|
||||
// to reduce average memory consumption
|
||||
luts_message_carry =
|
||||
new int_radix_lut<Torus>(streams, gpu_indexes, gpu_count,
|
||||
mem_ptr->params, 2, total_count, reused_lut);
|
||||
|
||||
auto message_acc = luts_message_carry->get_lut(gpu_indexes[0], 0);
|
||||
auto carry_acc = luts_message_carry->get_lut(gpu_indexes[0], 1);
|
||||
|
||||
// define functions for each accumulator
|
||||
auto lut_f_message = [message_modulus](Torus x) -> Torus {
|
||||
return x % message_modulus;
|
||||
};
|
||||
auto lut_f_carry = [message_modulus](Torus x) -> Torus {
|
||||
return x / message_modulus;
|
||||
};
|
||||
|
||||
// generate accumulators
|
||||
generate_device_accumulator<Torus>(
|
||||
streams[0], gpu_indexes[0], message_acc, glwe_dimension,
|
||||
polynomial_size, message_modulus, carry_modulus, lut_f_message);
|
||||
generate_device_accumulator<Torus>(
|
||||
streams[0], gpu_indexes[0], carry_acc, glwe_dimension, polynomial_size,
|
||||
message_modulus, carry_modulus, lut_f_carry);
|
||||
|
||||
auto lwe_indexes_in = luts_message_carry->lwe_indexes_in;
|
||||
auto lwe_indexes_out = luts_message_carry->lwe_indexes_out;
|
||||
|
||||
@@ -373,8 +370,6 @@ __host__ void host_integer_sum_ciphertexts_vec_kb(
|
||||
gpu_indexes[0]);
|
||||
std::swap(new_blocks, old_blocks);
|
||||
r = (new_blocks_created + rem_blocks) / num_blocks;
|
||||
luts_message_carry->release(streams, gpu_indexes, gpu_count);
|
||||
delete (luts_message_carry);
|
||||
}
|
||||
if (release_reused_lut) {
|
||||
reused_lut->release(streams, gpu_indexes, gpu_count);
|
||||
|
||||
Reference in New Issue
Block a user