fix(gpu): fix bug in are_all_comparison_blocks_true when number of blocks is 0

This commit is contained in:
Beka Barbakadze
2025-10-16 14:23:55 +04:00
committed by Agnès Leroy
parent eed5a6c5ba
commit 39862c2861
2 changed files with 34 additions and 0 deletions

View File

@@ -76,6 +76,13 @@ __host__ void are_all_comparisons_block_true(
auto message_modulus = params.message_modulus;
auto carry_modulus = params.carry_modulus;
if (num_radix_blocks == 0) {
set_single_scalar_trivial_radix_ciphertext_async<Torus>(
streams.stream(0), streams.gpu_index(0), lwe_array_out, 1,
message_modulus, carry_modulus);
return;
}
auto are_all_block_true_buffer =
mem_ptr->eq_buffer->are_all_block_true_buffer;
auto tmp_out = are_all_block_true_buffer->tmp_out;

View File

@@ -207,6 +207,33 @@ __host__ void set_trivial_radix_ciphertext_async(
}
}
// set single trivial value for a radix ciphertext
template <typename Torus>
__host__ void set_single_scalar_trivial_radix_ciphertext_async(
cudaStream_t stream, uint32_t gpu_index,
CudaRadixCiphertextFFI *lwe_array_out, Torus const scalar,
Torus message_modulus, Torus carry_modulus) {
set_zero_radix_ciphertext_slice_async<Torus>(
stream, gpu_index, lwe_array_out, 0, lwe_array_out->num_radix_blocks);
if (lwe_array_out->num_radix_blocks == 0)
return;
// Value of the shift we multiply our messages by
// If message_modulus and carry_modulus are always powers of 2 we can simplify
// this
auto nbits = sizeof(Torus) * 8;
Torus delta = (static_cast<Torus>(1) << (nbits - 1)) /
(message_modulus * carry_modulus);
cuda_set_value_async<Torus>(stream, gpu_index,
(Torus *)lwe_array_out->ptr +
lwe_array_out->lwe_dimension,
delta * scalar, 1);
check_cuda_error(cudaGetLastError());
lwe_array_out->degrees[0] = scalar;
}
// Copy the last radix block of radix_in to the first block of radix_out and
// decrease radix_in num_radix_blocks by 1
template <typename Torus>