chore(gpu): add option to pre-release some buffers in scalar mul

This commit is contained in:
Agnes Leroy
2024-12-18 15:19:42 +01:00
committed by Agnès Leroy
parent 33d5091025
commit 9b43a9459a
2 changed files with 15 additions and 5 deletions

View File

@@ -4271,12 +4271,15 @@ template <typename Torus> struct int_scalar_mul_buffer {
Torus *preshifted_buffer;
Torus *all_shifted_buffer;
int_sc_prop_memory<Torus> *sc_prop_mem;
bool anticipated_buffers_drop;
int_scalar_mul_buffer(cudaStream_t const *streams,
uint32_t const *gpu_indexes, uint32_t gpu_count,
int_radix_params params, uint32_t num_radix_blocks,
bool allocate_gpu_memory) {
bool allocate_gpu_memory,
bool anticipated_buffer_drop) {
this->params = params;
this->anticipated_buffers_drop = anticipated_buffer_drop;
if (allocate_gpu_memory) {
uint32_t msg_bits = (uint32_t)std::log2(params.message_modulus);
@@ -4324,6 +4327,11 @@ template <typename Torus> struct int_scalar_mul_buffer {
delete sum_ciphertexts_vec_mem;
delete sc_prop_mem;
cuda_drop_async(all_shifted_buffer, streams[0], gpu_indexes[0]);
if (!anticipated_buffers_drop) {
cuda_drop_async(preshifted_buffer, streams[0], gpu_indexes[0]);
logical_scalar_shift_buffer->release(streams, gpu_indexes, gpu_count);
delete (logical_scalar_shift_buffer);
}
}
};

View File

@@ -36,7 +36,7 @@ __host__ void scratch_cuda_integer_radix_scalar_mul_kb(
*mem_ptr =
new int_scalar_mul_buffer<T>(streams, gpu_indexes, gpu_count, params,
num_radix_blocks, allocate_gpu_memory);
num_radix_blocks, allocate_gpu_memory, true);
}
template <typename T, class params>
@@ -94,9 +94,11 @@ __host__ void host_integer_scalar_mul_radix(
}
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
cuda_drop_async(preshifted_buffer, streams[0], gpu_indexes[0]);
mem->logical_scalar_shift_buffer->release(streams, gpu_indexes, gpu_count);
delete (mem->logical_scalar_shift_buffer);
if (mem->anticipated_buffers_drop) {
cuda_drop_async(preshifted_buffer, streams[0], gpu_indexes[0]);
mem->logical_scalar_shift_buffer->release(streams, gpu_indexes, gpu_count);
delete (mem->logical_scalar_shift_buffer);
}
if (j == 0) {
// lwe array = 0