mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
fix(gpu): create message extract lut only when needed
This commit is contained in:
@@ -2357,9 +2357,11 @@ template <typename Torus> struct int_sc_prop_memory {
|
||||
grouping_size, num_groups, allocate_gpu_memory, size_tracker);
|
||||
|
||||
// Step 3 elements
|
||||
int num_luts_message_extract =
|
||||
requested_flag == outputFlag::FLAG_NONE ? 1 : 2;
|
||||
lut_message_extract = new int_radix_lut<Torus>(
|
||||
streams, gpu_indexes, gpu_count, params, 2, num_radix_blocks + 1,
|
||||
allocate_gpu_memory, size_tracker);
|
||||
streams, gpu_indexes, gpu_count, params, num_luts_message_extract,
|
||||
num_radix_blocks + 1, allocate_gpu_memory, size_tracker);
|
||||
// lut for the first block in the first grouping
|
||||
auto f_message_extract = [message_modulus](Torus block) -> Torus {
|
||||
return (block >> 1) % message_modulus;
|
||||
@@ -2372,8 +2374,6 @@ template <typename Torus> struct int_sc_prop_memory {
|
||||
message_modulus, carry_modulus, f_message_extract,
|
||||
gpu_memory_allocated);
|
||||
|
||||
lut_message_extract->broadcast_lut(streams, gpu_indexes);
|
||||
|
||||
// This store a single block that with be used to store the overflow or
|
||||
// carry results
|
||||
output_flag = new CudaRadixCiphertextFFI;
|
||||
@@ -2473,8 +2473,6 @@ template <typename Torus> struct int_sc_prop_memory {
|
||||
lut_message_extract->get_lut_indexes(0, 0), h_lut_indexes,
|
||||
(num_radix_blocks + 1) * sizeof(Torus), streams[0], gpu_indexes[0],
|
||||
allocate_gpu_memory);
|
||||
|
||||
lut_message_extract->broadcast_lut(streams, gpu_indexes);
|
||||
}
|
||||
if (requested_flag == outputFlag::FLAG_CARRY) { // Carry case
|
||||
|
||||
@@ -2501,9 +2499,8 @@ template <typename Torus> struct int_sc_prop_memory {
|
||||
lut_message_extract->get_lut_indexes(0, 0), h_lut_indexes,
|
||||
(num_radix_blocks + 1) * sizeof(Torus), streams[0], gpu_indexes[0],
|
||||
allocate_gpu_memory);
|
||||
|
||||
lut_message_extract->broadcast_lut(streams, gpu_indexes);
|
||||
}
|
||||
lut_message_extract->broadcast_lut(streams, gpu_indexes);
|
||||
};
|
||||
|
||||
void release(cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
||||
|
||||
Reference in New Issue
Block a user