fix(gpu): fix the logic of host_compare_with_zero_equality in Cuda to match the CPU

This commit is contained in:
Agnes Leroy
2025-03-14 17:26:48 +01:00
committed by Agnès Leroy
parent 912af0e87e
commit 5258acc08f
5 changed files with 35 additions and 23 deletions

View File

@@ -169,6 +169,7 @@ __host__ void are_all_comparisons_block_true(
is_max_value_lut->broadcast_lut(streams, gpu_indexes, 0);
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
free(h_lut_indexes);
reset_radix_ciphertext_blocks(lwe_array_out, 1);
return;
} else {
integer_radix_apply_univariate_lookup_table_kb<Torus>(
@@ -254,11 +255,8 @@ __host__ void is_at_least_one_comparisons_block_true(
}
}
// FIXME This function should be improved as it outputs a single LWE ciphertext
// but requires the output to have enough blocks allocated to compute
// intermediate values
template <typename Torus>
__host__ void host_compare_with_zero_equality(
__host__ void host_compare_blocks_with_zero(
cudaStream_t const *streams, uint32_t const *gpu_indexes,
uint32_t gpu_count, CudaRadixCiphertextFFI *lwe_array_out,
CudaRadixCiphertextFFI const *lwe_array_in,
@@ -320,11 +318,9 @@ __host__ void host_compare_with_zero_equality(
}
integer_radix_apply_univariate_lookup_table_kb<Torus>(
streams, gpu_indexes, gpu_count, sum, sum, bsks, ksks, zero_comparison,
num_sum_blocks);
are_all_comparisons_block_true<Torus>(streams, gpu_indexes, gpu_count,
lwe_array_out, sum, mem_ptr, bsks, ksks,
num_sum_blocks);
streams, gpu_indexes, gpu_count, lwe_array_out, sum, bsks, ksks,
zero_comparison, num_sum_blocks);
reset_radix_ciphertext_blocks(lwe_array_out, num_sum_blocks);
}
template <typename Torus>

View File

@@ -70,7 +70,6 @@ __host__ void host_unsigned_integer_div_rem_kb(
auto did_not_overflow = mem_ptr->did_not_overflow;
auto overflow_sum = mem_ptr->overflow_sum;
auto overflow_sum_radix = mem_ptr->overflow_sum_radix;
auto tmp_1 = mem_ptr->tmp_1;
auto at_least_one_upper_block_is_non_zero =
mem_ptr->at_least_one_upper_block_is_non_zero;
auto cleaned_merged_interesting_remainder =
@@ -334,16 +333,17 @@ __host__ void host_unsigned_integer_div_rem_kb(
// We could call unchecked_scalar_ne
// But we are in the special case where scalar == 0
// So we can skip some stuff
host_compare_with_zero_equality<Torus>(
streams, gpu_indexes, gpu_count, tmp_1, trivial_blocks,
host_compare_blocks_with_zero<Torus>(
streams, gpu_indexes, gpu_count, mem_ptr->tmp_1, trivial_blocks,
mem_ptr->comparison_buffer, bsks, ksks,
trivial_blocks->num_radix_blocks,
mem_ptr->comparison_buffer->eq_buffer->is_non_zero_lut);
is_at_least_one_comparisons_block_true<Torus>(
streams, gpu_indexes, gpu_count,
at_least_one_upper_block_is_non_zero, tmp_1,
mem_ptr->comparison_buffer, bsks, ksks, tmp_1->num_radix_blocks);
at_least_one_upper_block_is_non_zero, mem_ptr->tmp_1,
mem_ptr->comparison_buffer, bsks, ksks,
mem_ptr->tmp_1->num_radix_blocks);
}
};

View File

@@ -12,7 +12,7 @@ void release_radix_ciphertext(cudaStream_t const stream,
void reset_radix_ciphertext_blocks(CudaRadixCiphertextFFI *data,
uint32_t new_num_blocks) {
if (new_num_blocks > data->max_num_radix_blocks)
PANIC("Cuda error: new num blocks should be lower or equal to previous num "
"blocks")
PANIC("Cuda error: new num blocks should be lower or equal than the "
"radix' maximum number of blocks")
data->num_radix_blocks = new_num_blocks;
}

View File

@@ -56,6 +56,7 @@ void as_radix_ciphertext_slice(CudaRadixCiphertextFFI *output_radix,
auto lwe_size = input_radix->lwe_dimension + 1;
output_radix->num_radix_blocks = end_input_lwe_index - start_input_lwe_index;
output_radix->max_num_radix_blocks = input_radix->max_num_radix_blocks;
output_radix->lwe_dimension = input_radix->lwe_dimension;
Torus *in_ptr = (Torus *)input_radix->ptr;
output_radix->ptr = (void *)(in_ptr + start_input_lwe_index * lwe_size);

View File

@@ -131,10 +131,14 @@ __host__ void integer_radix_unsigned_scalar_difference_check_kb(
if (num_scalar_blocks == 0) {
// We only have to compare blocks with zero
// means scalar is zero
host_compare_with_zero_equality<Torus>(
host_compare_blocks_with_zero<Torus>(
streams, gpu_indexes, gpu_count, mem_ptr->tmp_lwe_array_out,
lwe_array_in, mem_ptr, bsks, ksks, num_radix_blocks,
mem_ptr->is_zero_lut);
are_all_comparisons_block_true<Torus>(
streams, gpu_indexes, gpu_count, mem_ptr->tmp_lwe_array_out,
mem_ptr->tmp_lwe_array_out, mem_ptr, bsks, ksks,
mem_ptr->tmp_lwe_array_out->num_radix_blocks);
auto scalar_last_leaf_lut_f = [sign_handler_f](Torus x) -> Torus {
x = (x == 1 ? IS_EQUAL : IS_SUPERIOR);
@@ -217,9 +221,13 @@ __host__ void integer_radix_unsigned_scalar_difference_check_kb(
num_lsb_radix_blocks);
//////////////
// msb
host_compare_with_zero_equality<Torus>(
host_compare_blocks_with_zero<Torus>(
msb_streams, gpu_indexes, gpu_count, &lwe_array_msb_out, &msb, mem_ptr,
bsks, ksks, num_msb_radix_blocks, mem_ptr->is_zero_lut);
are_all_comparisons_block_true<Torus>(
msb_streams, gpu_indexes, gpu_count, &lwe_array_msb_out,
&lwe_array_msb_out, mem_ptr, bsks, ksks,
lwe_array_msb_out.num_radix_blocks);
for (uint j = 0; j < mem_ptr->active_gpu_count; j++) {
cuda_synchronize_stream(lsb_streams[j], gpu_indexes[j]);
cuda_synchronize_stream(msb_streams[j], gpu_indexes[j]);
@@ -372,9 +380,12 @@ __host__ void integer_radix_signed_scalar_difference_check_kb(
// We only have to compare blocks with zero
// means scalar is zero
auto are_all_msb_zeros = mem_ptr->tmp_lwe_array_out;
host_compare_with_zero_equality<Torus>(
host_compare_blocks_with_zero<Torus>(
streams, gpu_indexes, gpu_count, are_all_msb_zeros, lwe_array_in,
mem_ptr, bsks, ksks, num_radix_blocks, mem_ptr->is_zero_lut);
are_all_comparisons_block_true<Torus>(
streams, gpu_indexes, gpu_count, are_all_msb_zeros, are_all_msb_zeros,
mem_ptr, bsks, ksks, are_all_msb_zeros->num_radix_blocks);
CudaRadixCiphertextFFI sign_block;
as_radix_ciphertext_slice<Torus>(&sign_block, lwe_array_in,
num_radix_blocks - 1, num_radix_blocks);
@@ -485,9 +496,13 @@ __host__ void integer_radix_signed_scalar_difference_check_kb(
// msb
// We remove the last block (which is the sign)
auto are_all_msb_zeros = lwe_array_msb_out;
host_compare_with_zero_equality<Torus>(
host_compare_blocks_with_zero<Torus>(
msb_streams, gpu_indexes, gpu_count, &are_all_msb_zeros, &msb, mem_ptr,
bsks, ksks, num_msb_radix_blocks, mem_ptr->is_zero_lut);
are_all_comparisons_block_true<Torus>(
msb_streams, gpu_indexes, gpu_count, &are_all_msb_zeros,
&are_all_msb_zeros, mem_ptr, bsks, ksks,
are_all_msb_zeros.num_radix_blocks);
auto sign_bit_pos = (int)log2(message_modulus) - 1;
@@ -813,9 +828,9 @@ __host__ void host_integer_radix_scalar_equality_check_kb(
PANIC("Cuda error: integer operation not supported")
}
host_compare_with_zero_equality<Torus>(msb_streams, gpu_indexes, gpu_count,
&msb_out, &msb_in, mem_ptr, bsks,
ksks, num_msb_radix_blocks, msb_lut);
host_compare_blocks_with_zero<Torus>(msb_streams, gpu_indexes, gpu_count,
&msb_out, &msb_in, mem_ptr, bsks, ksks,
num_msb_radix_blocks, msb_lut);
}
for (uint j = 0; j < mem_ptr->active_gpu_count; j++) {