mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
fix(gpu): fix the logic of host_compare_with_zero_equality in Cuda to match the CPU
This commit is contained in:
@@ -169,6 +169,7 @@ __host__ void are_all_comparisons_block_true(
|
||||
is_max_value_lut->broadcast_lut(streams, gpu_indexes, 0);
|
||||
cuda_synchronize_stream(streams[0], gpu_indexes[0]);
|
||||
free(h_lut_indexes);
|
||||
reset_radix_ciphertext_blocks(lwe_array_out, 1);
|
||||
return;
|
||||
} else {
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
@@ -254,11 +255,8 @@ __host__ void is_at_least_one_comparisons_block_true(
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME This function should be improved as it outputs a single LWE ciphertext
|
||||
// but requires the output to have enough blocks allocated to compute
|
||||
// intermediate values
|
||||
template <typename Torus>
|
||||
__host__ void host_compare_with_zero_equality(
|
||||
__host__ void host_compare_blocks_with_zero(
|
||||
cudaStream_t const *streams, uint32_t const *gpu_indexes,
|
||||
uint32_t gpu_count, CudaRadixCiphertextFFI *lwe_array_out,
|
||||
CudaRadixCiphertextFFI const *lwe_array_in,
|
||||
@@ -320,11 +318,9 @@ __host__ void host_compare_with_zero_equality(
|
||||
}
|
||||
|
||||
integer_radix_apply_univariate_lookup_table_kb<Torus>(
|
||||
streams, gpu_indexes, gpu_count, sum, sum, bsks, ksks, zero_comparison,
|
||||
num_sum_blocks);
|
||||
are_all_comparisons_block_true<Torus>(streams, gpu_indexes, gpu_count,
|
||||
lwe_array_out, sum, mem_ptr, bsks, ksks,
|
||||
num_sum_blocks);
|
||||
streams, gpu_indexes, gpu_count, lwe_array_out, sum, bsks, ksks,
|
||||
zero_comparison, num_sum_blocks);
|
||||
reset_radix_ciphertext_blocks(lwe_array_out, num_sum_blocks);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
|
||||
@@ -70,7 +70,6 @@ __host__ void host_unsigned_integer_div_rem_kb(
|
||||
auto did_not_overflow = mem_ptr->did_not_overflow;
|
||||
auto overflow_sum = mem_ptr->overflow_sum;
|
||||
auto overflow_sum_radix = mem_ptr->overflow_sum_radix;
|
||||
auto tmp_1 = mem_ptr->tmp_1;
|
||||
auto at_least_one_upper_block_is_non_zero =
|
||||
mem_ptr->at_least_one_upper_block_is_non_zero;
|
||||
auto cleaned_merged_interesting_remainder =
|
||||
@@ -334,16 +333,17 @@ __host__ void host_unsigned_integer_div_rem_kb(
|
||||
// We could call unchecked_scalar_ne
|
||||
// But we are in the special case where scalar == 0
|
||||
// So we can skip some stuff
|
||||
host_compare_with_zero_equality<Torus>(
|
||||
streams, gpu_indexes, gpu_count, tmp_1, trivial_blocks,
|
||||
host_compare_blocks_with_zero<Torus>(
|
||||
streams, gpu_indexes, gpu_count, mem_ptr->tmp_1, trivial_blocks,
|
||||
mem_ptr->comparison_buffer, bsks, ksks,
|
||||
trivial_blocks->num_radix_blocks,
|
||||
mem_ptr->comparison_buffer->eq_buffer->is_non_zero_lut);
|
||||
|
||||
is_at_least_one_comparisons_block_true<Torus>(
|
||||
streams, gpu_indexes, gpu_count,
|
||||
at_least_one_upper_block_is_non_zero, tmp_1,
|
||||
mem_ptr->comparison_buffer, bsks, ksks, tmp_1->num_radix_blocks);
|
||||
at_least_one_upper_block_is_non_zero, mem_ptr->tmp_1,
|
||||
mem_ptr->comparison_buffer, bsks, ksks,
|
||||
mem_ptr->tmp_1->num_radix_blocks);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ void release_radix_ciphertext(cudaStream_t const stream,
|
||||
void reset_radix_ciphertext_blocks(CudaRadixCiphertextFFI *data,
|
||||
uint32_t new_num_blocks) {
|
||||
if (new_num_blocks > data->max_num_radix_blocks)
|
||||
PANIC("Cuda error: new num blocks should be lower or equal to previous num "
|
||||
"blocks")
|
||||
PANIC("Cuda error: new num blocks should be lower or equal than the "
|
||||
"radix' maximum number of blocks")
|
||||
data->num_radix_blocks = new_num_blocks;
|
||||
}
|
||||
|
||||
@@ -56,6 +56,7 @@ void as_radix_ciphertext_slice(CudaRadixCiphertextFFI *output_radix,
|
||||
|
||||
auto lwe_size = input_radix->lwe_dimension + 1;
|
||||
output_radix->num_radix_blocks = end_input_lwe_index - start_input_lwe_index;
|
||||
output_radix->max_num_radix_blocks = input_radix->max_num_radix_blocks;
|
||||
output_radix->lwe_dimension = input_radix->lwe_dimension;
|
||||
Torus *in_ptr = (Torus *)input_radix->ptr;
|
||||
output_radix->ptr = (void *)(in_ptr + start_input_lwe_index * lwe_size);
|
||||
|
||||
@@ -131,10 +131,14 @@ __host__ void integer_radix_unsigned_scalar_difference_check_kb(
|
||||
if (num_scalar_blocks == 0) {
|
||||
// We only have to compare blocks with zero
|
||||
// means scalar is zero
|
||||
host_compare_with_zero_equality<Torus>(
|
||||
host_compare_blocks_with_zero<Torus>(
|
||||
streams, gpu_indexes, gpu_count, mem_ptr->tmp_lwe_array_out,
|
||||
lwe_array_in, mem_ptr, bsks, ksks, num_radix_blocks,
|
||||
mem_ptr->is_zero_lut);
|
||||
are_all_comparisons_block_true<Torus>(
|
||||
streams, gpu_indexes, gpu_count, mem_ptr->tmp_lwe_array_out,
|
||||
mem_ptr->tmp_lwe_array_out, mem_ptr, bsks, ksks,
|
||||
mem_ptr->tmp_lwe_array_out->num_radix_blocks);
|
||||
|
||||
auto scalar_last_leaf_lut_f = [sign_handler_f](Torus x) -> Torus {
|
||||
x = (x == 1 ? IS_EQUAL : IS_SUPERIOR);
|
||||
@@ -217,9 +221,13 @@ __host__ void integer_radix_unsigned_scalar_difference_check_kb(
|
||||
num_lsb_radix_blocks);
|
||||
//////////////
|
||||
// msb
|
||||
host_compare_with_zero_equality<Torus>(
|
||||
host_compare_blocks_with_zero<Torus>(
|
||||
msb_streams, gpu_indexes, gpu_count, &lwe_array_msb_out, &msb, mem_ptr,
|
||||
bsks, ksks, num_msb_radix_blocks, mem_ptr->is_zero_lut);
|
||||
are_all_comparisons_block_true<Torus>(
|
||||
msb_streams, gpu_indexes, gpu_count, &lwe_array_msb_out,
|
||||
&lwe_array_msb_out, mem_ptr, bsks, ksks,
|
||||
lwe_array_msb_out.num_radix_blocks);
|
||||
for (uint j = 0; j < mem_ptr->active_gpu_count; j++) {
|
||||
cuda_synchronize_stream(lsb_streams[j], gpu_indexes[j]);
|
||||
cuda_synchronize_stream(msb_streams[j], gpu_indexes[j]);
|
||||
@@ -372,9 +380,12 @@ __host__ void integer_radix_signed_scalar_difference_check_kb(
|
||||
// We only have to compare blocks with zero
|
||||
// means scalar is zero
|
||||
auto are_all_msb_zeros = mem_ptr->tmp_lwe_array_out;
|
||||
host_compare_with_zero_equality<Torus>(
|
||||
host_compare_blocks_with_zero<Torus>(
|
||||
streams, gpu_indexes, gpu_count, are_all_msb_zeros, lwe_array_in,
|
||||
mem_ptr, bsks, ksks, num_radix_blocks, mem_ptr->is_zero_lut);
|
||||
are_all_comparisons_block_true<Torus>(
|
||||
streams, gpu_indexes, gpu_count, are_all_msb_zeros, are_all_msb_zeros,
|
||||
mem_ptr, bsks, ksks, are_all_msb_zeros->num_radix_blocks);
|
||||
CudaRadixCiphertextFFI sign_block;
|
||||
as_radix_ciphertext_slice<Torus>(&sign_block, lwe_array_in,
|
||||
num_radix_blocks - 1, num_radix_blocks);
|
||||
@@ -485,9 +496,13 @@ __host__ void integer_radix_signed_scalar_difference_check_kb(
|
||||
// msb
|
||||
// We remove the last block (which is the sign)
|
||||
auto are_all_msb_zeros = lwe_array_msb_out;
|
||||
host_compare_with_zero_equality<Torus>(
|
||||
host_compare_blocks_with_zero<Torus>(
|
||||
msb_streams, gpu_indexes, gpu_count, &are_all_msb_zeros, &msb, mem_ptr,
|
||||
bsks, ksks, num_msb_radix_blocks, mem_ptr->is_zero_lut);
|
||||
are_all_comparisons_block_true<Torus>(
|
||||
msb_streams, gpu_indexes, gpu_count, &are_all_msb_zeros,
|
||||
&are_all_msb_zeros, mem_ptr, bsks, ksks,
|
||||
are_all_msb_zeros.num_radix_blocks);
|
||||
|
||||
auto sign_bit_pos = (int)log2(message_modulus) - 1;
|
||||
|
||||
@@ -813,9 +828,9 @@ __host__ void host_integer_radix_scalar_equality_check_kb(
|
||||
PANIC("Cuda error: integer operation not supported")
|
||||
}
|
||||
|
||||
host_compare_with_zero_equality<Torus>(msb_streams, gpu_indexes, gpu_count,
|
||||
&msb_out, &msb_in, mem_ptr, bsks,
|
||||
ksks, num_msb_radix_blocks, msb_lut);
|
||||
host_compare_blocks_with_zero<Torus>(msb_streams, gpu_indexes, gpu_count,
|
||||
&msb_out, &msb_in, mem_ptr, bsks, ksks,
|
||||
num_msb_radix_blocks, msb_lut);
|
||||
}
|
||||
|
||||
for (uint j = 0; j < mem_ptr->active_gpu_count; j++) {
|
||||
|
||||
Reference in New Issue
Block a user