diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh index 8656de36b..eb7521582 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cuh @@ -76,6 +76,13 @@ __host__ void are_all_comparisons_block_true( auto message_modulus = params.message_modulus; auto carry_modulus = params.carry_modulus; + if (num_radix_blocks == 0) { + set_single_scalar_trivial_radix_ciphertext_async( + streams.stream(0), streams.gpu_index(0), lwe_array_out, 1, + message_modulus, carry_modulus); + return; + } + auto are_all_block_true_buffer = mem_ptr->eq_buffer->are_all_block_true_buffer; auto tmp_out = are_all_block_true_buffer->tmp_out; diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/radix_ciphertext.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/radix_ciphertext.cuh index 25373b943..0cf04ddff 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/radix_ciphertext.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/radix_ciphertext.cuh @@ -207,6 +207,33 @@ __host__ void set_trivial_radix_ciphertext_async( } } +// set single trivial value for a radix ciphertext +template +__host__ void set_single_scalar_trivial_radix_ciphertext_async( + cudaStream_t stream, uint32_t gpu_index, + CudaRadixCiphertextFFI *lwe_array_out, Torus const scalar, + Torus message_modulus, Torus carry_modulus) { + + set_zero_radix_ciphertext_slice_async( + stream, gpu_index, lwe_array_out, 0, lwe_array_out->num_radix_blocks); + if (lwe_array_out->num_radix_blocks == 0) + return; + + // Value of the shift we multiply our messages by + // If message_modulus and carry_modulus are always powers of 2 we can simplify + // this + auto nbits = sizeof(Torus) * 8; + Torus delta = (static_cast(1) << (nbits - 1)) / + (message_modulus * carry_modulus); + + cuda_set_value_async(stream, gpu_index, + (Torus *)lwe_array_out->ptr + + lwe_array_out->lwe_dimension, + delta * scalar, 1); + check_cuda_error(cudaGetLastError()); + lwe_array_out->degrees[0] = scalar; +} + // Copy the last radix block of radix_in to the first block of radix_out and // decrease radix_in num_radix_blocks by 1 template