mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 06:13:58 -05:00
fix(gpu): fix some CPU memory leaks due to the use of new without delete
This commit is contained in:
@@ -102,21 +102,21 @@ __host__ void host_boolean_bitop(CudaStreams streams,
|
||||
return false;
|
||||
};
|
||||
|
||||
CudaRadixCiphertextFFI *lwe_array_left = new CudaRadixCiphertextFFI;
|
||||
CudaRadixCiphertextFFI *lwe_array_right = new CudaRadixCiphertextFFI;
|
||||
CudaRadixCiphertextFFI lwe_array_left;
|
||||
CudaRadixCiphertextFFI lwe_array_right;
|
||||
|
||||
if (needs_noise_reduction(lwe_array_1)) {
|
||||
copy_radix_ciphertext_slice_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), mem_ptr->tmp_lwe_left, 0,
|
||||
lwe_array_1->num_radix_blocks, lwe_array_1, 0,
|
||||
lwe_array_1->num_radix_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(lwe_array_left, mem_ptr->tmp_lwe_left, 0,
|
||||
as_radix_ciphertext_slice<Torus>(&lwe_array_left, mem_ptr->tmp_lwe_left, 0,
|
||||
lwe_array_1->num_radix_blocks);
|
||||
integer_radix_apply_univariate_lookup_table<Torus>(
|
||||
streams, lwe_array_left, lwe_array_left, bsks, ksks,
|
||||
mem_ptr->message_extract_lut, lwe_array_left->num_radix_blocks);
|
||||
streams, &lwe_array_left, &lwe_array_left, bsks, ksks,
|
||||
mem_ptr->message_extract_lut, lwe_array_left.num_radix_blocks);
|
||||
} else {
|
||||
as_radix_ciphertext_slice<Torus>(lwe_array_left, lwe_array_1, 0,
|
||||
as_radix_ciphertext_slice<Torus>(&lwe_array_left, lwe_array_1, 0,
|
||||
lwe_array_1->num_radix_blocks);
|
||||
}
|
||||
|
||||
@@ -125,37 +125,37 @@ __host__ void host_boolean_bitop(CudaStreams streams,
|
||||
streams.stream(0), streams.gpu_index(0), mem_ptr->tmp_lwe_right, 0,
|
||||
lwe_array_2->num_radix_blocks, lwe_array_2, 0,
|
||||
lwe_array_2->num_radix_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(lwe_array_right, mem_ptr->tmp_lwe_right, 0,
|
||||
lwe_array_2->num_radix_blocks);
|
||||
as_radix_ciphertext_slice<Torus>(&lwe_array_right, mem_ptr->tmp_lwe_right,
|
||||
0, lwe_array_2->num_radix_blocks);
|
||||
integer_radix_apply_univariate_lookup_table<Torus>(
|
||||
streams, lwe_array_right, lwe_array_right, bsks, ksks,
|
||||
mem_ptr->message_extract_lut, lwe_array_right->num_radix_blocks);
|
||||
streams, &lwe_array_right, &lwe_array_right, bsks, ksks,
|
||||
mem_ptr->message_extract_lut, lwe_array_right.num_radix_blocks);
|
||||
} else {
|
||||
as_radix_ciphertext_slice<Torus>(lwe_array_right, lwe_array_2, 0,
|
||||
as_radix_ciphertext_slice<Torus>(&lwe_array_right, lwe_array_2, 0,
|
||||
lwe_array_2->num_radix_blocks);
|
||||
}
|
||||
|
||||
auto lut = mem_ptr->lut;
|
||||
uint64_t degrees[lwe_array_left->num_radix_blocks];
|
||||
uint64_t degrees[lwe_array_left.num_radix_blocks];
|
||||
if (mem_ptr->op == BITOP_TYPE::BITAND) {
|
||||
update_degrees_after_bitand(degrees, lwe_array_left->degrees,
|
||||
lwe_array_right->degrees,
|
||||
lwe_array_left->num_radix_blocks);
|
||||
update_degrees_after_bitand(degrees, lwe_array_left.degrees,
|
||||
lwe_array_right.degrees,
|
||||
lwe_array_left.num_radix_blocks);
|
||||
} else if (mem_ptr->op == BITOP_TYPE::BITOR) {
|
||||
update_degrees_after_bitor(degrees, lwe_array_left->degrees,
|
||||
lwe_array_right->degrees,
|
||||
lwe_array_left->num_radix_blocks);
|
||||
update_degrees_after_bitor(degrees, lwe_array_left.degrees,
|
||||
lwe_array_right.degrees,
|
||||
lwe_array_left.num_radix_blocks);
|
||||
} else if (mem_ptr->op == BITOP_TYPE::BITXOR) {
|
||||
update_degrees_after_bitxor(degrees, lwe_array_left->degrees,
|
||||
lwe_array_right->degrees,
|
||||
lwe_array_left->num_radix_blocks);
|
||||
update_degrees_after_bitxor(degrees, lwe_array_left.degrees,
|
||||
lwe_array_right.degrees,
|
||||
lwe_array_left.num_radix_blocks);
|
||||
}
|
||||
|
||||
// shift argument is hardcoded as 2 here, because natively message modulus for
|
||||
// boolean block should be 2. lookup table is generated with same factor.
|
||||
integer_radix_apply_bivariate_lookup_table<Torus>(
|
||||
streams, lwe_array_out, lwe_array_left, lwe_array_right, bsks, ksks, lut,
|
||||
lwe_array_out->num_radix_blocks, 2);
|
||||
streams, lwe_array_out, &lwe_array_left, &lwe_array_right, bsks, ksks,
|
||||
lut, lwe_array_out->num_radix_blocks, 2);
|
||||
|
||||
memcpy(lwe_array_out->degrees, degrees,
|
||||
lwe_array_out->num_radix_blocks * sizeof(uint64_t));
|
||||
|
||||
@@ -169,21 +169,21 @@ __host__ void host_unsigned_integer_div_rem_block_by_block_2_2(
|
||||
CudaRadixCiphertextFFI *comparison_blocks,
|
||||
CudaRadixCiphertextFFI *d,
|
||||
int_comparison_buffer<Torus> *comparison_buffer) {
|
||||
CudaRadixCiphertextFFI *d_msb = new CudaRadixCiphertextFFI;
|
||||
CudaRadixCiphertextFFI d_msb;
|
||||
uint32_t slice_start = num_blocks - block_index;
|
||||
uint32_t slice_end = d->num_radix_blocks;
|
||||
as_radix_ciphertext_slice<Torus>(d_msb, d, slice_start, slice_end);
|
||||
comparison_blocks->num_radix_blocks = d_msb->num_radix_blocks;
|
||||
if (d_msb->num_radix_blocks == 0) {
|
||||
as_radix_ciphertext_slice<Torus>(&d_msb, d, slice_start, slice_end);
|
||||
comparison_blocks->num_radix_blocks = d_msb.num_radix_blocks;
|
||||
if (d_msb.num_radix_blocks == 0) {
|
||||
cuda_memset_async(
|
||||
(Torus *)out_boolean_block->ptr, 0,
|
||||
sizeof(Torus) * (out_boolean_block->lwe_dimension + 1),
|
||||
streams.stream(gpu_index), streams.gpu_index(gpu_index));
|
||||
} else {
|
||||
host_compare_blocks_with_zero<Torus>(
|
||||
streams.get_ith(gpu_index), comparison_blocks, d_msb,
|
||||
streams.get_ith(gpu_index), comparison_blocks, &d_msb,
|
||||
comparison_buffer, &bsks[gpu_index], &ksks[gpu_index],
|
||||
d_msb->num_radix_blocks, comparison_buffer->is_zero_lut);
|
||||
d_msb.num_radix_blocks, comparison_buffer->is_zero_lut);
|
||||
are_all_comparisons_block_true(
|
||||
streams.get_ith(gpu_index), out_boolean_block, comparison_blocks,
|
||||
comparison_buffer, &bsks[gpu_index], &ksks[gpu_index],
|
||||
@@ -202,7 +202,6 @@ __host__ void host_unsigned_integer_div_rem_block_by_block_2_2(
|
||||
(Torus *)out_boolean_block->ptr, (Torus *)out_boolean_block->ptr,
|
||||
encoded_scalar, radix_params.big_lwe_dimension, 1);
|
||||
}
|
||||
delete d_msb;
|
||||
};
|
||||
|
||||
for (uint j = 0; j < 3; j++) {
|
||||
|
||||
@@ -63,14 +63,14 @@ void rerand_inplace(
|
||||
|
||||
// Add ks output to ct
|
||||
// Check sizes
|
||||
auto lwes_ffi = new CudaRadixCiphertextFFI;
|
||||
into_radix_ciphertext(lwes_ffi, lwe_array, num_lwes, output_dimension);
|
||||
auto ksed_zero_lwes_ffi = new CudaRadixCiphertextFFI;
|
||||
into_radix_ciphertext(ksed_zero_lwes_ffi, ksed_zero_lwes, num_lwes,
|
||||
CudaRadixCiphertextFFI lwes_ffi;
|
||||
into_radix_ciphertext(&lwes_ffi, lwe_array, num_lwes, output_dimension);
|
||||
CudaRadixCiphertextFFI ksed_zero_lwes_ffi;
|
||||
into_radix_ciphertext(&ksed_zero_lwes_ffi, ksed_zero_lwes, num_lwes,
|
||||
output_dimension);
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), lwes_ffi,
|
||||
lwes_ffi, ksed_zero_lwes_ffi, num_lwes, message_modulus,
|
||||
carry_modulus);
|
||||
host_addition<Torus>(streams.stream(0), streams.gpu_index(0), &lwes_ffi,
|
||||
&lwes_ffi, &ksed_zero_lwes_ffi, num_lwes,
|
||||
message_modulus, carry_modulus);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
|
||||
@@ -91,12 +91,12 @@ __host__ void host_expand_without_verification(
|
||||
cuda_memset_async(lwe_array_out, 0,
|
||||
(lwe_dimension + 1) * num_lwes * 2 * sizeof(Torus),
|
||||
streams.stream(0), streams.gpu_index(0));
|
||||
auto output = new CudaRadixCiphertextFFI;
|
||||
into_radix_ciphertext(output, lwe_array_out, 2 * num_lwes, lwe_dimension);
|
||||
auto input = new CudaRadixCiphertextFFI;
|
||||
into_radix_ciphertext(input, lwe_array_input, 2 * num_lwes, lwe_dimension);
|
||||
CudaRadixCiphertextFFI output;
|
||||
into_radix_ciphertext(&output, lwe_array_out, 2 * num_lwes, lwe_dimension);
|
||||
CudaRadixCiphertextFFI input;
|
||||
into_radix_ciphertext(&input, lwe_array_input, 2 * num_lwes, lwe_dimension);
|
||||
integer_radix_apply_univariate_lookup_table<Torus>(
|
||||
streams, output, input, bsks, ksks, message_and_carry_extract_luts,
|
||||
streams, &output, &input, bsks, ksks, message_and_carry_extract_luts,
|
||||
2 * num_lwes);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user