From 95ef13f6ce5a886607be6140e252dede80d21941 Mon Sep 17 00:00:00 2001 From: Beka Barbakadze Date: Fri, 19 Jul 2024 13:28:21 +0400 Subject: [PATCH] feat(gpu): Add signed_overflowing_scalar_add and signed_overflowing_scalar_sub --- tfhe/benches/integer/signed_bench.rs | 14 +++ .../integers/signed/overflowing_ops.rs | 22 +++-- .../gpu/server_key/radix/scalar_add.rs | 86 +++++++++++++++++- .../gpu/server_key/radix/scalar_sub.rs | 88 ++++++++++++++++++- .../gpu/server_key/radix/tests_signed/mod.rs | 37 ++++++++ .../radix/tests_signed/test_scalar_add.rs | 12 ++- .../radix/tests_signed/test_scalar_sub.rs | 13 ++- .../tests_signed/test_scalar_add.rs | 2 +- 8 files changed, 261 insertions(+), 13 deletions(-) diff --git a/tfhe/benches/integer/signed_bench.rs b/tfhe/benches/integer/signed_bench.rs index 579f11faa..1967b3a75 100644 --- a/tfhe/benches/integer/signed_bench.rs +++ b/tfhe/benches/integer/signed_bench.rs @@ -1801,6 +1801,18 @@ mod cuda { rng_func: default_signed_scalar ); + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: signed_overflowing_scalar_add, + display_name: overflowing_add, + rng_func: default_signed_scalar + ); + + define_cuda_server_key_bench_clean_input_scalar_signed_fn!( + method_name: signed_overflowing_scalar_sub, + display_name: overflowing_sub, + rng_func: default_signed_scalar + ); + //=========================================== // Default //=========================================== @@ -2132,6 +2144,8 @@ mod cuda { cuda_scalar_le, cuda_scalar_min, cuda_scalar_max, + cuda_signed_overflowing_scalar_add, + cuda_signed_overflowing_scalar_sub, ); fn cuda_bench_server_key_signed_cast_function( diff --git a/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs b/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs index 436354f14..e22ecf1d8 100644 --- a/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs +++ b/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs @@ -140,9 +140,14 @@ where (FheInt::new(result), FheBool::new(overflow)) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - todo!("Cuda devices do not support signed integer"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let (result, overflow) = cuda_key.key.signed_overflowing_scalar_add( + &self.ciphertext.on_gpu(), + other, + streams, + ); + (FheInt::new(result), FheBool::new(overflow)) + }), }) } } @@ -349,9 +354,14 @@ where (FheInt::new(result), FheBool::new(overflow)) } #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_) => { - todo!("Cuda devices do not support signed integer"); - } + InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { + let (result, overflow) = cuda_key.key.signed_overflowing_scalar_sub( + &self.ciphertext.on_gpu(), + other, + streams, + ); + (FheInt::new(result), FheBool::new(overflow)) + }), }) } } diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs index c88f75627..12af81d43 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_add.rs @@ -1,8 +1,11 @@ use crate::core_crypto::gpu::vec::CudaVec; use crate::core_crypto::gpu::CudaStreams; +use crate::core_crypto::prelude::SignedNumeric; use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; -use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; +use crate::integer::gpu::ciphertext::{ + CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext, +}; use crate::integer::gpu::scalar_addition_integer_radix_assign_async; use crate::integer::gpu::server_key::CudaServerKey; use crate::prelude::CastInto; @@ -279,4 +282,85 @@ impl CudaServerKey { CudaBooleanBlock::from_cuda_radix_ciphertext(carry_out.ciphertext) } } + + /// ```rust + /// use tfhe::core_crypto::gpu::CudaStreams; + /// use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext}; + /// use tfhe::integer::gpu::gen_keys_radix_gpu; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let gpu_index = 0; + /// let streams = CudaStreams::new_single_gpu(gpu_index); + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &streams); + /// + /// let msg: i8 = 120; + /// let scalar: i8 = 8; + /// + /// let ct1 = cks.encrypt_signed(msg); + /// + /// // Copy to GPU + /// let d_ct1 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct1, &streams); + /// + /// // Compute homomorphically an overflowing addition: + /// let (d_ct_res, d_ct_overflowed) = sks.signed_overflowing_scalar_add(&d_ct1, scalar, &streams); + /// + /// let ct_res = d_ct_res.to_signed_radix_ciphertext(&streams); + /// let ct_overflowed = d_ct_overflowed.to_boolean_block(&streams); + /// + /// // Decrypt: + /// let dec_result: i8 = cks.decrypt_signed(&ct_res); + /// let dec_overflowed: bool = cks.decrypt_bool(&ct_overflowed); + /// let (clear_result, clear_overflowed) = msg.overflowing_add(scalar); + /// assert_eq!(dec_result, clear_result); + /// assert_eq!(dec_overflowed, clear_overflowed); + /// ``` + pub fn signed_overflowing_scalar_add( + &self, + ct_left: &CudaSignedRadixCiphertext, + scalar: Scalar, + streams: &CudaStreams, + ) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) + where + Scalar: SignedNumeric + DecomposableInto + CastInto, + { + let mut tmp_lhs; + unsafe { + tmp_lhs = ct_left.duplicate_async(streams); + if !tmp_lhs.block_carries_are_empty() { + self.full_propagate_assign_async(&mut tmp_lhs, streams); + } + } + + let trivial: CudaSignedRadixCiphertext = self.create_trivial_radix( + scalar, + ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0, + streams, + ); + let (result, overflowed) = self.signed_overflowing_add(&tmp_lhs, &trivial, streams); + + let mut extra_scalar_block_iter = + BlockDecomposer::new(scalar, self.message_modulus.0.ilog2()) + .iter_as::() + .skip(ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0); + + let extra_blocks_have_correct_value = if scalar < Scalar::ZERO { + extra_scalar_block_iter.all(|block| block == (self.message_modulus.0 as u64 - 1)) + } else { + extra_scalar_block_iter.all(|block| block == 0) + }; + + if extra_blocks_have_correct_value { + (result, overflowed) + } else { + let trivial_one: CudaSignedRadixCiphertext = self.create_trivial_radix(1, 1, streams); + // Scalar has more blocks so addition counts as overflowing + ( + result, + CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_one.ciphertext), + ) + } + } } diff --git a/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs b/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs index cf68e3d18..41abd9c6f 100644 --- a/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs +++ b/tfhe/src/integer/gpu/server_key/radix/scalar_sub.rs @@ -1,7 +1,8 @@ use crate::core_crypto::gpu::CudaStreams; -use crate::core_crypto::prelude::Numeric; -use crate::integer::block_decomposition::DecomposableInto; -use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext; +use crate::core_crypto::prelude::{Numeric, SignedNumeric}; +use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; +use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; +use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext}; use crate::integer::gpu::server_key::CudaServerKey; use crate::integer::server_key::TwosComplementNegation; use crate::prelude::CastInto; @@ -163,4 +164,85 @@ impl CudaServerKey { } stream.synchronize(); } + + /// ```rust + /// use tfhe::core_crypto::gpu::CudaStreams; + /// use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext}; + /// use tfhe::integer::gpu::gen_keys_radix_gpu; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// let gpu_index = 0; + /// let streams = CudaStreams::new_single_gpu(gpu_index); + /// + /// // Generate the client key and the server key: + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &streams); + /// + /// let msg: i8 = 120; + /// let scalar: i8 = 8; + /// + /// let ct1 = cks.encrypt_signed(msg); + /// + /// // Copy to GPU + /// let d_ct1 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct1, &streams); + /// + /// // Compute homomorphically an overflowing addition: + /// let (d_ct_res, d_ct_overflowed) = sks.signed_overflowing_scalar_sub(&d_ct1, scalar, &streams); + /// + /// let ct_res = d_ct_res.to_signed_radix_ciphertext(&streams); + /// let ct_overflowed = d_ct_overflowed.to_boolean_block(&streams); + /// + /// // Decrypt: + /// let dec_result: i8 = cks.decrypt_signed(&ct_res); + /// let dec_overflowed: bool = cks.decrypt_bool(&ct_overflowed); + /// let (clear_result, clear_overflowed) = msg.overflowing_sub(scalar); + /// assert_eq!(dec_result, clear_result); + /// assert_eq!(dec_overflowed, clear_overflowed); + /// ``` + pub fn signed_overflowing_scalar_sub( + &self, + ct_left: &CudaSignedRadixCiphertext, + scalar: Scalar, + streams: &CudaStreams, + ) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) + where + Scalar: SignedNumeric + DecomposableInto + CastInto, + { + let mut tmp_lhs; + unsafe { + tmp_lhs = ct_left.duplicate_async(streams); + if !tmp_lhs.block_carries_are_empty() { + self.full_propagate_assign_async(&mut tmp_lhs, streams); + } + } + + let trivial: CudaSignedRadixCiphertext = self.create_trivial_radix( + scalar, + ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0, + streams, + ); + let (result, overflowed) = self.signed_overflowing_sub(&tmp_lhs, &trivial, streams); + + let mut extra_scalar_block_iter = + BlockDecomposer::new(scalar, self.message_modulus.0.ilog2()) + .iter_as::() + .skip(ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0); + + let extra_blocks_have_correct_value = if scalar < Scalar::ZERO { + extra_scalar_block_iter.all(|block| block == (self.message_modulus.0 as u64 - 1)) + } else { + extra_scalar_block_iter.all(|block| block == 0) + }; + + if extra_blocks_have_correct_value { + (result, overflowed) + } else { + let trivial_one: CudaSignedRadixCiphertext = self.create_trivial_radix(1, 1, streams); + // Scalar has more blocks so addition counts as overflowing + ( + result, + CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_one.ciphertext), + ) + } + } } diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs index 1279e6d72..bffdec8a5 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/mod.rs @@ -389,3 +389,40 @@ where ) } } + +// for signed overflowing scalar ops +impl<'a, F> + FunctionExecutor<(&'a SignedRadixCiphertext, i64), (SignedRadixCiphertext, BooleanBlock)> + for GpuFunctionExecutor +where + F: Fn( + &CudaServerKey, + &CudaSignedRadixCiphertext, + i64, + &CudaStreams, + ) -> (CudaSignedRadixCiphertext, CudaBooleanBlock), +{ + fn setup(&mut self, cks: &RadixClientKey, sks: Arc) { + self.setup_from_keys(cks, &sks); + } + + fn execute( + &mut self, + input: (&'a SignedRadixCiphertext, i64), + ) -> (SignedRadixCiphertext, BooleanBlock) { + let context = self + .context + .as_ref() + .expect("setup was not properly called"); + + let d_ctxt_1: CudaSignedRadixCiphertext = + CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams); + + let (d_res, d_res_bool) = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams); + + ( + d_res.to_signed_radix_ciphertext(&context.streams), + d_res_bool.to_boolean_block(&context.streams), + ) + } +} diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_add.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_add.rs index 7303db2c5..f78837905 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_add.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_add.rs @@ -3,12 +3,14 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{ }; use crate::integer::gpu::CudaServerKey; use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_add::{ - signed_default_scalar_add_test, signed_unchecked_scalar_add_test, + signed_default_overflowing_scalar_add_test, signed_default_scalar_add_test, + signed_unchecked_scalar_add_test, }; use crate::shortint::parameters::*; create_gpu_parametrized_test!(integer_signed_unchecked_scalar_add); create_gpu_parametrized_test!(integer_signed_scalar_add); +create_gpu_parametrized_test!(integer_signed_overflowing_scalar_add); fn integer_signed_unchecked_scalar_add

(param: P) where @@ -25,3 +27,11 @@ where let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_add); signed_default_scalar_add_test(param, executor); } + +fn integer_signed_overflowing_scalar_add

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::signed_overflowing_scalar_add); + signed_default_overflowing_scalar_add_test(param, executor); +} diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_sub.rs b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_sub.rs index b6655ca1d..5020a05ce 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_sub.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_signed/test_scalar_sub.rs @@ -2,10 +2,13 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{ create_gpu_parametrized_test, GpuFunctionExecutor, }; use crate::integer::gpu::CudaServerKey; -use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_sub::signed_unchecked_scalar_sub_test; +use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_sub::{ + signed_default_overflowing_scalar_sub_test, signed_unchecked_scalar_sub_test, +}; use crate::shortint::parameters::*; create_gpu_parametrized_test!(integer_signed_unchecked_scalar_sub); +create_gpu_parametrized_test!(integer_signed_overflowing_scalar_sub); fn integer_signed_unchecked_scalar_sub

(param: P) where @@ -14,3 +17,11 @@ where let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_sub); signed_unchecked_scalar_sub_test(param, executor); } + +fn integer_signed_overflowing_scalar_sub

(param: P) +where + P: Into, +{ + let executor = GpuFunctionExecutor::new(&CudaServerKey::signed_overflowing_scalar_sub); + signed_default_overflowing_scalar_sub_test(param, executor); +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_add.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_add.rs index 6a7cf1e15..a47e0aee0 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_add.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_add.rs @@ -195,7 +195,7 @@ where let ctxt_0 = cks.encrypt_signed(clear_0); let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1)); - let (tmp_ct, tmp_o) = sks.signed_overflowing_scalar_add_parallelized(&ctxt_0, clear_1); + let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, clear_1)); assert!(ct_res.block_carries_are_empty()); assert_eq!(ct_res, tmp_ct, "Failed determinism check"); assert_eq!(tmp_o, result_overflowed, "Failed determinism check");