fix(gpu): fix perf regression introduced in 1936ec6d84

This commit is contained in:
Agnes Leroy
2025-06-19 10:52:01 +02:00
committed by Agnès Leroy
parent f5f7213289
commit e5a9145cce
2 changed files with 55 additions and 33 deletions

View File

@@ -249,11 +249,8 @@ template <typename Torus> struct int_radix_lut {
num_radix_blocks * sizeof(Torus), streams[0], gpu_indexes[0],
size_tracker, allocate_gpu_memory);
cudaMallocHost((void **)&h_lwe_indexes_in,
num_radix_blocks * sizeof(Torus));
cudaMallocHost((void **)&h_lwe_indexes_out,
num_radix_blocks * sizeof(Torus));
h_lwe_indexes_in = (Torus *)malloc(num_radix_blocks * sizeof(Torus));
h_lwe_indexes_out = (Torus *)malloc(num_radix_blocks * sizeof(Torus));
for (int i = 0; i < num_radix_blocks; i++)
h_lwe_indexes_in[i] = i;
@@ -372,10 +369,8 @@ template <typename Torus> struct int_radix_lut {
num_radix_blocks * sizeof(Torus), streams[0], gpu_indexes[0],
size_tracker, allocate_gpu_memory);
cudaMallocHost((void **)&h_lwe_indexes_in,
num_radix_blocks * sizeof(Torus));
cudaMallocHost((void **)&h_lwe_indexes_out,
num_radix_blocks * sizeof(Torus));
h_lwe_indexes_in = (Torus *)malloc(num_radix_blocks * sizeof(Torus));
h_lwe_indexes_out = (Torus *)malloc(num_radix_blocks * sizeof(Torus));
for (int i = 0; i < num_radix_blocks; i++)
h_lwe_indexes_in[i] = i;
@@ -470,10 +465,8 @@ template <typename Torus> struct int_radix_lut {
num_radix_blocks * sizeof(Torus), streams[0], gpu_indexes[0],
size_tracker, allocate_gpu_memory);
cudaMallocHost((void **)&h_lwe_indexes_in,
num_radix_blocks * sizeof(Torus));
cudaMallocHost((void **)&h_lwe_indexes_out,
num_radix_blocks * sizeof(Torus));
h_lwe_indexes_in = (Torus *)malloc(num_radix_blocks * sizeof(Torus));
h_lwe_indexes_out = (Torus *)malloc(num_radix_blocks * sizeof(Torus));
for (int i = 0; i < num_radix_blocks; i++)
h_lwe_indexes_in[i] = i;
@@ -611,8 +604,8 @@ template <typename Torus> struct int_radix_lut {
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
lut_vec.clear();
lut_indexes_vec.clear();
cudaFreeHost(h_lwe_indexes_in);
cudaFreeHost(h_lwe_indexes_out);
free(h_lwe_indexes_in);
free(h_lwe_indexes_out);
if (!mem_reuse) {
release_radix_ciphertext_async(streams[0], gpu_indexes[0],

View File

@@ -521,18 +521,32 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
total_ciphertexts, mem_ptr->params.pbs_type, num_many_lut,
lut_stride);
} else {
cuda_memcpy_async_to_cpu(luts_message_carry->h_lwe_indexes_in,
luts_message_carry->lwe_indexes_in,
total_ciphertexts * sizeof(Torus), streams[0],
gpu_indexes[0]);
cuda_memcpy_async_to_cpu(luts_message_carry->h_lwe_indexes_out,
luts_message_carry->lwe_indexes_out,
total_ciphertexts * sizeof(Torus), streams[0],
gpu_indexes[0]);
Torus *h_lwe_indexes_in_pinned;
Torus *h_lwe_indexes_out_pinned;
cudaMallocHost((void **)&h_lwe_indexes_in_pinned,
total_ciphertexts * sizeof(Torus));
cudaMallocHost((void **)&h_lwe_indexes_out_pinned,
total_ciphertexts * sizeof(Torus));
for (uint32_t i = 0; i < total_ciphertexts; i++) {
h_lwe_indexes_in_pinned[i] = luts_message_carry->h_lwe_indexes_in[i];
h_lwe_indexes_out_pinned[i] = luts_message_carry->h_lwe_indexes_out[i];
}
cuda_memcpy_async_to_cpu(
h_lwe_indexes_in_pinned, luts_message_carry->lwe_indexes_in,
total_ciphertexts * sizeof(Torus), streams[0], gpu_indexes[0]);
cuda_memcpy_async_to_cpu(
h_lwe_indexes_out_pinned, luts_message_carry->lwe_indexes_out,
total_ciphertexts * sizeof(Torus), streams[0], gpu_indexes[0]);
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
for (uint32_t i = 0; i < total_ciphertexts; i++) {
luts_message_carry->h_lwe_indexes_in[i] = h_lwe_indexes_in_pinned[i];
luts_message_carry->h_lwe_indexes_out[i] = h_lwe_indexes_out_pinned[i];
}
cudaFreeHost(h_lwe_indexes_in_pinned);
cudaFreeHost(h_lwe_indexes_out_pinned);
luts_message_carry->using_trivial_lwe_indexes = false;
luts_message_carry->broadcast_lut(streams, gpu_indexes, 0);
luts_message_carry->using_trivial_lwe_indexes = false;
integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, active_gpu_count, current_blocks,
@@ -580,15 +594,30 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
2 * num_radix_blocks, mem_ptr->params.pbs_type, num_many_lut,
lut_stride);
} else {
cuda_memcpy_async_to_cpu(luts_message_carry->h_lwe_indexes_in,
luts_message_carry->lwe_indexes_in,
2 * num_radix_blocks * sizeof(Torus), streams[0],
gpu_indexes[0]);
cuda_memcpy_async_to_cpu(luts_message_carry->h_lwe_indexes_out,
luts_message_carry->lwe_indexes_out,
2 * num_radix_blocks * sizeof(Torus), streams[0],
gpu_indexes[0]);
uint32_t num_blocks_in_apply_lut = 2 * num_radix_blocks;
Torus *h_lwe_indexes_in_pinned;
Torus *h_lwe_indexes_out_pinned;
cudaMallocHost((void **)&h_lwe_indexes_in_pinned,
num_blocks_in_apply_lut * sizeof(Torus));
cudaMallocHost((void **)&h_lwe_indexes_out_pinned,
num_blocks_in_apply_lut * sizeof(Torus));
for (uint32_t i = 0; i < num_blocks_in_apply_lut; i++) {
h_lwe_indexes_in_pinned[i] = luts_message_carry->h_lwe_indexes_in[i];
h_lwe_indexes_out_pinned[i] = luts_message_carry->h_lwe_indexes_out[i];
}
cuda_memcpy_async_to_cpu(
h_lwe_indexes_in_pinned, luts_message_carry->lwe_indexes_in,
num_blocks_in_apply_lut * sizeof(Torus), streams[0], gpu_indexes[0]);
cuda_memcpy_async_to_cpu(
h_lwe_indexes_out_pinned, luts_message_carry->lwe_indexes_out,
num_blocks_in_apply_lut * sizeof(Torus), streams[0], gpu_indexes[0]);
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
for (uint32_t i = 0; i < num_blocks_in_apply_lut; i++) {
luts_message_carry->h_lwe_indexes_in[i] = h_lwe_indexes_in_pinned[i];
luts_message_carry->h_lwe_indexes_out[i] = h_lwe_indexes_out_pinned[i];
}
cudaFreeHost(h_lwe_indexes_in_pinned);
cudaFreeHost(h_lwe_indexes_out_pinned);
luts_message_carry->broadcast_lut(streams, gpu_indexes, 0);
luts_message_carry->using_trivial_lwe_indexes = false;
@@ -596,7 +625,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec_kb(
integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, active_gpu_count, current_blocks, radix_lwe_out,
bsks, ksks, ms_noise_reduction_key, luts_message_carry,
2 * num_radix_blocks);
num_blocks_in_apply_lut);
}
calculate_final_degrees(radix_lwe_out->degrees, terms->degrees,
num_radix_blocks, num_radix_in_vec, chunk_size,