mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
fix(gpu): fix bug in are_all_comparison_blocks_true when number of blocks is 0
This commit is contained in:
committed by
Agnès Leroy
parent
eed5a6c5ba
commit
39862c2861
@@ -76,6 +76,13 @@ __host__ void are_all_comparisons_block_true(
|
||||
auto message_modulus = params.message_modulus;
|
||||
auto carry_modulus = params.carry_modulus;
|
||||
|
||||
if (num_radix_blocks == 0) {
|
||||
set_single_scalar_trivial_radix_ciphertext_async<Torus>(
|
||||
streams.stream(0), streams.gpu_index(0), lwe_array_out, 1,
|
||||
message_modulus, carry_modulus);
|
||||
return;
|
||||
}
|
||||
|
||||
auto are_all_block_true_buffer =
|
||||
mem_ptr->eq_buffer->are_all_block_true_buffer;
|
||||
auto tmp_out = are_all_block_true_buffer->tmp_out;
|
||||
|
||||
@@ -207,6 +207,33 @@ __host__ void set_trivial_radix_ciphertext_async(
|
||||
}
|
||||
}
|
||||
|
||||
// set single trivial value for a radix ciphertext
|
||||
template <typename Torus>
|
||||
__host__ void set_single_scalar_trivial_radix_ciphertext_async(
|
||||
cudaStream_t stream, uint32_t gpu_index,
|
||||
CudaRadixCiphertextFFI *lwe_array_out, Torus const scalar,
|
||||
Torus message_modulus, Torus carry_modulus) {
|
||||
|
||||
set_zero_radix_ciphertext_slice_async<Torus>(
|
||||
stream, gpu_index, lwe_array_out, 0, lwe_array_out->num_radix_blocks);
|
||||
if (lwe_array_out->num_radix_blocks == 0)
|
||||
return;
|
||||
|
||||
// Value of the shift we multiply our messages by
|
||||
// If message_modulus and carry_modulus are always powers of 2 we can simplify
|
||||
// this
|
||||
auto nbits = sizeof(Torus) * 8;
|
||||
Torus delta = (static_cast<Torus>(1) << (nbits - 1)) /
|
||||
(message_modulus * carry_modulus);
|
||||
|
||||
cuda_set_value_async<Torus>(stream, gpu_index,
|
||||
(Torus *)lwe_array_out->ptr +
|
||||
lwe_array_out->lwe_dimension,
|
||||
delta * scalar, 1);
|
||||
check_cuda_error(cudaGetLastError());
|
||||
lwe_array_out->degrees[0] = scalar;
|
||||
}
|
||||
|
||||
// Copy the last radix block of radix_in to the first block of radix_out and
|
||||
// decrease radix_in num_radix_blocks by 1
|
||||
template <typename Torus>
|
||||
|
||||
Reference in New Issue
Block a user