fix(gpu): create message extract lut only when needed

This commit is contained in:
Guillermo Oyarzun
2025-08-08 11:25:14 +02:00
parent a63207af9e
commit 4a3be71bd7

View File

@@ -2357,9 +2357,11 @@ template <typename Torus> struct int_sc_prop_memory {
grouping_size, num_groups, allocate_gpu_memory, size_tracker);
// Step 3 elements
int num_luts_message_extract =
requested_flag == outputFlag::FLAG_NONE ? 1 : 2;
lut_message_extract = new int_radix_lut<Torus>(
streams, gpu_indexes, gpu_count, params, 2, num_radix_blocks + 1,
allocate_gpu_memory, size_tracker);
streams, gpu_indexes, gpu_count, params, num_luts_message_extract,
num_radix_blocks + 1, allocate_gpu_memory, size_tracker);
// lut for the first block in the first grouping
auto f_message_extract = [message_modulus](Torus block) -> Torus {
return (block >> 1) % message_modulus;
@@ -2372,8 +2374,6 @@ template <typename Torus> struct int_sc_prop_memory {
message_modulus, carry_modulus, f_message_extract,
gpu_memory_allocated);
lut_message_extract->broadcast_lut(streams, gpu_indexes);
// This store a single block that with be used to store the overflow or
// carry results
output_flag = new CudaRadixCiphertextFFI;
@@ -2473,8 +2473,6 @@ template <typename Torus> struct int_sc_prop_memory {
lut_message_extract->get_lut_indexes(0, 0), h_lut_indexes,
(num_radix_blocks + 1) * sizeof(Torus), streams[0], gpu_indexes[0],
allocate_gpu_memory);
lut_message_extract->broadcast_lut(streams, gpu_indexes);
}
if (requested_flag == outputFlag::FLAG_CARRY) { // Carry case
@@ -2501,9 +2499,8 @@ template <typename Torus> struct int_sc_prop_memory {
lut_message_extract->get_lut_indexes(0, 0), h_lut_indexes,
(num_radix_blocks + 1) * sizeof(Torus), streams[0], gpu_indexes[0],
allocate_gpu_memory);
lut_message_extract->broadcast_lut(streams, gpu_indexes);
}
lut_message_extract->broadcast_lut(streams, gpu_indexes);
};
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,