From b3b2787c4a9fbe9f38f67eb052c6d63ded282ff5 Mon Sep 17 00:00:00 2001 From: Andrei Stoian Date: Mon, 29 Dec 2025 23:38:12 +0100 Subject: [PATCH] fix(gpu): two ops supporting ks32, add ks32 in tests --- Makefile | 7 +- .../cuda/include/integer/integer.h | 27 +++ .../cuda/include/integer/integer_utilities.h | 4 +- .../cuda/include/integer/multiplication.h | 2 +- .../cuda/src/crypto/keyswitch.cuh | 8 +- .../cuda/src/integer/integer.cuh | 48 ++--- .../cuda/src/integer/multiplication.cuh | 30 +-- .../cuda/src/integer/rerand.cuh | 2 +- .../cuda/src/integer/scalar_mul.cu | 32 +++ .../cuda/src/integer/scalar_mul.cuh | 8 +- .../cuda/src/integer/scalar_shifts.cu | 34 ++- .../cuda/src/pbs/programmable_bootstrap.cuh | 6 +- backends/tfhe-cuda-backend/cuda/src/zk/zk.cuh | 2 +- backends/tfhe-cuda-backend/src/bindings.rs | 67 ++++++ tfhe/src/integer/gpu/mod.rs | 189 +++++++++++------ tfhe/src/integer/gpu/server_key/mod.rs | 194 ++++++++++++------ .../gpu/server_key/radix/scalar_mul.rs | 147 ++++++++----- .../gpu/server_key/radix/scalar_shift.rs | 151 +++++++++----- .../radix/tests_signed/test_scalar_shift.rs | 1 + .../server_key/radix/tests_unsigned/mod.rs | 1 + 20 files changed, 678 insertions(+), 282 deletions(-) diff --git a/Makefile b/Makefile index adc7d67d1..2a0ab4c63 100644 --- a/Makefile +++ b/Makefile @@ -721,9 +721,10 @@ test_core_crypto_gpu: .PHONY: test_integer_gpu # Run the tests of the integer module including experimental on the gpu backend test_integer_gpu: RUSTFLAGS="$(RUSTFLAGS)" cargo test --profile $(CARGO_PROFILE) \ - --features=integer,gpu -p tfhe -- integer::gpu::server_key:: --test-threads=2 - RUSTFLAGS="$(RUSTFLAGS)" cargo test --doc --profile $(CARGO_PROFILE) \ - --features=integer,gpu -p tfhe -- integer::gpu::server_key:: --test-threads=4 + --features=integer,gpu -p tfhe -- integer::gpu::server_key::radix::tests_signed::test_scalar_mul --test-threads=2 + #test_scalar_shift::test_gpu_integer_signed_unchecked_scalar_left_shift_test_param_message_2_carry_2_ks32_pbs_tuniform_2m128 +# RUSTFLAGS="$(RUSTFLAGS)" cargo test --doc --profile $(CARGO_PROFILE) \ +# --features=integer,gpu -p tfhe -- integer::gpu::server_key:: --test-threads=4 .PHONY: test_integer_gpu_debug # Run the tests of the integer module with Debug flags for CUDA test_integer_gpu_debug: diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer.h index e8831978c..bef5067b2 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer.h @@ -191,11 +191,24 @@ uint64_t scratch_cuda_logical_scalar_shift_64( PBS_TYPE pbs_type, SHIFT_OR_ROTATE_TYPE shift_type, bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type); +uint64_t scratch_cuda_logical_scalar_shift_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, SHIFT_OR_ROTATE_TYPE shift_type, + bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type); + void cuda_logical_scalar_shift_64_inplace(CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array, uint32_t shift, int8_t *mem_ptr, void *const *bsks, void *const *ksks); +void cuda_logical_scalar_shift_64_ks32_inplace( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array, uint32_t shift, + int8_t *mem_ptr, void *const *bsks, void *const *ksks); + uint64_t scratch_cuda_arithmetic_scalar_shift_64( CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t big_lwe_dimension, @@ -446,12 +459,26 @@ uint64_t scratch_cuda_integer_scalar_mul_64( uint32_t carry_modulus, PBS_TYPE pbs_type, uint32_t num_scalar_bits, bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type); +uint64_t scratch_cuda_integer_scalar_mul_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t num_blocks, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, uint32_t num_scalar_bits, + bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type); + void cuda_scalar_multiplication_ciphertext_64_inplace( CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array, uint64_t const *decomposed_scalar, uint64_t const *has_at_least_one_set, int8_t *mem_ptr, void *const *bsks, void *const *ksks, uint32_t polynomial_size, uint32_t message_modulus, uint32_t num_scalars); +void cuda_scalar_multiplication_ciphertext_64_ks32_inplace( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array, + uint64_t const *decomposed_scalar, uint64_t const *has_at_least_one_set, + int8_t *mem_ptr, void *const *bsks, void *const *ksks, + uint32_t polynomial_size, uint32_t message_modulus, uint32_t num_scalars); + void cleanup_cuda_scalar_mul(CudaStreamsFFI streams, int8_t **mem_ptr_void); uint64_t scratch_cuda_integer_div_rem_radix_ciphertext_64( diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h index 950390916..a7b053c8e 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h @@ -1109,7 +1109,7 @@ template struct int_fullprop_buffer { // tmp_small_lwe_vector = new CudaRadixCiphertextFFI; - create_zero_radix_ciphertext_async( + create_zero_radix_ciphertext_async( streams.stream(0), streams.gpu_index(0), tmp_small_lwe_vector, 2, params.small_lwe_dimension, size_tracker, allocate_gpu_memory); tmp_big_lwe_vector = new CudaRadixCiphertextFFI; @@ -1309,7 +1309,7 @@ struct int_sum_ciphertexts_vec_memory { max_total_blocks_in_vec, params.big_lwe_dimension, size_tracker, allocate_gpu_memory); small_lwe_vector = new CudaRadixCiphertextFFI; - create_zero_radix_ciphertext_async( + create_zero_radix_ciphertext_async( streams.stream(0), streams.gpu_index(0), small_lwe_vector, max_total_blocks_in_vec, params.small_lwe_dimension, size_tracker, allocate_gpu_memory); diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/multiplication.h b/backends/tfhe-cuda-backend/cuda/include/integer/multiplication.h index 6c39da705..2695d9036 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/multiplication.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/multiplication.h @@ -81,7 +81,7 @@ template struct int_mul_memory { 2 * total_block_count, params.big_lwe_dimension, size_tracker, allocate_gpu_memory); small_lwe_vector = new CudaRadixCiphertextFFI; - create_zero_radix_ciphertext_async( + create_zero_radix_ciphertext_async( streams.stream(0), streams.gpu_index(0), small_lwe_vector, 2 * total_block_count, params.small_lwe_dimension, size_tracker, allocate_gpu_memory); diff --git a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh index efb83da0c..e2059ee62 100644 --- a/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh @@ -546,7 +546,7 @@ __host__ void host_gemm_keyswitch_lwe_ciphertext_vector( template void execute_keyswitch_async( - CudaStreams streams, const LweArrayVariant &lwe_array_out, + CudaStreams streams, const LweArrayVariant &lwe_array_out, const LweArrayVariant &lwe_output_indexes, const LweArrayVariant &lwe_array_in, const LweArrayVariant &lwe_input_indexes, KSTorus *const *ksks, @@ -560,7 +560,7 @@ void execute_keyswitch_async( int num_samples_on_gpu = get_num_inputs_on_gpu(num_samples, i, streams.count()); - Torus *current_lwe_array_out = get_variant_element(lwe_array_out, i); + KSTorus *current_lwe_array_out = get_variant_element(lwe_array_out, i); Torus *current_lwe_output_indexes = get_variant_element(lwe_output_indexes, i); Torus *current_lwe_array_in = get_variant_element(lwe_array_in, i); @@ -585,7 +585,7 @@ void execute_keyswitch_async( lwe_dimension_out); // Compute Keyswitch - host_gemm_keyswitch_lwe_ciphertext_vector( + host_gemm_keyswitch_lwe_ciphertext_vector( streams.stream(i), streams.gpu_index(i), current_lwe_array_out, current_lwe_output_indexes, current_lwe_array_in, current_lwe_input_indexes, ksks[i], lwe_dimension_in, @@ -594,7 +594,7 @@ void execute_keyswitch_async( } else { // Compute Keyswitch - host_keyswitch_lwe_ciphertext_vector( + host_keyswitch_lwe_ciphertext_vector( streams.stream(i), streams.gpu_index(i), current_lwe_array_out, current_lwe_output_indexes, current_lwe_array_in, current_lwe_input_indexes, ksks[i], lwe_dimension_in, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh index 9193ba09f..b26f6a760 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh @@ -544,7 +544,7 @@ __host__ void integer_radix_apply_univariate_lookup_table( auto active_streams = streams.active_gpu_subset(num_radix_blocks); if (active_streams.count() == 1) { - execute_keyswitch_async( + execute_keyswitch_async( streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], (Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks, big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level, @@ -552,7 +552,7 @@ __host__ void integer_radix_apply_univariate_lookup_table( /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension - execute_pbs_async( + execute_pbs_async( streams.get_ith(0), (Torus *)lwe_array_out->ptr, lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks, lut->buffer, glwe_dimension, @@ -573,7 +573,7 @@ __host__ void integer_radix_apply_univariate_lookup_table( big_lwe_dimension + 1); POP_RANGE() /// Apply KS to go from a big LWE dimension to a small LWE dimension - execute_keyswitch_async( + execute_keyswitch_async( active_streams, lwe_after_ks_vec, lwe_trivial_indexes_vec, lwe_array_in_vec, lwe_trivial_indexes_vec, ksks, big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks, true, @@ -581,7 +581,7 @@ __host__ void integer_radix_apply_univariate_lookup_table( /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension - execute_pbs_async( + execute_pbs_async( active_streams, lwe_after_pbs_vec, lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, lut->buffer, glwe_dimension, @@ -641,13 +641,13 @@ __host__ void integer_radix_apply_many_univariate_lookup_table( /// For multi GPU execution we create vectors of pointers for inputs and /// outputs std::vector lwe_array_in_vec = lut->lwe_array_in_vec; - std::vector lwe_after_ks_vec = lut->lwe_after_ks_vec; + std::vector lwe_after_ks_vec = lut->lwe_after_ks_vec; std::vector lwe_after_pbs_vec = lut->lwe_after_pbs_vec; std::vector lwe_trivial_indexes_vec = lut->lwe_trivial_indexes_vec; auto active_streams = streams.active_gpu_subset(num_radix_blocks); if (active_streams.count() == 1) { - execute_keyswitch_async( + execute_keyswitch_async( streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], (Torus *)lwe_array_in->ptr, lut->lwe_indexes_in, ksks, big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level, @@ -655,7 +655,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table( /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension - execute_pbs_async( + execute_pbs_async( streams.get_ith(0), (Torus *)lwe_array_out->ptr, lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks, lut->buffer, glwe_dimension, @@ -676,7 +676,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table( big_lwe_dimension + 1); POP_RANGE() /// Apply KS to go from a big LWE dimension to a small LWE dimension - execute_keyswitch_async( + execute_keyswitch_async( active_streams, lwe_after_ks_vec, lwe_trivial_indexes_vec, lwe_array_in_vec, lwe_trivial_indexes_vec, ksks, big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks, true, @@ -684,7 +684,7 @@ __host__ void integer_radix_apply_many_univariate_lookup_table( /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension - execute_pbs_async( + execute_pbs_async( active_streams, lwe_after_pbs_vec, lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, lut->buffer, glwe_dimension, @@ -760,13 +760,13 @@ __host__ void integer_radix_apply_bivariate_lookup_table( /// For multi GPU execution we create vectors of pointers for inputs and /// outputs std::vector lwe_array_in_vec = lut->lwe_array_in_vec; - std::vector lwe_after_ks_vec = lut->lwe_after_ks_vec; + std::vector lwe_after_ks_vec = lut->lwe_after_ks_vec; std::vector lwe_after_pbs_vec = lut->lwe_after_pbs_vec; std::vector lwe_trivial_indexes_vec = lut->lwe_trivial_indexes_vec; auto active_streams = streams.active_gpu_subset(num_radix_blocks); if (active_streams.count() == 1) { - execute_keyswitch_async( + execute_keyswitch_async( streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], (Torus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks, big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level, @@ -774,7 +774,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table( /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension - execute_pbs_async( + execute_pbs_async( streams.get_ith(0), (Torus *)(lwe_array_out->ptr), lut->lwe_indexes_out, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks, lut->buffer, glwe_dimension, @@ -792,7 +792,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table( big_lwe_dimension + 1); POP_RANGE() /// Apply KS to go from a big LWE dimension to a small LWE dimension - execute_keyswitch_async( + execute_keyswitch_async( active_streams, lwe_after_ks_vec, lwe_trivial_indexes_vec, lwe_array_in_vec, lwe_trivial_indexes_vec, ksks, big_lwe_dimension, small_lwe_dimension, ks_base_log, ks_level, num_radix_blocks, true, @@ -800,7 +800,7 @@ __host__ void integer_radix_apply_bivariate_lookup_table( /// Apply PBS to apply a LUT, reduce the noise and go from a small LWE /// dimension to a big LWE dimension - execute_pbs_async( + execute_pbs_async( active_streams, lwe_after_pbs_vec, lwe_trivial_indexes_vec, lut->lut_vec, lut->lut_indexes_vec, lwe_after_ks_vec, lwe_trivial_indexes_vec, bsks, lut->buffer, glwe_dimension, @@ -1524,22 +1524,22 @@ void host_full_propagate_inplace(CudaStreams streams, as_radix_ciphertext_slice(&cur_input_block, input_blocks, i, i + 1); /// Since the keyswitch is done on one input only, use only 1 GPU - execute_keyswitch_async( - streams.get_ith(0), (Torus *)(mem_ptr->tmp_small_lwe_vector->ptr), + execute_keyswitch_async( + streams.get_ith(0), (KSTorus *)(mem_ptr->tmp_small_lwe_vector->ptr), mem_ptr->lut->lwe_trivial_indexes, (Torus *)cur_input_block.ptr, mem_ptr->lut->lwe_trivial_indexes, ksks, params.big_lwe_dimension, params.small_lwe_dimension, params.ks_base_log, params.ks_level, 1, mem_ptr->lut->using_trivial_lwe_indexes, mem_ptr->lut->ks_tmp_buf_vec); - copy_radix_ciphertext_slice_async( + copy_radix_ciphertext_slice_async( streams.stream(0), streams.gpu_index(0), mem_ptr->tmp_small_lwe_vector, 1, 2, mem_ptr->tmp_small_lwe_vector, 0, 1); - execute_pbs_async( + execute_pbs_async( streams.get_ith(0), (Torus *)mem_ptr->tmp_big_lwe_vector->ptr, mem_ptr->lut->lwe_trivial_indexes, mem_ptr->lut->lut_vec, mem_ptr->lut->lut_indexes_vec, - (Torus *)mem_ptr->tmp_small_lwe_vector->ptr, + (KSTorus *)mem_ptr->tmp_small_lwe_vector->ptr, mem_ptr->lut->lwe_trivial_indexes, bsks, mem_ptr->lut->buffer, params.glwe_dimension, params.small_lwe_dimension, params.polynomial_size, params.pbs_base_log, params.pbs_level, @@ -2325,9 +2325,9 @@ __host__ void integer_radix_apply_noise_squashing( /// outputs auto lwe_array_pbs_in = lut->tmp_lwe_before_ks; std::vector lwe_array_in_vec = lut->lwe_array_in_vec; - std::vector lwe_after_ks_vec = lut->lwe_after_ks_vec; + std::vector lwe_after_ks_vec = lut->lwe_after_ks_vec; std::vector<__uint128_t *> lwe_after_pbs_vec = lut->lwe_after_pbs_vec; - std::vector lwe_trivial_indexes_vec = + std::vector lwe_trivial_indexes_vec = lut->lwe_trivial_indexes_vec; // We know carry is empty so we can pack two blocks in one @@ -2340,7 +2340,7 @@ __host__ void integer_radix_apply_noise_squashing( auto active_streams = streams.active_gpu_subset(lwe_array_out->num_radix_blocks); if (active_streams.count() == 1) { - execute_keyswitch_async( + execute_keyswitch_async( streams.get_ith(0), lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], (InputTorus *)lwe_array_pbs_in->ptr, lut->lwe_indexes_in, ksks, lut->input_big_lwe_dimension, small_lwe_dimension, ks_base_log, @@ -2352,7 +2352,7 @@ __host__ void integer_radix_apply_noise_squashing( /// /// int_noise_squashing_lut doesn't support a different output or lut /// indexing than the trivial - execute_pbs_async( + execute_pbs_async( streams.get_ith(0), (__uint128_t *)lwe_array_out->ptr, lwe_trivial_indexes_vec[0], lut->lut_vec, lwe_trivial_indexes_vec, lwe_after_ks_vec[0], lwe_trivial_indexes_vec[0], bsks, lut->buffer, @@ -2371,7 +2371,7 @@ __host__ void integer_radix_apply_noise_squashing( lut->lwe_aligned_scatter_vec, lut->active_streams.count(), lwe_array_out->num_radix_blocks, lut->input_big_lwe_dimension + 1); - execute_keyswitch_async( + execute_keyswitch_async( active_streams, lwe_after_ks_vec, lwe_trivial_indexes_vec, lwe_array_in_vec, lwe_trivial_indexes_vec, ksks, lut->input_big_lwe_dimension, small_lwe_dimension, ks_base_log, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh index 10c02c36a..219f5fd52 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/multiplication.cuh @@ -286,7 +286,7 @@ __host__ uint64_t scratch_cuda_integer_partial_sum_ciphertexts_vec( template __host__ void host_integer_partial_sum_ciphertexts_vec( CudaStreams streams, CudaRadixCiphertextFFI *radix_lwe_out, - CudaRadixCiphertextFFI *terms, void *const *bsks, uint64_t *const *ksks, + CudaRadixCiphertextFFI *terms, void *const *bsks, KSTorus *const *ksks, int_sum_ciphertexts_vec_memory *mem_ptr, uint32_t num_radix_blocks, uint32_t num_radix_in_vec) { auto big_lwe_dimension = mem_ptr->params.big_lwe_dimension; @@ -394,17 +394,17 @@ __host__ void host_integer_partial_sum_ciphertexts_vec( "SUM CT"); if (active_streams.count() == 1) { - execute_keyswitch_async( - streams.get_ith(0), (Torus *)small_lwe_vector->ptr, d_pbs_indexes_in, - (Torus *)current_blocks->ptr, d_pbs_indexes_in, ksks, - big_lwe_dimension, small_lwe_dimension, mem_ptr->params.ks_base_log, - mem_ptr->params.ks_level, total_messages, false, - mem_ptr->luts_message_carry->ks_tmp_buf_vec); + execute_keyswitch_async( + streams.get_ith(0), (KSTorus *)small_lwe_vector->ptr, + d_pbs_indexes_in, (Torus *)current_blocks->ptr, d_pbs_indexes_in, + ksks, big_lwe_dimension, small_lwe_dimension, + mem_ptr->params.ks_base_log, mem_ptr->params.ks_level, total_messages, + false, mem_ptr->luts_message_carry->ks_tmp_buf_vec); - execute_pbs_async( + execute_pbs_async( streams.get_ith(0), (Torus *)current_blocks->ptr, d_pbs_indexes_out, luts_message_carry->lut_vec, luts_message_carry->lut_indexes_vec, - (Torus *)small_lwe_vector->ptr, d_pbs_indexes_in, bsks, + (KSTorus *)small_lwe_vector->ptr, d_pbs_indexes_in, bsks, luts_message_carry->buffer, glwe_dimension, small_lwe_dimension, polynomial_size, mem_ptr->params.pbs_base_log, mem_ptr->params.pbs_level, mem_ptr->params.grouping_factor, @@ -416,7 +416,7 @@ __host__ void host_integer_partial_sum_ciphertexts_vec( luts_message_carry->broadcast_lut(active_streams, false); luts_message_carry->using_trivial_lwe_indexes = false; - integer_radix_apply_univariate_lookup_table( + integer_radix_apply_univariate_lookup_table( streams, current_blocks, current_blocks, bsks, ksks, luts_message_carry, total_ciphertexts); } @@ -446,17 +446,17 @@ __host__ void host_integer_partial_sum_ciphertexts_vec( auto active_streams = streams.active_gpu_subset(2 * num_radix_blocks); if (active_streams.count() == 1) { - execute_keyswitch_async( - streams.get_ith(0), (Torus *)small_lwe_vector->ptr, d_pbs_indexes_in, - (Torus *)radix_lwe_out->ptr, d_pbs_indexes_in, ksks, + execute_keyswitch_async( + streams.get_ith(0), (KSTorus *)small_lwe_vector->ptr, + d_pbs_indexes_in, (Torus *)radix_lwe_out->ptr, d_pbs_indexes_in, ksks, big_lwe_dimension, small_lwe_dimension, mem_ptr->params.ks_base_log, mem_ptr->params.ks_level, num_radix_blocks, false, mem_ptr->luts_message_carry->ks_tmp_buf_vec); - execute_pbs_async( + execute_pbs_async( streams.get_ith(0), (Torus *)current_blocks->ptr, d_pbs_indexes_out, luts_message_carry->lut_vec, luts_message_carry->lut_indexes_vec, - (Torus *)small_lwe_vector->ptr, d_pbs_indexes_in, bsks, + (KSTorus *)small_lwe_vector->ptr, d_pbs_indexes_in, bsks, luts_message_carry->buffer, glwe_dimension, small_lwe_dimension, polynomial_size, mem_ptr->params.pbs_base_log, mem_ptr->params.pbs_level, mem_ptr->params.grouping_factor, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/rerand.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/rerand.cuh index 7d2d189c5..87f97099e 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/rerand.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/rerand.cuh @@ -56,7 +56,7 @@ void rerand_inplace( zero_lwes, d_expand_jobs, num_lwes); // Keyswitch - execute_keyswitch_async( + execute_keyswitch_async( streams.get_ith(0), ksed_zero_lwes, lwe_trivial_indexes, zero_lwes, lwe_trivial_indexes, ksk, input_dimension, output_dimension, ks_base_log, ks_level, num_lwes, true, mem_ptr->ks_tmp_buf_vec); diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cu b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cu index 016c21605..c84aee9e1 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cu @@ -20,6 +20,38 @@ uint64_t scratch_cuda_integer_scalar_mul_64( num_scalar_bits, allocate_gpu_memory); } +uint64_t scratch_cuda_integer_scalar_mul_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t lwe_dimension, uint32_t ks_level, + uint32_t ks_base_log, uint32_t pbs_level, uint32_t pbs_base_log, + uint32_t grouping_factor, uint32_t num_blocks, uint32_t message_modulus, + uint32_t carry_modulus, PBS_TYPE pbs_type, uint32_t num_scalar_bits, + bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + glwe_dimension * polynomial_size, lwe_dimension, + ks_level, ks_base_log, pbs_level, pbs_base_log, + grouping_factor, message_modulus, carry_modulus, + noise_reduction_type); + + return scratch_cuda_scalar_mul( + CudaStreams(streams), + (int_scalar_mul_buffer **)mem_ptr, num_blocks, params, + num_scalar_bits, allocate_gpu_memory); +} + +void cuda_scalar_multiplication_ciphertext_64_ks32_inplace( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array, + uint64_t const *decomposed_scalar, uint64_t const *has_at_least_one_set, + int8_t *mem, void *const *bsks, void *const *ksks, uint32_t polynomial_size, + uint32_t message_modulus, uint32_t num_scalars) { + + host_integer_scalar_mul_radix( + CudaStreams(streams), lwe_array, decomposed_scalar, has_at_least_one_set, + reinterpret_cast *>(mem), bsks, + (uint32_t **)(ksks), message_modulus, num_scalars); +} + void cuda_scalar_multiplication_ciphertext_64_inplace( CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array, uint64_t const *decomposed_scalar, uint64_t const *has_at_least_one_set, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh index fc8ed5e1d..a87be3e3b 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh @@ -111,16 +111,16 @@ __host__ void host_integer_scalar_mul_radix( streams.gpu_index(0), lwe_array, 0, num_radix_blocks); } else { - host_integer_partial_sum_ciphertexts_vec( + host_integer_partial_sum_ciphertexts_vec( streams, lwe_array, all_shifted_buffer, bsks, ksks, mem->sum_ciphertexts_vec_mem, num_radix_blocks, j); auto scp_mem_ptr = mem->sc_prop_mem; uint32_t requested_flag = outputFlag::FLAG_NONE; uint32_t uses_carry = 0; - host_propagate_single_carry(streams, lwe_array, nullptr, nullptr, - scp_mem_ptr, bsks, ksks, requested_flag, - uses_carry); + host_propagate_single_carry(streams, lwe_array, nullptr, + nullptr, scp_mem_ptr, bsks, ksks, + requested_flag, uses_carry); } } diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cu b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cu index f14893d7d..57e235319 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cu +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_shifts.cu @@ -14,12 +14,32 @@ uint64_t scratch_cuda_logical_scalar_shift_64( ks_base_log, pbs_level, pbs_base_log, grouping_factor, message_modulus, carry_modulus, noise_reduction_type); - return scratch_cuda_logical_scalar_shift( + return scratch_cuda_logical_scalar_shift( CudaStreams(streams), (int_logical_scalar_shift_buffer **)mem_ptr, num_blocks, params, shift_type, allocate_gpu_memory); } +uint64_t scratch_cuda_logical_scalar_shift_64_ks32( + CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, + uint32_t polynomial_size, uint32_t big_lwe_dimension, + uint32_t small_lwe_dimension, uint32_t ks_level, uint32_t ks_base_log, + uint32_t pbs_level, uint32_t pbs_base_log, uint32_t grouping_factor, + uint32_t num_blocks, uint32_t message_modulus, uint32_t carry_modulus, + PBS_TYPE pbs_type, SHIFT_OR_ROTATE_TYPE shift_type, + bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) { + + int_radix_params params(pbs_type, glwe_dimension, polynomial_size, + big_lwe_dimension, small_lwe_dimension, ks_level, + ks_base_log, pbs_level, pbs_base_log, grouping_factor, + message_modulus, carry_modulus, noise_reduction_type); + + return scratch_cuda_logical_scalar_shift( + CudaStreams(streams), + (int_logical_scalar_shift_buffer **)mem_ptr, + num_blocks, params, shift_type, allocate_gpu_memory); +} + /// The logical scalar shift is the one used for unsigned integers, and /// for the left scalar shift. It is constituted of a rotation, followed by /// the application of a PBS onto the rotated blocks up to num_blocks - @@ -30,12 +50,22 @@ void cuda_logical_scalar_shift_64_inplace(CudaStreamsFFI streams, void *const *bsks, void *const *ksks) { - host_logical_scalar_shift_inplace( + host_logical_scalar_shift_inplace( CudaStreams(streams), lwe_array, shift, (int_logical_scalar_shift_buffer *)mem_ptr, bsks, (uint64_t **)(ksks), lwe_array->num_radix_blocks); } +void cuda_logical_scalar_shift_64_ks32_inplace( + CudaStreamsFFI streams, CudaRadixCiphertextFFI *lwe_array, uint32_t shift, + int8_t *mem_ptr, void *const *bsks, void *const *ksks) { + + host_logical_scalar_shift_inplace( + CudaStreams(streams), lwe_array, shift, + (int_logical_scalar_shift_buffer *)mem_ptr, bsks, + (uint32_t **)(ksks), lwe_array->num_radix_blocks); +} + uint64_t scratch_cuda_arithmetic_scalar_shift_64( CudaStreamsFFI streams, int8_t **mem_ptr, uint32_t glwe_dimension, uint32_t polynomial_size, uint32_t big_lwe_dimension, diff --git a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh index 8af383f02..d9f1e2972 100644 --- a/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/pbs/programmable_bootstrap.cuh @@ -249,11 +249,11 @@ __device__ void mul_ggsw_glwe_in_fourier_domain_2_2_params_classical( template void execute_pbs_async(CudaStreams streams, const LweArrayVariant &lwe_array_out, - const LweArrayVariant &lwe_output_indexes, + const LweArrayVariant &lwe_output_indexes, const std::vector lut_vec, - const std::vector lut_indexes_vec, + const std::vector &lut_indexes_vec, const LweArrayVariant &lwe_array_in, - const LweArrayVariant &lwe_input_indexes, + const LweArrayVariant &lwe_input_indexes, void *const *bootstrapping_keys, std::vector pbs_buffer, uint32_t glwe_dimension, uint32_t lwe_dimension, diff --git a/backends/tfhe-cuda-backend/cuda/src/zk/zk.cuh b/backends/tfhe-cuda-backend/cuda/src/zk/zk.cuh index 7fe288736..ebec7db9a 100644 --- a/backends/tfhe-cuda-backend/cuda/src/zk/zk.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/zk/zk.cuh @@ -77,7 +77,7 @@ host_expand_without_verification(CudaStreams streams, Torus *lwe_array_out, auto casting_ks_base_log = casting_params.ks_base_log; // apply keyswitch to BIG - execute_keyswitch_async( + execute_keyswitch_async( streams.get_ith(0), ksed_small_to_big_expanded_lwes, lwe_trivial_indexes_vec[0], expanded_lwes, lwe_trivial_indexes_vec[0], casting_keys, casting_input_dimension, casting_output_dimension, diff --git a/backends/tfhe-cuda-backend/src/bindings.rs b/backends/tfhe-cuda-backend/src/bindings.rs index 5eed4f0f8..4ce367c65 100644 --- a/backends/tfhe-cuda-backend/src/bindings.rs +++ b/backends/tfhe-cuda-backend/src/bindings.rs @@ -472,6 +472,28 @@ unsafe extern "C" { noise_reduction_type: PBS_MS_REDUCTION_T, ) -> u64; } +unsafe extern "C" { + pub fn scratch_cuda_logical_scalar_shift_64_ks32( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + big_lwe_dimension: u32, + small_lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + num_blocks: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + shift_type: SHIFT_OR_ROTATE_TYPE, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} unsafe extern "C" { pub fn cuda_logical_scalar_shift_64_inplace( streams: CudaStreamsFFI, @@ -482,6 +504,16 @@ unsafe extern "C" { ksks: *const *mut ffi::c_void, ); } +unsafe extern "C" { + pub fn cuda_logical_scalar_shift_64_ks32_inplace( + streams: CudaStreamsFFI, + lwe_array: *mut CudaRadixCiphertextFFI, + shift: u32, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + ); +} unsafe extern "C" { pub fn scratch_cuda_arithmetic_scalar_shift_64( streams: CudaStreamsFFI, @@ -990,6 +1022,27 @@ unsafe extern "C" { noise_reduction_type: PBS_MS_REDUCTION_T, ) -> u64; } +unsafe extern "C" { + pub fn scratch_cuda_integer_scalar_mul_64_ks32( + streams: CudaStreamsFFI, + mem_ptr: *mut *mut i8, + glwe_dimension: u32, + polynomial_size: u32, + lwe_dimension: u32, + ks_level: u32, + ks_base_log: u32, + pbs_level: u32, + pbs_base_log: u32, + grouping_factor: u32, + num_blocks: u32, + message_modulus: u32, + carry_modulus: u32, + pbs_type: PBS_TYPE, + num_scalar_bits: u32, + allocate_gpu_memory: bool, + noise_reduction_type: PBS_MS_REDUCTION_T, + ) -> u64; +} unsafe extern "C" { pub fn cuda_scalar_multiplication_ciphertext_64_inplace( streams: CudaStreamsFFI, @@ -1004,6 +1057,20 @@ unsafe extern "C" { num_scalars: u32, ); } +unsafe extern "C" { + pub fn cuda_scalar_multiplication_ciphertext_64_ks32_inplace( + streams: CudaStreamsFFI, + lwe_array: *mut CudaRadixCiphertextFFI, + decomposed_scalar: *const u64, + has_at_least_one_set: *const u64, + mem_ptr: *mut i8, + bsks: *const *mut ffi::c_void, + ksks: *const *mut ffi::c_void, + polynomial_size: u32, + message_modulus: u32, + num_scalars: u32, + ); +} unsafe extern "C" { pub fn cleanup_cuda_scalar_mul(streams: CudaStreamsFFI, mem_ptr_void: *mut *mut i8); } diff --git a/tfhe/src/integer/gpu/mod.rs b/tfhe/src/integer/gpu/mod.rs index 14bdd6508..5adcdce27 100644 --- a/tfhe/src/integer/gpu/mod.rs +++ b/tfhe/src/integer/gpu/mod.rs @@ -435,38 +435,75 @@ pub(crate) unsafe fn cuda_backend_unchecked_scalar_mul< .filter(|&&rhs_bit| rhs_bit == T::ONE) .count() as u32; - scratch_cuda_integer_scalar_mul_64( - streams.ffi(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - grouping_factor.0 as u32, - lwe_array.d_blocks.0.lwe_ciphertext_count.0 as u32, - message_modulus.0 as u32, - carry_modulus.0 as u32, - pbs_type as u32, - num_scalar_bits, - true, - noise_reduction_type as u32, - ); + if TypeId::of::() == TypeId::of::() { + scratch_cuda_integer_scalar_mul_64_ks32( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + lwe_array.d_blocks.0.lwe_ciphertext_count.0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + num_scalar_bits, + true, + noise_reduction_type as u32, + ); - cuda_scalar_multiplication_ciphertext_64_inplace( - streams.ffi(), - &raw mut cuda_ffi_lwe_array, - decomposed_scalar.as_ptr().cast::(), - has_at_least_one_set.as_ptr().cast::(), - mem_ptr, - bootstrapping_key.ptr.as_ptr(), - keyswitch_key.ptr.as_ptr(), - polynomial_size.0 as u32, - message_modulus.0 as u32, - num_scalars, - ); + cuda_scalar_multiplication_ciphertext_64_ks32_inplace( + streams.ffi(), + &raw mut cuda_ffi_lwe_array, + decomposed_scalar.as_ptr().cast::(), + has_at_least_one_set.as_ptr().cast::(), + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + polynomial_size.0 as u32, + message_modulus.0 as u32, + num_scalars, + ); + } else if TypeId::of::() == TypeId::of::() { + scratch_cuda_integer_scalar_mul_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + lwe_array.d_blocks.0.lwe_ciphertext_count.0 as u32, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + num_scalar_bits, + true, + noise_reduction_type as u32, + ); + + cuda_scalar_multiplication_ciphertext_64_inplace( + streams.ffi(), + &raw mut cuda_ffi_lwe_array, + decomposed_scalar.as_ptr().cast::(), + has_at_least_one_set.as_ptr().cast::(), + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + polynomial_size.0 as u32, + message_modulus.0 as u32, + num_scalars, + ); + } else { + panic!("Unknown KS dtype"); + } cleanup_cuda_scalar_mul(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); update_noise_degree(lwe_array, &cuda_ffi_lwe_array); @@ -4163,34 +4200,68 @@ pub(crate) unsafe fn cuda_backend_unchecked_scalar_left_shift_assign< &mut radix_lwe_left_noise_levels, ); - scratch_cuda_logical_scalar_shift_64( - streams.ffi(), - std::ptr::addr_of_mut!(mem_ptr), - glwe_dimension.0 as u32, - polynomial_size.0 as u32, - big_lwe_dimension.0 as u32, - small_lwe_dimension.0 as u32, - ks_level.0 as u32, - ks_base_log.0 as u32, - pbs_level.0 as u32, - pbs_base_log.0 as u32, - grouping_factor.0 as u32, - num_blocks, - message_modulus.0 as u32, - carry_modulus.0 as u32, - pbs_type as u32, - ShiftRotateType::LeftShift as u32, - true, - noise_reduction_type as u32, - ); - cuda_logical_scalar_shift_64_inplace( - streams.ffi(), - &raw mut cuda_ffi_radix_lwe_left, - shift, - mem_ptr, - bootstrapping_key.ptr.as_ptr(), - keyswitch_key.ptr.as_ptr(), - ); + if TypeId::of::() == TypeId::of::() { + scratch_cuda_logical_scalar_shift_64_ks32( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + ShiftRotateType::LeftShift as u32, + true, + noise_reduction_type as u32, + ); + cuda_logical_scalar_shift_64_ks32_inplace( + streams.ffi(), + &raw mut cuda_ffi_radix_lwe_left, + shift, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + } else if TypeId::of::() == TypeId::of::() { + scratch_cuda_logical_scalar_shift_64( + streams.ffi(), + std::ptr::addr_of_mut!(mem_ptr), + glwe_dimension.0 as u32, + polynomial_size.0 as u32, + big_lwe_dimension.0 as u32, + small_lwe_dimension.0 as u32, + ks_level.0 as u32, + ks_base_log.0 as u32, + pbs_level.0 as u32, + pbs_base_log.0 as u32, + grouping_factor.0 as u32, + num_blocks, + message_modulus.0 as u32, + carry_modulus.0 as u32, + pbs_type as u32, + ShiftRotateType::LeftShift as u32, + true, + noise_reduction_type as u32, + ); + cuda_logical_scalar_shift_64_inplace( + streams.ffi(), + &raw mut cuda_ffi_radix_lwe_left, + shift, + mem_ptr, + bootstrapping_key.ptr.as_ptr(), + keyswitch_key.ptr.as_ptr(), + ); + } else { + panic!("Unknown KS dtype"); + } + cleanup_cuda_logical_scalar_shift(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr)); update_noise_degree(input, &cuda_ffi_radix_lwe_left); } diff --git a/tfhe/src/integer/gpu/server_key/mod.rs b/tfhe/src/integer/gpu/server_key/mod.rs index 752156feb..2faa4e4cf 100644 --- a/tfhe/src/integer/gpu/server_key/mod.rs +++ b/tfhe/src/integer/gpu/server_key/mod.rs @@ -16,7 +16,7 @@ use crate::shortint::atomic_pattern::compressed::CompressedAtomicPatternServerKe use crate::shortint::ciphertext::{MaxDegree, MaxNoiseLevel}; use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey; use crate::shortint::engine::ShortintEngine; -use crate::shortint::parameters::ModulusSwitchType; +use crate::shortint::parameters::{CiphertextModulus32, ModulusSwitchType}; use crate::shortint::server_key::CompressedModulusSwitchConfiguration; use crate::shortint::{CarryModulus, CiphertextModulus, MessageModulus, PBSOrder}; @@ -101,27 +101,113 @@ impl CudaServerKey { ) -> Self { let mut engine = ShortintEngine::new(); - // Generate a regular keyset and convert to the GPU - let AtomicPatternClientKey::Standard(std_cks) = &cks.key.atomic_pattern else { - panic!("Only the standard atomic pattern is supported on GPU") - }; + match &cks.key.atomic_pattern { + AtomicPatternClientKey::Standard(std_cks) => { + let pbs_params_base = std_cks.parameters; - let pbs_params_base = std_cks.parameters; + let d_bootstrapping_key = match pbs_params_base { + crate::shortint::PBSParameters::PBS(pbs_params) => { + let h_bootstrap_key: LweBootstrapKeyOwned = + par_allocate_and_generate_new_lwe_bootstrap_key( + &std_cks.lwe_secret_key, + &std_cks.glwe_secret_key, + pbs_params.pbs_base_log, + pbs_params.pbs_level, + pbs_params.glwe_noise_distribution, + pbs_params.ciphertext_modulus, + &mut engine.encryption_generator, + ); + let modulus_switch_noise_reduction_configuration = + match pbs_params.modulus_switch_noise_reduction_params { + ModulusSwitchType::Standard => None, + ModulusSwitchType::DriftTechniqueNoiseReduction( + _modulus_switch_noise_reduction_params, + ) => { + panic!("Drift noise reduction is not supported on GPU") + } + ModulusSwitchType::CenteredMeanNoiseReduction => { + Some(CudaModulusSwitchNoiseReductionConfiguration::Centered) + } + }; - let d_bootstrapping_key = match pbs_params_base { - crate::shortint::PBSParameters::PBS(pbs_params) => { + let d_bootstrap_key = CudaLweBootstrapKey::from_lwe_bootstrap_key( + &h_bootstrap_key, + modulus_switch_noise_reduction_configuration, + streams, + ); + + CudaBootstrappingKey::Classic(d_bootstrap_key) + } + crate::shortint::PBSParameters::MultiBitPBS(pbs_params) => { + let h_bootstrap_key: LweMultiBitBootstrapKeyOwned = + par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key( + &std_cks.lwe_secret_key, + &std_cks.glwe_secret_key, + pbs_params.pbs_base_log, + pbs_params.pbs_level, + pbs_params.grouping_factor, + pbs_params.glwe_noise_distribution, + pbs_params.ciphertext_modulus, + &mut engine.encryption_generator, + ); + + let d_bootstrap_key = + CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key( + &h_bootstrap_key, + streams, + ); + + CudaBootstrappingKey::MultiBit(d_bootstrap_key) + } + }; + + // Creation of the key switching key + let h_key_switching_key = allocate_and_generate_new_lwe_keyswitch_key( + &std_cks.large_lwe_secret_key(), + &std_cks.small_lwe_secret_key(), + std_cks.parameters.ks_base_log(), + std_cks.parameters.ks_level(), + std_cks.parameters.lwe_noise_distribution(), + std_cks.parameters.ciphertext_modulus(), + &mut engine.encryption_generator, + ); + + let d_key_switching_key = + CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams); + + assert!(matches!( + std_cks.parameters.encryption_key_choice().into(), + PBSOrder::KeyswitchBootstrap + )); + + // Pack the keys in the server key set: + Self { + key_switching_key: CudaDynamicKeyswitchingKey::Standard(d_key_switching_key), + bootstrapping_key: d_bootstrapping_key, + message_modulus: std_cks.parameters.message_modulus(), + carry_modulus: std_cks.parameters.carry_modulus(), + max_degree, + max_noise_level: std_cks.parameters.max_noise_level(), + ciphertext_modulus: std_cks.parameters.ciphertext_modulus(), + pbs_order: std_cks.parameters.encryption_key_choice().into(), + } + } + AtomicPatternClientKey::KeySwitch32(ks32_cks) => { + let pbs_params_base = ks32_cks.parameters; + + // let d_bootstrapping_key let h_bootstrap_key: LweBootstrapKeyOwned = par_allocate_and_generate_new_lwe_bootstrap_key( - &std_cks.lwe_secret_key, - &std_cks.glwe_secret_key, - pbs_params.pbs_base_log, - pbs_params.pbs_level, - pbs_params.glwe_noise_distribution, - pbs_params.ciphertext_modulus, + &ks32_cks.lwe_secret_key, + &ks32_cks.glwe_secret_key, + pbs_params_base.pbs_base_log, + pbs_params_base.pbs_level, + pbs_params_base.glwe_noise_distribution, + pbs_params_base.ciphertext_modulus, &mut engine.encryption_generator, ); let modulus_switch_noise_reduction_configuration = - match pbs_params.modulus_switch_noise_reduction_params { + match pbs_params_base.modulus_switch_noise_reduction_params { ModulusSwitchType::Standard => None, ModulusSwitchType::DriftTechniqueNoiseReduction( _modulus_switch_noise_reduction_params, @@ -139,59 +225,39 @@ impl CudaServerKey { streams, ); - CudaBootstrappingKey::Classic(d_bootstrap_key) - } - crate::shortint::PBSParameters::MultiBitPBS(pbs_params) => { - let h_bootstrap_key: LweMultiBitBootstrapKeyOwned = - par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key( - &std_cks.lwe_secret_key, - &std_cks.glwe_secret_key, - pbs_params.pbs_base_log, - pbs_params.pbs_level, - pbs_params.grouping_factor, - pbs_params.glwe_noise_distribution, - pbs_params.ciphertext_modulus, - &mut engine.encryption_generator, - ); + let d_bootstrapping_key = CudaBootstrappingKey::Classic(d_bootstrap_key); - let d_bootstrap_key = CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key( - &h_bootstrap_key, - streams, + // Creation of the key switching key + let h_key_switching_key = allocate_and_generate_new_lwe_keyswitch_key( + &ks32_cks.large_lwe_secret_key(), + &ks32_cks.small_lwe_secret_key(), + ks32_cks.parameters.ks_base_log(), + ks32_cks.parameters.ks_level(), + ks32_cks.parameters.lwe_noise_distribution(), + CiphertextModulus32::new_native(), + &mut engine.encryption_generator, ); - CudaBootstrappingKey::MultiBit(d_bootstrap_key) + let d_key_switching_key = + CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams); + + assert!(matches!( + ks32_cks.parameters.encryption_key_choice().into(), + PBSOrder::KeyswitchBootstrap + )); + + // Pack the keys in the server key set: + Self { + key_switching_key: CudaDynamicKeyswitchingKey::KeySwitch32(d_key_switching_key), + bootstrapping_key: d_bootstrapping_key, + message_modulus: ks32_cks.parameters.message_modulus(), + carry_modulus: ks32_cks.parameters.carry_modulus(), + max_degree, + max_noise_level: ks32_cks.parameters.max_noise_level(), + ciphertext_modulus: ks32_cks.parameters.ciphertext_modulus(), + pbs_order: ks32_cks.parameters.encryption_key_choice().into(), + } } - }; - - // Creation of the key switching key - let h_key_switching_key = allocate_and_generate_new_lwe_keyswitch_key( - &std_cks.large_lwe_secret_key(), - &std_cks.small_lwe_secret_key(), - std_cks.parameters.ks_base_log(), - std_cks.parameters.ks_level(), - std_cks.parameters.lwe_noise_distribution(), - std_cks.parameters.ciphertext_modulus(), - &mut engine.encryption_generator, - ); - - let d_key_switching_key = - CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams); - - assert!(matches!( - std_cks.parameters.encryption_key_choice().into(), - PBSOrder::KeyswitchBootstrap - )); - - // Pack the keys in the server key set: - Self { - key_switching_key: CudaDynamicKeyswitchingKey::Standard(d_key_switching_key), - bootstrapping_key: d_bootstrapping_key, - message_modulus: std_cks.parameters.message_modulus(), - carry_modulus: std_cks.parameters.carry_modulus(), - max_degree, - max_noise_level: std_cks.parameters.max_noise_level(), - ciphertext_modulus: std_cks.parameters.ciphertext_modulus(), - pbs_order: std_cks.parameters.encryption_key_choice().into(), } } diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs index c42719846..576acef89 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs @@ -109,57 +109,108 @@ impl CudaServerKey { if decomposed_scalar.is_empty() { return; } - let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else { - panic!("Only the standard atomic pattern is supported on GPU") - }; unsafe { - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - cuda_backend_unchecked_scalar_mul( - streams, - ct.as_mut(), - decomposed_scalar.as_slice(), - has_at_least_one_set.as_slice(), - &d_bsk.d_vec, - &computing_ks_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_bsk.glwe_dimension, - d_bsk.polynomial_size, - computing_ks_key.output_key_lwe_size().to_lwe_dimension(), - d_bsk.decomp_base_log, - d_bsk.decomp_level_count, - computing_ks_key.decomposition_base_log(), - computing_ks_key.decomposition_level_count(), - decomposed_scalar.len() as u32, - PBSType::Classical, - LweBskGroupingFactor(0), - d_bsk.ms_noise_reduction_configuration.as_ref(), - ); + match &self.key_switching_key { + CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_scalar_mul( + streams, + ct.as_mut(), + decomposed_scalar.as_slice(), + has_at_least_one_set.as_slice(), + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + d_bsk.decomp_base_log, + d_bsk.decomp_level_count, + computing_ks_key.decomposition_base_log(), + computing_ks_key.decomposition_level_count(), + decomposed_scalar.len() as u32, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_scalar_mul( + streams, + ct.as_mut(), + decomposed_scalar.as_slice(), + has_at_least_one_set.as_slice(), + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + d_multibit_bsk.decomp_base_log, + d_multibit_bsk.decomp_level_count, + computing_ks_key.decomposition_base_log(), + computing_ks_key.decomposition_level_count(), + decomposed_scalar.len() as u32, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } + } } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - cuda_backend_unchecked_scalar_mul( - streams, - ct.as_mut(), - decomposed_scalar.as_slice(), - has_at_least_one_set.as_slice(), - &d_multibit_bsk.d_vec, - &computing_ks_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_multibit_bsk.glwe_dimension, - d_multibit_bsk.polynomial_size, - computing_ks_key.output_key_lwe_size().to_lwe_dimension(), - d_multibit_bsk.decomp_base_log, - d_multibit_bsk.decomp_level_count, - computing_ks_key.decomposition_base_log(), - computing_ks_key.decomposition_level_count(), - decomposed_scalar.len() as u32, - PBSType::MultiBit, - d_multibit_bsk.grouping_factor, - None, - ); + CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_scalar_mul( + streams, + ct.as_mut(), + decomposed_scalar.as_slice(), + has_at_least_one_set.as_slice(), + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + d_bsk.decomp_base_log, + d_bsk.decomp_level_count, + computing_ks_key.decomposition_base_log(), + computing_ks_key.decomposition_level_count(), + decomposed_scalar.len() as u32, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_scalar_mul( + streams, + ct.as_mut(), + decomposed_scalar.as_slice(), + has_at_least_one_set.as_slice(), + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + d_multibit_bsk.decomp_base_log, + d_multibit_bsk.decomp_level_count, + computing_ks_key.decomposition_base_log(), + computing_ks_key.decomposition_level_count(), + decomposed_scalar.len() as u32, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } + } } } } diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs index 68e299b4f..3aca4fbd6 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_shift.rs @@ -76,60 +76,109 @@ impl CudaServerKey { T: CudaIntegerRadixCiphertext, { let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count(); - let CudaDynamicKeyswitchingKey::Standard(computing_ks_key) = &self.key_switching_key else { - panic!("Only the standard atomic pattern is supported on GPU") - }; - unsafe { - match &self.bootstrapping_key { - CudaBootstrappingKey::Classic(d_bsk) => { - cuda_backend_unchecked_scalar_left_shift_assign( - streams, - ct.as_mut(), - u32::cast_from(shift), - &d_bsk.d_vec, - &computing_ks_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_bsk.glwe_dimension, - d_bsk.polynomial_size, - computing_ks_key.input_key_lwe_size().to_lwe_dimension(), - computing_ks_key.output_key_lwe_size().to_lwe_dimension(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_bsk.decomp_level_count, - d_bsk.decomp_base_log, - lwe_ciphertext_count.0 as u32, - PBSType::Classical, - LweBskGroupingFactor(0), - d_bsk.ms_noise_reduction_configuration.as_ref(), - ); + match &self.key_switching_key { + CudaDynamicKeyswitchingKey::Standard(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_scalar_left_shift_assign( + streams, + ct.as_mut(), + u32::cast_from(shift), + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + lwe_ciphertext_count.0 as u32, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_scalar_left_shift_assign( + streams, + ct.as_mut(), + u32::cast_from(shift), + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + lwe_ciphertext_count.0 as u32, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } - CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { - cuda_backend_unchecked_scalar_left_shift_assign( - streams, - ct.as_mut(), - u32::cast_from(shift), - &d_multibit_bsk.d_vec, - &computing_ks_key.d_vec, - self.message_modulus, - self.carry_modulus, - d_multibit_bsk.glwe_dimension, - d_multibit_bsk.polynomial_size, - computing_ks_key.input_key_lwe_size().to_lwe_dimension(), - computing_ks_key.output_key_lwe_size().to_lwe_dimension(), - computing_ks_key.decomposition_level_count(), - computing_ks_key.decomposition_base_log(), - d_multibit_bsk.decomp_level_count, - d_multibit_bsk.decomp_base_log, - lwe_ciphertext_count.0 as u32, - PBSType::MultiBit, - d_multibit_bsk.grouping_factor, - None, - ); + }, + CudaDynamicKeyswitchingKey::KeySwitch32(computing_ks_key) => unsafe { + match &self.bootstrapping_key { + CudaBootstrappingKey::Classic(d_bsk) => { + cuda_backend_unchecked_scalar_left_shift_assign( + streams, + ct.as_mut(), + u32::cast_from(shift), + &d_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_bsk.glwe_dimension, + d_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_bsk.decomp_level_count, + d_bsk.decomp_base_log, + lwe_ciphertext_count.0 as u32, + PBSType::Classical, + LweBskGroupingFactor(0), + d_bsk.ms_noise_reduction_configuration.as_ref(), + ); + } + CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { + cuda_backend_unchecked_scalar_left_shift_assign( + streams, + ct.as_mut(), + u32::cast_from(shift), + &d_multibit_bsk.d_vec, + &computing_ks_key.d_vec, + self.message_modulus, + self.carry_modulus, + d_multibit_bsk.glwe_dimension, + d_multibit_bsk.polynomial_size, + computing_ks_key.input_key_lwe_size().to_lwe_dimension(), + computing_ks_key.output_key_lwe_size().to_lwe_dimension(), + computing_ks_key.decomposition_level_count(), + computing_ks_key.decomposition_base_log(), + d_multibit_bsk.decomp_level_count, + d_multibit_bsk.decomp_base_log, + lwe_ciphertext_count.0 as u32, + PBSType::MultiBit, + d_multibit_bsk.grouping_factor, + None, + ); + } } - } - } + }, + }; } /// Computes homomorphically a right shift by a scalar. diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_shift.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_shift.rs index 355f835f7..2580ea94f 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_shift.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_shift.rs @@ -10,6 +10,7 @@ use crate::shortint::parameters::test_params::*; use crate::shortint::parameters::*; create_gpu_parameterized_test!(integer_signed_unchecked_scalar_left_shift); + create_gpu_parameterized_test!(integer_signed_scalar_left_shift); create_gpu_parameterized_test!(integer_signed_unchecked_scalar_right_shift); create_gpu_parameterized_test!(integer_signed_scalar_right_shift); diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs index ed7fbe931..feaa40eeb 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_unsigned/mod.rs @@ -47,6 +47,7 @@ macro_rules! create_gpu_parameterized_test{ create_gpu_parameterized_test!($name { PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128, TEST_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, }); };