diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/div_rem.h b/backends/tfhe-cuda-backend/cuda/include/integer/div_rem.h index d91514b76..8499b16c6 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/div_rem.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/div_rem.h @@ -1044,7 +1044,7 @@ template struct unsigned_int_div_rem_memory { }; int_radix_lut *luts[2] = {message_extract_lut_1, - message_extract_lut_2}; + message_extract_lut_2}; auto active_streams = streams.active_gpu_subset(num_blocks, params.pbs_type); for (int j = 0; j < 2; j++) { @@ -1134,7 +1134,8 @@ template struct unsigned_int_div_rem_memory { zero_out_if_overflow_happened[1]->broadcast_lut(active_streams); // merge_overflow_flags_luts - merge_overflow_flags_luts = new int_radix_lut *[num_bits_in_message]; + merge_overflow_flags_luts = + new int_radix_lut *[num_bits_in_message]; auto active_gpu_count_for_bits = streams.active_gpu_subset(1, params.pbs_type); for (int i = 0; i < num_bits_in_message; i++) { diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer.h index 5b4acc883..a035d22b1 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer.h @@ -274,6 +274,16 @@ uint64_t scratch_cuda_comparison_64( bool is_signed, bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type); +uint64_t scratch_cuda_comparison_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t lwe_ciphertext_count, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, COMPARISON_TYPE op_type, + bool is_signed, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type); + void cuda_comparison_ciphertext_64(CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out, CudaRadixCiphertextFFI const *lwe_array_1, @@ -281,6 +291,12 @@ void cuda_comparison_ciphertext_64(CudaStreamsFFI streams, int8_t *mem_ptr, void *const *bsks, void *const *ksks); +void cuda_comparison_ciphertext_64_ks32( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out, + CudaRadixCiphertextFFI const *lwe_array_1, + CudaRadixCiphertextFFI const *lwe_array_2, int8_t *mem_ptr, + void *const *bsks, void *const *ksks); + void cuda_scalar_comparison_ciphertext_64( CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out, CudaRadixCiphertextFFI const *lwe_array_in, void const *scalar_blocks, @@ -323,6 +339,20 @@ void cuda_boolean_bitnot_ciphertext_64(CudaStreamsFFI streams, int8_t *mem_ptr, void *const *bsks, void *const *ksks); +uint64_t scratch_cuda_boolean_bitnot_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type, + uint32_t lwe_ciphertext_count, bool is_unchecked, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type); + +void cuda_boolean_bitnot_ciphertext_64_ks32(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *lwe_array, + int8_t *mem_ptr, void *const *bsks, + void *const *ksks); + void cleanup_cuda_boolean_bitnot(CudaStreamsFFI streams, int8_t **mem_ptr_void); void cuda_bitnot_ciphertext_64(CudaStreamsFFI streams, @@ -340,6 +370,15 @@ uint64_t scratch_cuda_bitop_64( uint32_t carry_modulus, PBS_TYPE pbs_type, BITOP_TYPE op_type, bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type); +uint64_t scratch_cuda_bitop_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t lwe_ciphertext_count, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, BITOP_TYPE op_type, + bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type); + void cuda_scalar_bitop_ciphertext_64( CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out, CudaRadixCiphertextFFI const *lwe_array_input, void const *clear_blocks, @@ -353,6 +392,13 @@ void cuda_bitop_ciphertext_64(CudaStreamsFFI streams, int8_t *mem_ptr, void *const *bsks, void *const *ksks); +void cuda_bitop_ciphertext_64_ks32(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *lwe_array_out, + CudaRadixCiphertextFFI const *lwe_array_1, + CudaRadixCiphertextFFI const *lwe_array_2, + int8_t *mem_ptr, void *const *bsks, + void *const *ksks); + void cleanup_cuda_integer_bitop(CudaStreamsFFI streams, int8_t **mem_ptr_void); uint64_t scratch_cuda_cmux_64(CudaStreamsFFI streams, int8_t **mem_ptr, @@ -366,6 +412,15 @@ uint64_t scratch_cuda_cmux_64(CudaStreamsFFI streams, int8_t **mem_ptr, PBS_TYPE pbs_type, bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type); +uint64_t scratch_cuda_cmux_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t lwe_ciphertext_count, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type); + void cuda_cmux_ciphertext_64(CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out, CudaRadixCiphertextFFI const *lwe_condition, @@ -374,6 +429,14 @@ void cuda_cmux_ciphertext_64(CudaStreamsFFI streams, int8_t *mem_ptr, void *const *bsks, void *const *ksks); +void cuda_cmux_ciphertext_64_ks32(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *lwe_array_out, + CudaRadixCiphertextFFI const *lwe_condition, + CudaRadixCiphertextFFI const *lwe_array_true, + CudaRadixCiphertextFFI const *lwe_array_false, + int8_t *mem_ptr, void *const *bsks, + void *const *ksks); + void cleanup_cuda_cmux(CudaStreamsFFI streams, int8_t **mem_ptr_void); uint64_t scratch_cuda_scalar_rotate_64( @@ -452,6 +515,15 @@ uint64_t scratch_cuda_integer_overflowing_sub_64_inplace( PBS_TYPE pbs_type, uint32_t compute_overflow, bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type); +uint64_t scratch_cuda_integer_overflowing_sub_64_ks32_inplace( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, uint32_t compute_overflow, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type); + void cuda_integer_overflowing_sub_64_inplace( CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array, const CudaRadixCiphertextFFI *rhs_array, @@ -460,6 +532,14 @@ void cuda_integer_overflowing_sub_64_inplace( void *const *bsks, void *const *ksks, uint32_t compute_overflow, uint32_t uses_input_borrow); +void cuda_integer_overflowing_sub_64_ks32_inplace( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array, + const CudaRadixCiphertextFFI *rhs_array, + CudaRadixCiphertextFFI *overflow_block, + const CudaRadixCiphertextFFI *input_borrow, int8_t *mem_ptr, + void *const *bsks, void *const *ksks, uint32_t compute_overflow, + uint32_t uses_input_borrow); + void cleanup_cuda_integer_overflowing_sub(CudaStreamsFFI streams, int8_t **mem_ptr_void); @@ -844,12 +924,29 @@ uint64_t scratch_cuda_cast_to_unsigned_64( uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type); +uint64_t scratch_cuda_cast_to_unsigned_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed, + bool requires_full_propagate, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type); + void cuda_cast_to_unsigned_64(CudaStreamsFFI streams, CudaRadixCiphertextFFI *output, CudaRadixCiphertextFFI *input, int8_t *mem_ptr, uint32_t target_num_blocks, bool input_is_signed, void *const *bsks, void *const *ksks); +void cuda_cast_to_unsigned_64_ks32(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *output, + CudaRadixCiphertextFFI *input, + int8_t *mem_ptr, uint32_t target_num_blocks, + bool input_is_signed, void *const *bsks, + void *const *ksks); + void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams, int8_t **mem_ptr_void); diff --git a/backends/tfhe-cuda-backend/cuda/src/aes/aes256.cuh b/backends/tfhe-cuda-backend/cuda/src/aes/aes256.cuh index 44075c50d..e741dce2e 100644 --- a/backends/tfhe-cuda-backend/cuda/src/aes/aes256.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/aes/aes256.cuh @@ -34,7 +34,7 @@ __host__ void vectorized_aes_256_encrypt_inplace( CudaStreams streams, CudaRadixCiphertextFFI *all_states_bitsliced, CudaRadixCiphertextFFI const *round_keys, uint32_t num_aes_inputs, int_aes_encrypt_buffer *mem, void *const *bsks, - Torus *const *ksks) { + KSTorus *const *ksks) { constexpr uint32_t BITS_PER_BYTE = 8; constexpr uint32_t STATE_BYTES = 16; @@ -186,7 +186,7 @@ __host__ void host_integer_aes_ctr_256_encrypt( CudaRadixCiphertextFFI const *iv, CudaRadixCiphertextFFI const *round_keys, const Torus *counter_bits_le_all_blocks, uint32_t num_aes_inputs, int_aes_encrypt_buffer *mem, void *const *bsks, - Torus *const *ksks) { + KSTorus *const *ksks) { constexpr uint32_t NUM_BITS = 128; @@ -245,7 +245,7 @@ __host__ void host_integer_key_expansion_256( CudaStreams streams, CudaRadixCiphertextFFI *expanded_keys, CudaRadixCiphertextFFI const *key, int_key_expansion_256_buffer *mem, void *const *bsks, - Torus *const *ksks) { + KSTorus *const *ksks) { constexpr uint32_t BITS_PER_WORD = 32; constexpr uint32_t BITS_PER_BYTE = 8; diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/abs.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/abs.cuh index dcf476c35..8db8e445d 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/abs.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/abs.cuh @@ -23,10 +23,10 @@ __host__ uint64_t scratch_cuda_integer_abs( return size_tracker; } -template +template __host__ void host_integer_abs(CudaStreams streams, CudaRadixCiphertextFFI *ct, - void *const *bsks, uint64_t *const *ksks, - int_abs_buffer *mem_ptr, + void *const *bsks, KSTorus *const *ksks, + int_abs_buffer *mem_ptr, bool is_signed) { if (!is_signed) return; @@ -40,7 +40,7 @@ __host__ void host_integer_abs(CudaStreams streams, CudaRadixCiphertextFFI *ct, copy_radix_ciphertext_async(streams.stream(0), streams.gpu_index(0), mask, ct); - host_arithmetic_scalar_shift_inplace( + host_arithmetic_scalar_shift_inplace( streams, mask, num_bits_in_ciphertext - 1, mem_ptr->arithmetic_scalar_shift_mem, bsks, ksks); host_addition(streams.stream(0), streams.gpu_index(0), ct, mask, ct, @@ -49,11 +49,12 @@ __host__ void host_integer_abs(CudaStreams streams, CudaRadixCiphertextFFI *ct, uint32_t requested_flag = outputFlag::FLAG_NONE; uint32_t uses_carry = 0; - host_propagate_single_carry(streams, ct, nullptr, nullptr, - mem_ptr->scp_mem, bsks, ksks, - requested_flag, uses_carry); + host_propagate_single_carry(streams, ct, nullptr, nullptr, + mem_ptr->scp_mem, bsks, ksks, + requested_flag, uses_carry); - host_bitop(streams, ct, mask, ct, mem_ptr->bitxor_mem, bsks, ksks); + host_bitop(streams, ct, mask, ct, mem_ptr->bitxor_mem, bsks, + ksks); } #endif // TFHE_RS_ABS_CUH diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/bitwise_ops.cu b/backends/tfhe-cuda-backend/cuda/src/integer/bitwise_ops.cu index d2fc56f6a..d6562e4d6 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/bitwise_ops.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/bitwise_ops.cu @@ -63,6 +63,26 @@ uint64_t scratch_cuda_boolean_bitnot_64( lwe_ciphertext_count, is_unchecked, allocate_gpu_memory); } +uint64_t scratch_cuda_boolean_bitnot_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type, + uint32_t lwe_ciphertext_count, bool is_unchecked, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus, noise_reduction_type); + + return scratch_cuda_boolean_bitnot( + CudaStreams(streams), + (boolean_bitnot_buffer **)mem_ptr, params, + lwe_ciphertext_count, is_unchecked, allocate_gpu_memory); +} + void cuda_boolean_bitnot_ciphertext_64(CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array, int8_t *mem_ptr, void *const *bsks, @@ -73,6 +93,16 @@ void cuda_boolean_bitnot_ciphertext_64(CudaStreamsFFI streams, (uint64_t **)(ksks)); } +void cuda_boolean_bitnot_ciphertext_64_ks32(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *lwe_array, + int8_t *mem_ptr, void *const *bsks, + void *const *ksks) { + host_boolean_bitnot( + CudaStreams(streams), lwe_array, + (boolean_bitnot_buffer *)mem_ptr, bsks, + (uint32_t **)(ksks)); +} + void cleanup_cuda_boolean_bitnot(CudaStreamsFFI streams, int8_t **mem_ptr_void) { @@ -102,6 +132,25 @@ uint64_t scratch_cuda_bitop_64( lwe_ciphertext_count, params, op_type, allocate_gpu_memory); } +uint64_t scratch_cuda_bitop_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t lwe_ciphertext_count, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, BITOP_TYPE op_type, + bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus, noise_reduction_type); + + return scratch_cuda_bitop( + CudaStreams(streams), (int_bitop_buffer **)mem_ptr, + lwe_ciphertext_count, params, op_type, allocate_gpu_memory); +} + void cuda_bitnot_ciphertext_64(CudaStreamsFFI streams, CudaRadixCiphertextFFI *radix_ciphertext, uint32_t ct_message_modulus, @@ -126,6 +175,19 @@ void cuda_bitop_ciphertext_64(CudaStreamsFFI streams, (uint64_t **)(ksks)); } +void cuda_bitop_ciphertext_64_ks32(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *lwe_array_out, + CudaRadixCiphertextFFI const *lwe_array_1, + CudaRadixCiphertextFFI const *lwe_array_2, + int8_t *mem_ptr, void *const *bsks, + void *const *ksks) { + + host_bitop(CudaStreams(streams), lwe_array_out, lwe_array_1, + lwe_array_2, + (int_bitop_buffer *)mem_ptr, bsks, + (uint32_t **)(ksks)); +} + void cleanup_cuda_integer_bitop(CudaStreamsFFI streams, int8_t **mem_ptr_void) { int_bitop_buffer *mem_ptr = diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/cast.cu b/backends/tfhe-cuda-backend/cuda/src/integer/cast.cu index 6383681f7..bf770a767 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/cast.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/cast.cu @@ -49,6 +49,28 @@ uint64_t scratch_cuda_cast_to_unsigned_64( requires_full_propagate, allocate_gpu_memory); } +uint64_t scratch_cuda_cast_to_unsigned_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_input_blocks, uint32_t target_num_blocks, bool input_is_signed, + bool requires_full_propagate, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus, noise_reduction_type); + + return scratch_cuda_cast_to_unsigned( + CudaStreams(streams), + (int_cast_to_unsigned_buffer **)mem_ptr, params, + num_input_blocks, target_num_blocks, input_is_signed, + requires_full_propagate, allocate_gpu_memory); +} + void cuda_cast_to_unsigned_64(CudaStreamsFFI streams, CudaRadixCiphertextFFI *output, CudaRadixCiphertextFFI *input, int8_t *mem_ptr, @@ -61,6 +83,19 @@ void cuda_cast_to_unsigned_64(CudaStreamsFFI streams, target_num_blocks, input_is_signed, bsks, (uint64_t **)ksks); } +void cuda_cast_to_unsigned_64_ks32(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *output, + CudaRadixCiphertextFFI *input, + int8_t *mem_ptr, uint32_t target_num_blocks, + bool input_is_signed, void *const *bsks, + void *const *ksks) { + + host_cast_to_unsigned( + CudaStreams(streams), output, input, + (int_cast_to_unsigned_buffer *)mem_ptr, + target_num_blocks, input_is_signed, bsks, (uint32_t **)ksks); +} + void cleanup_cuda_cast_to_unsigned_64(CudaStreamsFFI streams, int8_t **mem_ptr_void) { int_cast_to_unsigned_buffer *mem_ptr = diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/cast.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/cast.cuh index de59f5b42..a226c6e49 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/cast.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/cast.cuh @@ -130,7 +130,7 @@ host_cast_to_unsigned(CudaStreams streams, CudaRadixCiphertextFFI *output, CudaRadixCiphertextFFI *input, int_cast_to_unsigned_buffer *mem_ptr, uint32_t target_num_blocks, bool input_is_signed, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { uint32_t current_num_blocks = input->num_radix_blocks; @@ -143,9 +143,9 @@ host_cast_to_unsigned(CudaStreams streams, CudaRadixCiphertextFFI *output, uint32_t num_blocks_to_add = target_num_blocks - current_num_blocks; if (input_is_signed) { - host_extend_radix_with_sign_msb( + host_extend_radix_with_sign_msb( streams, output, input, mem_ptr->extend_buffer, num_blocks_to_add, - bsks, (Torus **)ksks); + bsks, ksks); } else { host_extend_radix_with_trivial_zero_blocks_msb(output, input, streams); diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/cmux.cu b/backends/tfhe-cuda-backend/cuda/src/integer/cmux.cu index 93918c2ec..277ee2b2f 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/cmux.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/cmux.cu @@ -26,6 +26,30 @@ uint64_t scratch_cuda_cmux_64(CudaStreamsFFI streams, int8_t **mem_ptr, return ret; } +uint64_t scratch_cuda_cmux_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t lwe_ciphertext_count, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type) { + PUSH_RANGE("scratch cmux") + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus, noise_reduction_type); + + std::function predicate_lut_f = + [](uint64_t x) -> uint64_t { return x == 1; }; + + uint64_t ret = scratch_cuda_cmux( + CudaStreams(streams), (int_cmux_buffer **)mem_ptr, + predicate_lut_f, lwe_ciphertext_count, params, allocate_gpu_memory); + POP_RANGE() + return ret; +} + void cuda_cmux_ciphertext_64(CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out, CudaRadixCiphertextFFI const *lwe_condition, @@ -34,10 +58,25 @@ void cuda_cmux_ciphertext_64(CudaStreamsFFI streams, int8_t *mem_ptr, void *const *bsks, void *const *ksks) { PUSH_RANGE("cmux") - host_cmux(CudaStreams(streams), lwe_array_out, lwe_condition, - lwe_array_true, lwe_array_false, - (int_cmux_buffer *)mem_ptr, bsks, - (uint64_t **)(ksks)); + host_cmux(CudaStreams(streams), lwe_array_out, + lwe_condition, lwe_array_true, lwe_array_false, + (int_cmux_buffer *)mem_ptr, + bsks, (uint64_t **)(ksks)); + POP_RANGE() +} + +void cuda_cmux_ciphertext_64_ks32(CudaStreamsFFI streams, + CudaRadixCiphertextFFI *lwe_array_out, + CudaRadixCiphertextFFI const *lwe_condition, + CudaRadixCiphertextFFI const *lwe_array_true, + CudaRadixCiphertextFFI const *lwe_array_false, + int8_t *mem_ptr, void *const *bsks, + void *const *ksks) { + PUSH_RANGE("cmux") + host_cmux(CudaStreams(streams), lwe_array_out, + lwe_condition, lwe_array_true, lwe_array_false, + (int_cmux_buffer *)mem_ptr, + bsks, (uint32_t **)(ksks)); POP_RANGE() } diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu index b7400362f..620d39da4 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/comparison.cu @@ -39,6 +39,45 @@ uint64_t scratch_cuda_comparison_64( return size_tracker; } +uint64_t scratch_cuda_comparison_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_radix_blocks, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, COMPARISON_TYPE op_type, bool is_signed, + bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) { + PUSH_RANGE("scratch comparison") + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus, noise_reduction_type); + + uint64_t size_tracker = 0; + switch (op_type) { + case EQ: + case NE: + size_tracker += scratch_cuda_comparison_check( + CudaStreams(streams), + (int_comparison_buffer **)mem_ptr, num_radix_blocks, + params, op_type, false, allocate_gpu_memory); + break; + case GT: + case GE: + case LT: + case LE: + case MAX: + case MIN: + size_tracker += scratch_cuda_comparison_check( + CudaStreams(streams), + (int_comparison_buffer **)mem_ptr, num_radix_blocks, + params, op_type, is_signed, allocate_gpu_memory); + break; + } + POP_RANGE() + return size_tracker; +} + void cuda_comparison_ciphertext_64(CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out, CudaRadixCiphertextFFI const *lwe_array_1, @@ -87,6 +126,53 @@ void cuda_comparison_ciphertext_64(CudaStreamsFFI streams, POP_RANGE() } +void cuda_comparison_ciphertext_64_ks32( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array_out, + CudaRadixCiphertextFFI const *lwe_array_1, + CudaRadixCiphertextFFI const *lwe_array_2, int8_t *mem_ptr, + void *const *bsks, void *const *ksks) { + PUSH_RANGE("comparison") + if (lwe_array_1->num_radix_blocks != lwe_array_2->num_radix_blocks) + PANIC("Cuda error: input num radix blocks must be the same") + // The output ciphertext might be a boolean block or a radix ciphertext + // depending on the case (eq/gt vs max/min) so the amount of blocks to + // consider for calculation is the one of the input + auto num_radix_blocks = lwe_array_1->num_radix_blocks; + int_comparison_buffer *buffer = + (int_comparison_buffer *)mem_ptr; + switch (buffer->op) { + case EQ: + case NE: + host_equality_check(CudaStreams(streams), lwe_array_out, + lwe_array_1, lwe_array_2, buffer, bsks, + (uint32_t **)(ksks), num_radix_blocks); + break; + case GT: + case GE: + case LT: + case LE: + if (num_radix_blocks % 2 != 0) + PANIC("Cuda error (comparisons): the number of radix blocks has to be " + "even.") + host_difference_check(CudaStreams(streams), lwe_array_out, + lwe_array_1, lwe_array_2, buffer, + buffer->diff_buffer->operator_f, bsks, + (uint32_t **)(ksks), num_radix_blocks); + break; + case MAX: + case MIN: + if (num_radix_blocks % 2 != 0) + PANIC("Cuda error (max/min): the number of radix blocks has to be even.") + host_maxmin(CudaStreams(streams), lwe_array_out, lwe_array_1, + lwe_array_2, buffer, bsks, (uint32_t **)(ksks), + num_radix_blocks); + break; + default: + PANIC("Cuda error: integer operation not supported") + } + POP_RANGE() +} + void cleanup_cuda_integer_comparison(CudaStreamsFFI streams, int8_t **mem_ptr_void) { PUSH_RANGE("cleanup comparison") diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/div_rem.cu b/backends/tfhe-cuda-backend/cuda/src/integer/div_rem.cu index b2589b84d..640de2a14 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/div_rem.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/div_rem.cu @@ -21,6 +21,27 @@ uint64_t scratch_cuda_integer_div_rem_radix_ciphertext_64( POP_RANGE() } +uint64_t scratch_cuda_integer_div_rem_radix_ciphertext_64_ks32( + CudaStreamsFFI streams, bool is_signed, int8_t **mem_ptr, + uint32_t glwe_dimension, uint32_t polynomial_size, + uint32_t big_lwe_dimension, uint32_t small_lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t num_blocks, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type) { + PUSH_RANGE("scratch div") + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus, noise_reduction_type); + + return scratch_cuda_integer_div_rem( + CudaStreams(streams), is_signed, + (int_div_rem_memory **)mem_ptr, num_blocks, params, + allocate_gpu_memory); + POP_RANGE() +} + void cuda_integer_div_rem_radix_ciphertext_64( CudaStreamsFFI streams, CudaRadixCiphertextFFI *quotient, CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator, @@ -35,6 +56,20 @@ void cuda_integer_div_rem_radix_ciphertext_64( POP_RANGE() } +void cuda_integer_div_rem_radix_ciphertext_64_ks32( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *quotient, + CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator, + CudaRadixCiphertextFFI const *divisor, bool is_signed, int8_t *mem_ptr, + void *const *bsks, void *const *ksks) { + PUSH_RANGE("div") + auto mem = (int_div_rem_memory *)mem_ptr; + + host_integer_div_rem( + CudaStreams(streams), quotient, remainder, numerator, divisor, is_signed, + bsks, (uint32_t **)(ksks), mem); + POP_RANGE() +} + void cleanup_cuda_integer_div_rem(CudaStreamsFFI streams, int8_t **mem_ptr_void) { PUSH_RANGE("cleanup div") diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/div_rem.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/div_rem.cuh index c72ac9682..3fa776d27 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/div_rem.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/div_rem.cuh @@ -31,8 +31,8 @@ __host__ void host_unsigned_integer_div_rem_block_by_block_2_2( CudaStreams streams, CudaRadixCiphertextFFI *quotient, CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator, CudaRadixCiphertextFFI const *divisor, void *const *bsks, - uint64_t *const *ksks, - unsigned_int_div_rem_2_2_memory *mem_ptr) { + KSTorus *const *ksks, + unsigned_int_div_rem_2_2_memory *mem_ptr) { if (streams.count() < 4) { PANIC("GPU count should be greater than 4 when using div_rem_2_2"); @@ -476,8 +476,8 @@ __host__ void host_unsigned_integer_div_rem( CudaStreams streams, CudaRadixCiphertextFFI *quotient, CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator, CudaRadixCiphertextFFI const *divisor, void *const *bsks, - uint64_t *const *ksks, - unsigned_int_div_rem_memory *mem_ptr) { + KSTorus *const *ksks, + unsigned_int_div_rem_memory *mem_ptr) { if (remainder->num_radix_blocks != numerator->num_radix_blocks || remainder->num_radix_blocks != divisor->num_radix_blocks || @@ -740,7 +740,7 @@ __host__ void host_unsigned_integer_div_rem( mem_ptr->overflow_sub_mem->update_lut_indexes( streams, first_indexes, second_indexes, scalar_indexes, merged_interesting_remainder->num_radix_blocks); - host_integer_overflowing_sub( + host_integer_overflowing_sub( streams, new_remainder, merged_interesting_remainder, interesting_divisor, subtraction_overflowed, (const CudaRadixCiphertextFFI *)nullptr, mem_ptr->overflow_sub_mem, @@ -903,13 +903,11 @@ __host__ void host_unsigned_integer_div_rem( } template -__host__ void -host_integer_div_rem(CudaStreams streams, CudaRadixCiphertextFFI *quotient, - CudaRadixCiphertextFFI *remainder, - CudaRadixCiphertextFFI const *numerator, - CudaRadixCiphertextFFI const *divisor, bool is_signed, - void *const *bsks, uint64_t *const *ksks, - int_div_rem_memory *int_mem_ptr) { +__host__ void host_integer_div_rem( + CudaStreams streams, CudaRadixCiphertextFFI *quotient, + CudaRadixCiphertextFFI *remainder, CudaRadixCiphertextFFI const *numerator, + CudaRadixCiphertextFFI const *divisor, bool is_signed, void *const *bsks, + KSTorus *const *ksks, int_div_rem_memory *int_mem_ptr) { if (remainder->num_radix_blocks != numerator->num_radix_blocks || remainder->num_radix_blocks != divisor->num_radix_blocks || remainder->num_radix_blocks != quotient->num_radix_blocks) diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu index 68c96718a..139c5835b 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cu @@ -116,6 +116,25 @@ uint64_t scratch_cuda_integer_overflowing_sub_64_inplace( params, compute_overflow, allocate_gpu_memory); } +uint64_t scratch_cuda_integer_overflowing_sub_64_ks32_inplace( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, uint32_t compute_overflow, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type) { + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus, noise_reduction_type); + + return scratch_cuda_integer_overflowing_sub( + CudaStreams(streams), + (int_borrow_prop_memory **)mem_ptr, num_blocks, + params, compute_overflow, allocate_gpu_memory); +} + void cuda_propagate_single_carry_64_inplace( CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array, CudaRadixCiphertextFFI *carry_out, const CudaRadixCiphertextFFI *carry_in, @@ -128,6 +147,18 @@ void cuda_propagate_single_carry_64_inplace( (uint64_t **)(ksks), requested_flag, uses_carry); } +void cuda_propagate_single_carry_64_ks32_inplace( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array, + CudaRadixCiphertextFFI *carry_out, const CudaRadixCiphertextFFI *carry_in, + int8_t *mem_ptr, void *const *bsks, void *const *ksks, + uint32_t requested_flag, uint32_t uses_carry) { + + host_propagate_single_carry( + CudaStreams(streams), lwe_array, carry_out, carry_in, + (int_sc_prop_memory *)mem_ptr, bsks, + (uint32_t **)(ksks), requested_flag, uses_carry); +} + void cuda_add_and_propagate_single_carry_64_inplace( CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array, const CudaRadixCiphertextFFI *rhs_array, CudaRadixCiphertextFFI *carry_out, @@ -146,10 +177,10 @@ void cuda_add_and_propagate_single_carry_64_ks32_inplace( const CudaRadixCiphertextFFI *carry_in, int8_t *mem_ptr, void *const *bsks, void *const *ksks, uint32_t requested_flag, uint32_t uses_carry) { - host_add_and_propagate_single_carry( - CudaStreams(streams), lhs_array, rhs_array, carry_out, carry_in, - (int_sc_prop_memory *)mem_ptr, bsks, - (uint32_t **)(ksks), requested_flag, uses_carry); + host_add_and_propagate_single_carry( + CudaStreams(streams), lhs_array, rhs_array, carry_out, carry_in, + (int_sc_prop_memory *)mem_ptr, bsks, + (uint32_t **)(ksks), requested_flag, uses_carry); } void cuda_integer_overflowing_sub_64_inplace( @@ -167,6 +198,21 @@ void cuda_integer_overflowing_sub_64_inplace( POP_RANGE() } +void cuda_integer_overflowing_sub_64_ks32_inplace( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array, + const CudaRadixCiphertextFFI *rhs_array, + CudaRadixCiphertextFFI *overflow_block, + const CudaRadixCiphertextFFI *input_borrow, int8_t *mem_ptr, + void *const *bsks, void *const *ksks, uint32_t compute_overflow, + uint32_t uses_input_borrow) { + PUSH_RANGE("overflow sub") + host_integer_overflowing_sub( + CudaStreams(streams), lhs_array, lhs_array, rhs_array, overflow_block, + input_borrow, (int_borrow_prop_memory *)mem_ptr, bsks, + (uint32_t **)ksks, compute_overflow, uses_input_borrow); + POP_RANGE() +} + void cleanup_cuda_propagate_single_carry(CudaStreamsFFI streams, int8_t **mem_ptr_void) { PUSH_RANGE("cleanup propagate sc") diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/oprf.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/oprf.cuh index 998abf350..b87e30d72 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/oprf.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/oprf.cuh @@ -115,7 +115,7 @@ void host_integer_grouped_oprf_custom_range( const Torus *decomposed_scalar, const Torus *has_at_least_one_set, uint32_t num_scalars, uint32_t shift, int_grouped_oprf_custom_range_memory *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { CudaRadixCiphertextFFI *computation_buffer = mem_ptr->tmp_oprf_output; set_zero_radix_ciphertext_slice_async( diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/subtraction.cu b/backends/tfhe-cuda-backend/cuda/src/integer/subtraction.cu index 050cf7d2d..2b9990efe 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/subtraction.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/subtraction.cu @@ -20,6 +20,26 @@ uint64_t scratch_cuda_sub_and_propagate_single_carry_64_inplace( requested_flag, allocate_gpu_memory); } +uint64_t scratch_cuda_sub_and_propagate_single_carry_64_ks32_inplace( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, uint32_t requested_flag, bool allocate_gpu_memory, + PBS_MS_REDUCTION_T noise_reduction_type) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus, noise_reduction_type); + + return scratch_cuda_sub_and_propagate_single_carry( + CudaStreams(streams), + (int_sub_and_propagate **)mem_ptr, num_blocks, params, + requested_flag, allocate_gpu_memory); +} + void cuda_sub_and_propagate_single_carry_64_inplace( CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array, const CudaRadixCiphertextFFI *rhs_array, CudaRadixCiphertextFFI *carry_out, @@ -33,6 +53,19 @@ void cuda_sub_and_propagate_single_carry_64_inplace( POP_RANGE() } +void cuda_sub_and_propagate_single_carry_64_ks32_inplace( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *lhs_array, + const CudaRadixCiphertextFFI *rhs_array, CudaRadixCiphertextFFI *carry_out, + const CudaRadixCiphertextFFI *carry_in, int8_t *mem_ptr, void *const *bsks, + void *const *ksks, uint32_t requested_flag, uint32_t uses_carry) { + PUSH_RANGE("sub") + host_sub_and_propagate_single_carry( + CudaStreams(streams), lhs_array, rhs_array, carry_out, carry_in, + (int_sub_and_propagate *)mem_ptr, bsks, + (uint32_t **)(ksks), requested_flag, uses_carry); + POP_RANGE() +} + void cleanup_cuda_sub_and_propagate_single_carry(CudaStreamsFFI streams, int8_t **mem_ptr_void) { PUSH_RANGE("cleanup sub") diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/subtraction.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/subtraction.cuh index a6e0b1e75..189280fbc 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/subtraction.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/subtraction.cuh @@ -95,7 +95,8 @@ __host__ void host_integer_overflowing_sub( CudaRadixCiphertextFFI *overflow_block, const CudaRadixCiphertextFFI *input_borrow, int_borrow_prop_memory *mem_ptr, void *const *bsks, - Torus *const *ksks, uint32_t compute_overflow, uint32_t uses_input_borrow) { + KSTorus *const *ksks, uint32_t compute_overflow, + uint32_t uses_input_borrow) { PUSH_RANGE("overflowing sub") if (output->num_radix_blocks != input_left->num_radix_blocks || output->num_radix_blocks != input_right->num_radix_blocks) @@ -124,8 +125,8 @@ __host__ void host_integer_overflowing_sub( host_single_borrow_propagate( streams, output, overflow_block, input_borrow, - (int_borrow_prop_memory *)mem_ptr, bsks, (Torus **)(ksks), - num_groups, compute_overflow, uses_input_borrow); + (int_borrow_prop_memory *)mem_ptr, bsks, + (KSTorus **)(ksks), num_groups, compute_overflow, uses_input_borrow); POP_RANGE() } diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/vector_comparison.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/vector_comparison.cuh index ebe9b75d0..0808e460b 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/vector_comparison.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/vector_comparison.cuh @@ -25,7 +25,7 @@ __host__ void host_unchecked_all_eq_slices( CudaRadixCiphertextFFI const *lhs, CudaRadixCiphertextFFI const *rhs, uint32_t num_inputs, uint32_t num_blocks, int_unchecked_all_eq_slices_buffer *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { // sync_from(streams) // @@ -104,7 +104,7 @@ __host__ void host_unchecked_contains_sub_slice( CudaRadixCiphertextFFI const *lhs, CudaRadixCiphertextFFI const *rhs, uint32_t num_rhs, uint32_t num_blocks, int_unchecked_contains_sub_slice_buffer *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { uint32_t num_windows = mem_ptr->num_windows; diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cuh index 90d01b3e5..19490b91c 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/vector_find.cuh @@ -14,7 +14,7 @@ __host__ void host_compute_equality_selectors( CudaRadixCiphertextFFI const *lwe_array_in, uint32_t num_blocks, const uint64_t *h_decomposed_cleartexts, int_equality_selectors_buffer *mem_ptr, void *const *bsks, - Torus *const *ksks) { + KSTorus *const *ksks) { uint32_t num_possible_values = mem_ptr->num_possible_values; uint32_t message_modulus = mem_ptr->params.message_modulus; @@ -91,7 +91,7 @@ __host__ void host_create_possible_results( CudaRadixCiphertextFFI const *lwe_array_in_list, uint32_t num_possible_values, const uint64_t *h_decomposed_cleartexts, uint32_t num_blocks, int_possible_results_buffer *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { uint32_t max_packed_value = mem_ptr->max_packed_value; uint32_t max_luts_per_call = mem_ptr->max_luts_per_call; @@ -174,7 +174,7 @@ __host__ void host_aggregate_one_hot_vector( CudaRadixCiphertextFFI const *lwe_array_in_list, uint32_t num_input_ciphertexts, uint32_t num_blocks, int_aggregate_one_hot_buffer *mem_ptr, void *const *bsks, - Torus *const *ksks) { + KSTorus *const *ksks) { int_radix_params params = mem_ptr->params; uint32_t chunk_size = mem_ptr->chunk_size; @@ -345,7 +345,7 @@ __host__ void host_unchecked_match_value( CudaRadixCiphertextFFI const *lwe_array_in_ct, const uint64_t *h_match_inputs, const uint64_t *h_match_outputs, int_unchecked_match_buffer *mem_ptr, void *const *bsks, - Torus *const *ksks) { + KSTorus *const *ksks) { host_compute_equality_selectors( streams, mem_ptr->selectors_list, lwe_array_in_ct, mem_ptr->num_input_blocks, h_match_inputs, mem_ptr->eq_selectors_buffer, @@ -422,7 +422,7 @@ __host__ void host_unchecked_match_value_or( const uint64_t *h_match_inputs, const uint64_t *h_match_outputs, const uint64_t *h_or_value, int_unchecked_match_value_or_buffer *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { host_unchecked_match_value(streams, mem_ptr->tmp_match_result, mem_ptr->tmp_match_bool, lwe_array_in_ct, @@ -465,7 +465,7 @@ host_unchecked_contains(CudaStreams streams, CudaRadixCiphertextFFI *output, CudaRadixCiphertextFFI const *value, uint32_t num_inputs, uint32_t num_blocks, int_unchecked_contains_buffer *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { mem_ptr->internal_cuda_streams.internal_streams_wait_for_main_stream_0( streams); @@ -516,7 +516,7 @@ __host__ void host_unchecked_contains_clear( CudaRadixCiphertextFFI const *inputs, const uint64_t *h_clear_val, uint32_t num_inputs, uint32_t num_blocks, int_unchecked_contains_clear_buffer *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { cuda_memcpy_async_to_gpu(mem_ptr->d_clear_val, h_clear_val, num_blocks * sizeof(Torus), streams.stream(0), @@ -577,7 +577,7 @@ __host__ void host_unchecked_is_in_clears( CudaRadixCiphertextFFI const *input, const uint64_t *h_cleartexts, uint32_t num_clears, uint32_t num_blocks, int_unchecked_is_in_clears_buffer *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { host_compute_equality_selectors(streams, mem_ptr->unpacked_selectors, input, num_blocks, h_cleartexts, @@ -594,7 +594,7 @@ __host__ void host_compute_final_index_from_selectors( CudaRadixCiphertextFFI *match_ct, CudaRadixCiphertextFFI const *selectors, uint32_t num_inputs, uint32_t num_blocks_index, int_final_index_from_selectors_buffer *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { for (uint32_t i = 0; i < num_inputs; i++) { CudaRadixCiphertextFFI const *src_selector = &selectors[i]; @@ -657,7 +657,7 @@ __host__ void host_unchecked_index_in_clears( const uint64_t *h_cleartexts, uint32_t num_clears, uint32_t num_blocks, uint32_t num_blocks_index, int_unchecked_index_in_clears_buffer *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { host_compute_equality_selectors( streams, mem_ptr->final_index_buf->unpacked_selectors, input, num_blocks, @@ -704,7 +704,7 @@ __host__ void host_unchecked_first_index_in_clears( const uint64_t *h_unique_values, const uint64_t *h_unique_indices, uint32_t num_unique, uint32_t num_blocks, uint32_t num_blocks_index, int_unchecked_first_index_in_clears_buffer *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { host_compute_equality_selectors(streams, mem_ptr->unpacked_selectors, input, num_blocks, h_unique_values, @@ -748,7 +748,7 @@ __host__ void host_unchecked_first_index_of_clear( const uint64_t *h_clear_val, uint32_t num_inputs, uint32_t num_blocks, uint32_t num_blocks_index, int_unchecked_first_index_of_clear_buffer *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { cuda_memcpy_async_to_gpu(mem_ptr->d_clear_val, h_clear_val, num_blocks * sizeof(Torus), streams.stream(0), @@ -841,7 +841,7 @@ __host__ void host_unchecked_first_index_of( CudaRadixCiphertextFFI const *value, uint32_t num_inputs, uint32_t num_blocks, uint32_t num_blocks_index, int_unchecked_first_index_of_buffer *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { mem_ptr->internal_cuda_streams.internal_streams_wait_for_main_stream_0( streams); @@ -924,7 +924,7 @@ __host__ void host_unchecked_index_of( CudaRadixCiphertextFFI const *value, uint32_t num_inputs, uint32_t num_blocks, uint32_t num_blocks_index, int_unchecked_index_of_buffer *mem_ptr, void *const *bsks, - Torus *const *ksks) { + KSTorus *const *ksks) { mem_ptr->internal_cuda_streams.internal_streams_wait_for_main_stream_0( streams); @@ -992,7 +992,7 @@ __host__ void host_unchecked_index_of_clear( uint32_t num_inputs, uint32_t num_blocks, uint32_t num_scalar_blocks, uint32_t num_blocks_index, int_unchecked_index_of_clear_buffer *mem_ptr, - void *const *bsks, Torus *const *ksks) { + void *const *bsks, KSTorus *const *ksks) { CudaRadixCiphertextFFI *packed_selectors = mem_ptr->final_index_buf->packed_selectors; diff --git a/backends/tfhe-cuda-backend/src/bindings.rs b/backends/tfhe-cuda-backend/src/bindings.rs index 21099b083..b543d8f97 100644 --- a/backends/tfhe-cuda-backend/src/bindings.rs +++ b/backends/tfhe-cuda-backend/src/bindings.rs @@ -651,6 +651,29 @@ unsafe extern "C" { noise_reduction_type: PBS_MS_REDUCTION_T, ) -> u64; } +unsafe extern "C" { + pub fn scratch_cuda_comparison_64_ks32( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + big_lwe_dimension: u32, + small_lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + lwe_ciphertext_count: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + op_type: COMPARISON_TYPE, + is_signed: bool, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} unsafe extern "C" { pub fn cuda_comparison_ciphertext_64( streams: CudaStreamsFFI, @@ -662,6 +685,17 @@ unsafe extern "C" { ksks: *const *mut ffi::c_void, ); } +unsafe extern "C" { + pub fn cuda_comparison_ciphertext_64_ks32( + streams: CudaStreamsFFI, + lwe_array_out: *mut CudaRadixCiphertextFFI, + lwe_array_1: *const CudaRadixCiphertextFFI, + lwe_array_2: *const CudaRadixCiphertextFFI, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} unsafe extern "C" { pub fn cuda_scalar_comparison_ciphertext_64( streams: CudaStreamsFFI, @@ -746,6 +780,37 @@ unsafe extern "C" { ksks: *const *mut ffi::c_void, ); } +unsafe extern "C" { + pub fn scratch_cuda_boolean_bitnot_64_ks32( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + big_lwe_dimension: u32, + small_lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + lwe_ciphertext_count: u32, + is_unchecked: bool, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} +unsafe extern "C" { + pub fn cuda_boolean_bitnot_ciphertext_64_ks32( + streams: CudaStreamsFFI, + lwe_array: *mut CudaRadixCiphertextFFI, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} unsafe extern "C" { pub fn cleanup_cuda_boolean_bitnot(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8); } @@ -780,6 +845,28 @@ unsafe extern "C" { noise_reduction_type: PBS_MS_REDUCTION_T, ) -> u64; } +unsafe extern "C" { + pub fn scratch_cuda_bitop_64_ks32( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + big_lwe_dimension: u32, + small_lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + lwe_ciphertext_count: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + op_type: BITOP_TYPE, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} unsafe extern "C" { pub fn cuda_scalar_bitop_ciphertext_64( streams: CudaStreamsFFI, @@ -804,6 +891,17 @@ unsafe extern "C" { ksks: *const *mut ffi::c_void, ); } +unsafe extern "C" { + pub fn cuda_bitop_ciphertext_64_ks32( + streams: CudaStreamsFFI, + lwe_array_out: *mut CudaRadixCiphertextFFI, + lwe_array_1: *const CudaRadixCiphertextFFI, + lwe_array_2: *const CudaRadixCiphertextFFI, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} unsafe extern "C" { pub fn cleanup_cuda_integer_bitop(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8); } @@ -828,6 +926,27 @@ unsafe extern "C" { noise_reduction_type: PBS_MS_REDUCTION_T, ) -> u64; } +unsafe extern "C" { + pub fn scratch_cuda_cmux_64_ks32( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + big_lwe_dimension: u32, + small_lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + lwe_ciphertext_count: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} unsafe extern "C" { pub fn cuda_cmux_ciphertext_64( streams: CudaStreamsFFI, @@ -840,6 +959,18 @@ unsafe extern "C" { ksks: *const *mut ffi::c_void, ); } +unsafe extern "C" { + pub fn cuda_cmux_ciphertext_64_ks32( + streams: CudaStreamsFFI, + lwe_array_out: *mut CudaRadixCiphertextFFI, + lwe_condition: *const CudaRadixCiphertextFFI, + lwe_array_true: *const CudaRadixCiphertextFFI, + lwe_array_false: *const CudaRadixCiphertextFFI, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} unsafe extern "C" { pub fn cleanup_cuda_cmux(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8); } @@ -1016,6 +1147,28 @@ unsafe extern "C" { noise_reduction_type: PBS_MS_REDUCTION_T, ) -> u64; } +unsafe extern "C" { + pub fn scratch_cuda_integer_overflowing_sub_64_ks32_inplace( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + big_lwe_dimension: u32, + small_lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + num_blocks: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + compute_overflow: u32, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} unsafe extern "C" { pub fn cuda_integer_overflowing_sub_64_inplace( streams: CudaStreamsFFI, @@ -1030,6 +1183,20 @@ unsafe extern "C" { uses_input_borrow: u32, ); } +unsafe extern "C" { + pub fn cuda_integer_overflowing_sub_64_ks32_inplace( + streams: CudaStreamsFFI, + lhs_array: *mut CudaRadixCiphertextFFI, + rhs_array: *const CudaRadixCiphertextFFI, + overflow_block: *mut CudaRadixCiphertextFFI, + input_borrow: *const CudaRadixCiphertextFFI, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + compute_overflow: u32, + uses_input_borrow: u32, + ); +} unsafe extern "C" { pub fn cleanup_cuda_integer_overflowing_sub( streams: CudaStreamsFFI, @@ -1169,6 +1336,28 @@ unsafe extern "C" { noise_reduction_type: PBS_MS_REDUCTION_T, ) -> u64; } +unsafe extern "C" { + pub fn scratch_cuda_integer_div_rem_radix_ciphertext_64_ks32( + streams: CudaStreamsFFI, + is_signed: bool, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + big_lwe_dimension: u32, + small_lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + num_blocks: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} unsafe extern "C" { pub fn cuda_integer_div_rem_radix_ciphertext_64( streams: CudaStreamsFFI, @@ -1182,6 +1371,19 @@ unsafe extern "C" { ksks: *const *mut ffi::c_void, ); } +unsafe extern "C" { + pub fn cuda_integer_div_rem_radix_ciphertext_64_ks32( + streams: CudaStreamsFFI, + quotient: *mut CudaRadixCiphertextFFI, + remainder: *mut CudaRadixCiphertextFFI, + numerator: *const CudaRadixCiphertextFFI, + divisor: *const CudaRadixCiphertextFFI, + is_signed: bool, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} unsafe extern "C" { pub fn cleanup_cuda_integer_div_rem(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8); } @@ -1826,6 +2028,30 @@ unsafe extern "C" { noise_reduction_type: PBS_MS_REDUCTION_T, ) -> u64; } +unsafe extern "C" { + pub fn scratch_cuda_cast_to_unsigned_64_ks32( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + big_lwe_dimension: u32, + small_lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + num_input_blocks: u32, + target_num_blocks: u32, + input_is_signed: bool, + requires_full_propagate: bool, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} unsafe extern "C" { pub fn cuda_cast_to_unsigned_64( streams: CudaStreamsFFI, @@ -1838,6 +2064,18 @@ unsafe extern "C" { ksks: *const *mut ffi::c_void, ); } +unsafe extern "C" { + pub fn cuda_cast_to_unsigned_64_ks32( + streams: CudaStreamsFFI, + output: *mut CudaRadixCiphertextFFI, + input: *mut CudaRadixCiphertextFFI, + mem_ptr: *mut i8, + target_num_blocks: u32, + input_is_signed: bool, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} unsafe extern "C" { pub fn cleanup_cuda_cast_to_unsigned_64(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8); } diff --git a/tfhe-benchmark/benches/high_level_api/erc20.rs b/tfhe-benchmark/benches/high_level_api/erc20.rs index 1e8e81b3b..08eefbe32 100644 --- a/tfhe-benchmark/benches/high_level_api/erc20.rs +++ b/tfhe-benchmark/benches/high_level_api/erc20.rs @@ -866,6 +866,9 @@ fn main() { ParamType::Classical => { benchmark::params_aliases::BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into() }, + ParamType::ClassicalKs32 => { + benchmark::params_aliases::BENCH_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M64.into() + }, _ => { benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into() } diff --git a/tfhe-benchmark/benches/integer/bench.rs b/tfhe-benchmark/benches/integer/bench.rs index bc260f050..11368fe64 100644 --- a/tfhe-benchmark/benches/integer/bench.rs +++ b/tfhe-benchmark/benches/integer/bench.rs @@ -2719,7 +2719,13 @@ mod cuda { rng_func: default_scalar ); - criterion_group!(cuda_ops_support_ks32, cuda_mul, cuda_add, cuda_sub,); + criterion_group!( + cuda_ops_support_ks32, + cuda_mul, + cuda_add, + cuda_sub, + cuda_div_rem, + ); criterion_group!( unchecked_cuda_ops, diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index 75e624952..fcd90ba2e 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -1492,35 +1492,70 @@ pub(crate) unsafe fn cuda_backend_unchecked_bitop_assign() == TypeId::of::() { + scratch_cuda_bitop_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + op as u32, + true, + noise_reduction_type as u32, + ); + cuda_bitop_ciphertext_64( + streams.ffi(), + &raw mut cuda_ffi_radix_lwe_left, + &raw const cuda_ffi_radix_lwe_left, + &raw const cuda_ffi_radix_lwe_right, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + } else if TypeId::of::() == TypeId::of::() { + scratch_cuda_bitop_64_ks32( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + op as u32, + true, + noise_reduction_type as u32, + ); + cuda_bitop_ciphertext_64_ks32( + streams.ffi(), + &raw mut cuda_ffi_radix_lwe_left, + &raw const cuda_ffi_radix_lwe_left, + &raw const cuda_ffi_radix_lwe_right, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + } else { + panic!("Unsupported KS dtype"); + } cleanup_cuda_integer_bitop(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); update_noise_degree(radix_lwe_left, &cuda_ffi_radix_lwe_left); } @@ -2093,37 +2128,73 @@ pub(crate) unsafe fn cuda_backend_unchecked_comparison() == TypeId::of::() { + scratch_cuda_comparison_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + op as u32, + is_signed, + true, + noise_reduction_type as u32, + ); - cuda_comparison_ciphertext_64( - streams.ffi(), - &raw mut cuda_ffi_radix_lwe_out, - &raw const cuda_ffi_radix_lwe_left, - &raw const cuda_ffi_radix_lwe_right, - mem_ptr, - bootstrapping_key.ptr.as_ptr(), - keyswitch_key.ptr.as_ptr(), - ); + cuda_comparison_ciphertext_64( + streams.ffi(), + &raw mut cuda_ffi_radix_lwe_out, + &raw const cuda_ffi_radix_lwe_left, + &raw const cuda_ffi_radix_lwe_right, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + } else if TypeId::of::() == TypeId::of::() { + scratch_cuda_comparison_64_ks32( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + op as u32, + is_signed, + true, + noise_reduction_type as u32, + ); + + cuda_comparison_ciphertext_64_ks32( + streams.ffi(), + &raw mut cuda_ffi_radix_lwe_out, + &raw const cuda_ffi_radix_lwe_left, + &raw const cuda_ffi_radix_lwe_right, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + } else { + panic!("Unsupported KS datatype"); + } cleanup_cuda_integer_comparison(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); update_noise_degree(radix_lwe_out, &cuda_ffi_radix_lwe_out); @@ -5565,35 +5636,71 @@ pub(crate) unsafe fn cuda_backend_unchecked_cmux &mut condition_noise_levels, ); let mut mem_ptr: *mut i8 = std::ptr::null_mut(); - scratch_cuda_cmux_64( - streams.ffi(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - pbs_type as u32, - true, - noise_reduction_type as u32, - ); - cuda_cmux_ciphertext_64( - streams.ffi(), - &raw mut cuda_ffi_radix_lwe_out, - &raw const cuda_ffi_condition, - &raw const cuda_ffi_radix_lwe_true, - &raw const cuda_ffi_radix_lwe_false, - mem_ptr, - bootstrapping_key.ptr.as_ptr(), - keyswitch_key.ptr.as_ptr(), - ); + + if TypeId::of::() == TypeId::of::() { + scratch_cuda_cmux_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + ); + cuda_cmux_ciphertext_64( + streams.ffi(), + &raw mut cuda_ffi_radix_lwe_out, + &raw const cuda_ffi_condition, + &raw const cuda_ffi_radix_lwe_true, + &raw const cuda_ffi_radix_lwe_false, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + } else if TypeId::of::() == TypeId::of::() { + scratch_cuda_cmux_64_ks32( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + ); + cuda_cmux_ciphertext_64_ks32( + streams.ffi(), + &raw mut cuda_ffi_radix_lwe_out, + &raw const cuda_ffi_condition, + &raw const cuda_ffi_radix_lwe_true, + &raw const cuda_ffi_radix_lwe_false, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + } else { + panic!("Unsupported KS dtype"); + } + cleanup_cuda_cmux(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); update_noise_degree(radix_lwe_out, &cuda_ffi_radix_lwe_out); } @@ -6963,38 +7070,76 @@ pub(crate) unsafe fn cuda_backend_unchecked_unsigned_overflowing_sub_assign< .collect(); let cuda_ffi_carry_in = prepare_cuda_radix_ffi(carry_in, &mut carry_in_degrees, &mut carry_in_noise_levels); - scratch_cuda_integer_overflowing_sub_64_inplace( - streams.ffi(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension, - lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - grouping_factor.0 as u32, - radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32, - message_modulus.0 as u32, - carry_modulus.0 as u32, - pbs_type as u32, - compute_overflow as u32, - true, - noise_reduction_type as u32, - ); - cuda_integer_overflowing_sub_64_inplace( - streams.ffi(), - &raw mut cuda_ffi_radix_lwe_left, - &raw const cuda_ffi_radix_lwe_right, - &raw mut cuda_ffi_carry_out, - &raw const cuda_ffi_carry_in, - mem_ptr, - bootstrapping_key.ptr.as_ptr(), - keyswitch_key.ptr.as_ptr(), - compute_overflow as u32, - uses_input_borrow, - ); + + if TypeId::of::() == TypeId::of::() { + scratch_cuda_integer_overflowing_sub_64_inplace( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + compute_overflow as u32, + true, + noise_reduction_type as u32, + ); + cuda_integer_overflowing_sub_64_inplace( + streams.ffi(), + &raw mut cuda_ffi_radix_lwe_left, + &raw const cuda_ffi_radix_lwe_right, + &raw mut cuda_ffi_carry_out, + &raw const cuda_ffi_carry_in, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + compute_overflow as u32, + uses_input_borrow, + ); + } else if TypeId::of::() == TypeId::of::() { + scratch_cuda_integer_overflowing_sub_64_ks32_inplace( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + radix_lwe_left.d_blocks.lwe_ciphertext_count().0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + compute_overflow as u32, + true, + noise_reduction_type as u32, + ); + cuda_integer_overflowing_sub_64_ks32_inplace( + streams.ffi(), + &raw mut cuda_ffi_radix_lwe_left, + &raw const cuda_ffi_radix_lwe_right, + &raw mut cuda_ffi_carry_out, + &raw const cuda_ffi_carry_in, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + compute_overflow as u32, + uses_input_borrow, + ); + } else { + panic!("Unsupported KS dtype"); + } cleanup_cuda_integer_overflowing_sub(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); update_noise_degree(radix_lwe_left, &cuda_ffi_radix_lwe_left); update_noise_degree(carry_out, &cuda_ffi_carry_out); @@ -8291,34 +8436,68 @@ pub(crate) unsafe fn cuda_backend_boolean_bitnot_assign() == TypeId::of::() { + scratch_cuda_boolean_bitnot_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + 1u32, + is_unchecked, + true, + noise_reduction_type as u32, + ); + + cuda_boolean_bitnot_ciphertext_64( + streams.ffi(), + &raw mut cuda_ffi_ciphertext, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + } else if TypeId::of::() == TypeId::of::() { + scratch_cuda_boolean_bitnot_64_ks32( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + 1u32, + is_unchecked, + true, + noise_reduction_type as u32, + ); + + cuda_boolean_bitnot_ciphertext_64_ks32( + streams.ffi(), + &raw mut cuda_ffi_ciphertext, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + } else { + panic!("Unsupported KS dype"); + } - cuda_boolean_bitnot_ciphertext_64( - streams.ffi(), - &raw mut cuda_ffi_ciphertext, - mem_ptr, - bootstrapping_key.ptr.as_ptr(), - keyswitch_key.ptr.as_ptr(), - ); cleanup_cuda_boolean_bitnot(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); update_noise_degree(ciphertext, &cuda_ffi_ciphertext); } @@ -8792,39 +8971,75 @@ pub(crate) unsafe fn cuda_backend_cast_to_unsigned() == TypeId::of::() { + scratch_cuda_cast_to_unsigned_64_ks32( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_input_blocks, + target_num_blocks, + input_is_signed, + requires_full_propagate, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + ); - cuda_cast_to_unsigned_64( - streams.ffi(), - &raw mut cuda_ffi_output, - &raw mut cuda_ffi_input, - mem_ptr, - target_num_blocks, - input_is_signed, - bootstrapping_key.ptr.as_ptr(), - keyswitch_key.ptr.as_ptr(), - ); + cuda_cast_to_unsigned_64_ks32( + streams.ffi(), + &raw mut cuda_ffi_output, + &raw mut cuda_ffi_input, + mem_ptr, + target_num_blocks, + input_is_signed, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + } else if TypeId::of::() == TypeId::of::() { + scratch_cuda_cast_to_unsigned_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_input_blocks, + target_num_blocks, + input_is_signed, + requires_full_propagate, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + true, + noise_reduction_type as u32, + ); + + cuda_cast_to_unsigned_64( + streams.ffi(), + &raw mut cuda_ffi_output, + &raw mut cuda_ffi_input, + mem_ptr, + target_num_blocks, + input_is_signed, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + } cleanup_cuda_cast_to_unsigned_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); diff --git a/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs b/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs index 0fa21f3d9..b9105e94a 100644 --- a/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs +++ b/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs @@ -408,57 +408,103 @@ impl CudaServerKey { is_unchecked: bool, streams: &CudaStreams, ) { - let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else { - panic!("Only the standard atomic pattern is supported on GPU") - }; - - unsafe { - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - cuda_backend_boolean_bitnot_assign( - streams, - &mut ct.0.ciphertext as &mut CudaRadixCiphertext, - is_unchecked, - &d_bsk.d_vec, - &computing_ks_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_bsk.glwe_dimension(), - d_bsk.polynomial_size(), - d_bsk.output_lwe_dimension(), - d_bsk.input_lwe_dimension(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_bsk.decomp_level_count(), - d_bsk.decomp_base_log(), - PBSType::Classical, - LweBskGroupingFactor(0), - d_bsk.ms_noise_reduction_configuration.as_ref(), - ); + match &self.key_switching_key { + CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_boolean_bitnot_assign( + streams, + &mut ct.0.ciphertext as &mut CudaRadixCiphertext, + is_unchecked, + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension(), + d_bsk.polynomial_size(), + d_bsk.output_lwe_dimension(), + d_bsk.input_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count(), + d_bsk.decomp_base_log(), + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_boolean_bitnot_assign( + streams, + &mut ct.0.ciphertext as &mut CudaRadixCiphertext, + is_unchecked, + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension(), + d_multibit_bsk.polynomial_size(), + d_multibit_bsk.output_lwe_dimension(), + d_multibit_bsk.input_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count(), + d_multibit_bsk.decomp_base_log(), + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - cuda_backend_boolean_bitnot_assign( - streams, - &mut ct.0.ciphertext as &mut CudaRadixCiphertext, - is_unchecked, - &d_multibit_bsk.d_vec, - &computing_ks_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_multibit_bsk.glwe_dimension(), - d_multibit_bsk.polynomial_size(), - d_multibit_bsk.output_lwe_dimension(), - d_multibit_bsk.input_lwe_dimension(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count(), - d_multibit_bsk.decomp_base_log(), - PBSType::MultiBit, - d_multibit_bsk.grouping_factor, - None, - ); + }, + CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_boolean_bitnot_assign( + streams, + &mut ct.0.ciphertext as &mut CudaRadixCiphertext, + is_unchecked, + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension(), + d_bsk.polynomial_size(), + d_bsk.output_lwe_dimension(), + d_bsk.input_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count(), + d_bsk.decomp_base_log(), + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_boolean_bitnot_assign( + streams, + &mut ct.0.ciphertext as &mut CudaRadixCiphertext, + is_unchecked, + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension(), + d_multibit_bsk.polynomial_size(), + d_multibit_bsk.output_lwe_dimension(), + d_multibit_bsk.input_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count(), + d_multibit_bsk.decomp_base_log(), + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } - } + }, } } @@ -534,61 +580,111 @@ impl CudaServerKey { let lwe_ciphertext_count = ct_left.as_ref().d_blocks.lwe_ciphertext_count(); - let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else { - panic!("Only the standard atomic pattern is supported on GPU") - }; - - unsafe { - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - cuda_backend_unchecked_bitop_assign( - streams, - ct_left.as_mut(), - ct_right.as_ref(), - &d_bsk.d_vec, - &computing_ks_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_bsk.glwe_dimension, - d_bsk.polynomial_size, - computing_ks_key.input_key_lwe_size().to_lwe_dimension(), - computing_ks_key.output_key_lwe_size().to_lwe_dimension(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_bsk.decomp_level_count, - d_bsk.decomp_base_log, - op, - lwe_ciphertext_count.0 as u32, - PBSType::Classical, - LweBskGroupingFactor(0), - d_bsk.ms_noise_reduction_configuration.as_ref(), - ); + match &self.key_switching_key { + CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_bitop_assign( + streams, + ct_left.as_mut(), + ct_right.as_ref(), + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + op, + lwe_ciphertext_count.0 as u32, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_bitop_assign( + streams, + ct_left.as_mut(), + ct_right.as_ref(), + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + op, + lwe_ciphertext_count.0 as u32, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - cuda_backend_unchecked_bitop_assign( - streams, - ct_left.as_mut(), - ct_right.as_ref(), - &d_multibit_bsk.d_vec, - &computing_ks_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_multibit_bsk.glwe_dimension, - d_multibit_bsk.polynomial_size, - computing_ks_key.input_key_lwe_size().to_lwe_dimension(), - computing_ks_key.output_key_lwe_size().to_lwe_dimension(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count, - d_multibit_bsk.decomp_base_log, - op, - lwe_ciphertext_count.0 as u32, - PBSType::MultiBit, - d_multibit_bsk.grouping_factor, - None, - ); + }, + CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_bitop_assign( + streams, + ct_left.as_mut(), + ct_right.as_ref(), + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + op, + lwe_ciphertext_count.0 as u32, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_bitop_assign( + streams, + ct_left.as_mut(), + ct_right.as_ref(), + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + op, + lwe_ciphertext_count.0 as u32, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } - } + }, } } diff --git a/tfhe/src/integer/gpu/server_key/radix/cmux.rs b/tfhe/src/integer/gpu/server_key/radix/cmux.rs index b6a0701a2..a577ddb9c 100644 --- a/tfhe/src/integer/gpu/server_key/radix/cmux.rs +++ b/tfhe/src/integer/gpu/server_key/radix/cmux.rs @@ -20,64 +20,117 @@ impl CudaServerKey { let mut result: T = self .create_trivial_zero_radix(true_ct.as_ref().d_blocks.lwe_ciphertext_count().0, stream); - let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else { - panic!("Only the standard atomic pattern is supported on GPU") - }; - - unsafe { - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - cuda_backend_unchecked_cmux( - stream, - result.as_mut(), - condition, - true_ct.as_ref(), - false_ct.as_ref(), - &d_bsk.d_vec, - &computing_ks_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_bsk.glwe_dimension, - d_bsk.polynomial_size, - computing_ks_key.input_key_lwe_size().to_lwe_dimension(), - computing_ks_key.output_key_lwe_size().to_lwe_dimension(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_bsk.decomp_level_count, - d_bsk.decomp_base_log, - lwe_ciphertext_count.0 as u32, - PBSType::Classical, - LweBskGroupingFactor(0), - d_bsk.ms_noise_reduction_configuration.as_ref(), - ); + match &self.key_switching_key { + CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_cmux( + stream, + result.as_mut(), + condition, + true_ct.as_ref(), + false_ct.as_ref(), + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + lwe_ciphertext_count.0 as u32, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_cmux( + stream, + result.as_mut(), + condition, + true_ct.as_ref(), + false_ct.as_ref(), + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + lwe_ciphertext_count.0 as u32, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - cuda_backend_unchecked_cmux( - stream, - result.as_mut(), - condition, - true_ct.as_ref(), - false_ct.as_ref(), - &d_multibit_bsk.d_vec, - &computing_ks_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_multibit_bsk.glwe_dimension, - d_multibit_bsk.polynomial_size, - computing_ks_key.input_key_lwe_size().to_lwe_dimension(), - computing_ks_key.output_key_lwe_size().to_lwe_dimension(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count, - d_multibit_bsk.decomp_base_log, - lwe_ciphertext_count.0 as u32, - PBSType::MultiBit, - d_multibit_bsk.grouping_factor, - None, - ); + }, + CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_cmux( + stream, + result.as_mut(), + condition, + true_ct.as_ref(), + false_ct.as_ref(), + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + lwe_ciphertext_count.0 as u32, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_cmux( + stream, + result.as_mut(), + condition, + true_ct.as_ref(), + false_ct.as_ref(), + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + lwe_ciphertext_count.0 as u32, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } - } + }, } + result } diff --git a/tfhe/src/integer/gpu/server_key/radix/comparison.rs b/tfhe/src/integer/gpu/server_key/radix/comparison.rs index 489e6e7f6..5ef3d6a96 100644 --- a/tfhe/src/integer/gpu/server_key/radix/comparison.rs +++ b/tfhe/src/integer/gpu/server_key/radix/comparison.rs @@ -45,65 +45,117 @@ impl CudaServerKey { let mut result = CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(block, ct_info)); - let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else { - panic!("Only the standard atomic pattern is supported on GPU") + match &self.key_switching_key { + CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_comparison( + streams, + result.as_mut().as_mut(), + ct_left.as_ref(), + ct_right.as_ref(), + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + op, + T::IS_SIGNED, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_comparison( + streams, + result.as_mut().as_mut(), + ct_left.as_ref(), + ct_right.as_ref(), + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + op, + T::IS_SIGNED, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } + } + }, + CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_comparison( + streams, + result.as_mut().as_mut(), + ct_left.as_ref(), + ct_right.as_ref(), + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + op, + T::IS_SIGNED, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_comparison( + streams, + result.as_mut().as_mut(), + ct_left.as_ref(), + ct_right.as_ref(), + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + op, + T::IS_SIGNED, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } + } + }, }; - unsafe { - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - cuda_backend_unchecked_comparison( - streams, - result.as_mut().as_mut(), - ct_left.as_ref(), - ct_right.as_ref(), - &d_bsk.d_vec, - &computing_ks_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_bsk.glwe_dimension, - d_bsk.polynomial_size, - computing_ks_key.input_key_lwe_size().to_lwe_dimension(), - computing_ks_key.output_key_lwe_size().to_lwe_dimension(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_bsk.decomp_level_count, - d_bsk.decomp_base_log, - op, - T::IS_SIGNED, - PBSType::Classical, - LweBskGroupingFactor(0), - d_bsk.ms_noise_reduction_configuration.as_ref(), - ); - } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - cuda_backend_unchecked_comparison( - streams, - result.as_mut().as_mut(), - ct_left.as_ref(), - ct_right.as_ref(), - &d_multibit_bsk.d_vec, - &computing_ks_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_multibit_bsk.glwe_dimension, - d_multibit_bsk.polynomial_size, - computing_ks_key.input_key_lwe_size().to_lwe_dimension(), - computing_ks_key.output_key_lwe_size().to_lwe_dimension(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count, - d_multibit_bsk.decomp_base_log, - op, - T::IS_SIGNED, - PBSType::MultiBit, - d_multibit_bsk.grouping_factor, - None, - ); - } - } - } - result } diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index acb2eb27e..2313cfd24 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -1019,61 +1019,109 @@ impl CudaServerKey { self.create_trivial_zero_radix(target_num_blocks, streams); let requires_full_propagate = !source.block_carries_are_empty(); - let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else { - panic!("Only the standard atomic pattern is supported on GPU") - }; - unsafe { - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - cuda_backend_cast_to_unsigned( - streams, - result.as_mut(), - source.as_mut(), - T::IS_SIGNED, - requires_full_propagate, - target_num_blocks as u32, - &d_bsk.d_vec, - &computing_ks_key.d_vec, - d_bsk.glwe_dimension, - d_bsk.polynomial_size, - computing_ks_key.input_key_lwe_size().to_lwe_dimension(), - computing_ks_key.output_key_lwe_size().to_lwe_dimension(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_bsk.decomp_level_count, - d_bsk.decomp_base_log, - PBSType::Classical, - LweBskGroupingFactor(0), - d_bsk.ms_noise_reduction_configuration.as_ref(), - ); + match &self.key_switching_key { + CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_cast_to_unsigned( + streams, + result.as_mut(), + source.as_mut(), + T::IS_SIGNED, + requires_full_propagate, + target_num_blocks as u32, + &d_bsk.d_vec, + &computing_ks_key.d_vec, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_cast_to_unsigned( + streams, + result.as_mut(), + source.as_mut(), + T::IS_SIGNED, + requires_full_propagate, + target_num_blocks as u32, + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - cuda_backend_cast_to_unsigned( - streams, - result.as_mut(), - source.as_mut(), - T::IS_SIGNED, - requires_full_propagate, - target_num_blocks as u32, - &d_multibit_bsk.d_vec, - &computing_ks_key.d_vec, - d_multibit_bsk.glwe_dimension, - d_multibit_bsk.polynomial_size, - computing_ks_key.input_key_lwe_size().to_lwe_dimension(), - computing_ks_key.output_key_lwe_size().to_lwe_dimension(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count, - d_multibit_bsk.decomp_base_log, - PBSType::MultiBit, - d_multibit_bsk.grouping_factor, - None, - ); + }, + CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_cast_to_unsigned( + streams, + result.as_mut(), + source.as_mut(), + T::IS_SIGNED, + requires_full_propagate, + target_num_blocks as u32, + &d_bsk.d_vec, + &computing_ks_key.d_vec, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_cast_to_unsigned( + streams, + result.as_mut(), + source.as_mut(), + T::IS_SIGNED, + requires_full_propagate, + target_num_blocks as u32, + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } - } + }, } - result } diff --git a/tfhe/src/integer/gpu/server_key/radix/sub.rs b/tfhe/src/integer/gpu/server_key/radix/sub.rs index 720fea173..6232468c4 100644 --- a/tfhe/src/integer/gpu/server_key/radix/sub.rs +++ b/tfhe/src/integer/gpu/server_key/radix/sub.rs @@ -296,64 +296,118 @@ impl CudaServerKey { let aux_block: CudaUnsignedRadixCiphertext = self.create_trivial_zero_radix(1, stream); let in_carry_dvec = INPUT_BORROW.map_or_else(|| aux_block.as_ref(), |block| block.as_ref().as_ref()); - let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else { - panic!("Only the standard atomic pattern is supported on GPU") - }; - unsafe { - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - cuda_backend_unchecked_unsigned_overflowing_sub_assign( - stream, - ciphertext, - rhs.as_ref(), - overflow_block.as_mut(), - in_carry_dvec, - &d_bsk.d_vec, - &computing_ks_key.d_vec, - d_bsk.input_lwe_dimension(), - d_bsk.glwe_dimension(), - d_bsk.polynomial_size(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_bsk.decomp_level_count(), - d_bsk.decomp_base_log(), - ciphertext.info.blocks.first().unwrap().message_modulus, - ciphertext.info.blocks.first().unwrap().carry_modulus, - PBSType::Classical, - LweBskGroupingFactor(0), - compute_overflow, - uses_input_borrow, - d_bsk.ms_noise_reduction_configuration.as_ref(), - ); + match &self.key_switching_key { + CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_unsigned_overflowing_sub_assign( + stream, + ciphertext, + rhs.as_ref(), + overflow_block.as_mut(), + in_carry_dvec, + &d_bsk.d_vec, + &computing_ks_key.d_vec, + d_bsk.input_lwe_dimension(), + d_bsk.glwe_dimension(), + d_bsk.polynomial_size(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count(), + d_bsk.decomp_base_log(), + ciphertext.info.blocks.first().unwrap().message_modulus, + ciphertext.info.blocks.first().unwrap().carry_modulus, + PBSType::Classical, + LweBskGroupingFactor(0), + compute_overflow, + uses_input_borrow, + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_unsigned_overflowing_sub_assign( + stream, + ciphertext, + rhs.as_ref(), + overflow_block.as_mut(), + in_carry_dvec, + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + d_multibit_bsk.input_lwe_dimension(), + d_multibit_bsk.glwe_dimension(), + d_multibit_bsk.polynomial_size(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count(), + d_multibit_bsk.decomp_base_log(), + ciphertext.info.blocks.first().unwrap().message_modulus, + ciphertext.info.blocks.first().unwrap().carry_modulus, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + compute_overflow, + uses_input_borrow, + None, + ); + } } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - cuda_backend_unchecked_unsigned_overflowing_sub_assign( - stream, - ciphertext, - rhs.as_ref(), - overflow_block.as_mut(), - in_carry_dvec, - &d_multibit_bsk.d_vec, - &computing_ks_key.d_vec, - d_multibit_bsk.input_lwe_dimension(), - d_multibit_bsk.glwe_dimension(), - d_multibit_bsk.polynomial_size(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count(), - d_multibit_bsk.decomp_base_log(), - ciphertext.info.blocks.first().unwrap().message_modulus, - ciphertext.info.blocks.first().unwrap().carry_modulus, - PBSType::MultiBit, - d_multibit_bsk.grouping_factor, - compute_overflow, - uses_input_borrow, - None, - ); + }, + CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_unsigned_overflowing_sub_assign( + stream, + ciphertext, + rhs.as_ref(), + overflow_block.as_mut(), + in_carry_dvec, + &d_bsk.d_vec, + &computing_ks_key.d_vec, + d_bsk.input_lwe_dimension(), + d_bsk.glwe_dimension(), + d_bsk.polynomial_size(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count(), + d_bsk.decomp_base_log(), + ciphertext.info.blocks.first().unwrap().message_modulus, + ciphertext.info.blocks.first().unwrap().carry_modulus, + PBSType::Classical, + LweBskGroupingFactor(0), + compute_overflow, + uses_input_borrow, + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_unsigned_overflowing_sub_assign( + stream, + ciphertext, + rhs.as_ref(), + overflow_block.as_mut(), + in_carry_dvec, + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + d_multibit_bsk.input_lwe_dimension(), + d_multibit_bsk.glwe_dimension(), + d_multibit_bsk.polynomial_size(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count(), + d_multibit_bsk.decomp_base_log(), + ciphertext.info.blocks.first().unwrap().message_modulus, + ciphertext.info.blocks.first().unwrap().carry_modulus, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + compute_overflow, + uses_input_borrow, + None, + ); + } } - } + }, } + let ct_overflowed = CudaBooleanBlock::from_cuda_radix_ciphertext(overflow_block.ciphertext); (ct_res, ct_overflowed)