diff --git a/Makefile b/Makefile index e688404d7..f00eebb84 100644 --- a/Makefile +++ b/Makefile @@ -284,7 +284,7 @@ check_typos: install_typos_checker .PHONY: clippy_gpu # Run clippy lints on tfhe with "gpu" enabled clippy_gpu: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \ - --features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats \ + --features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types \ --all-targets \ -p $(TFHE_SPEC) -- --no-deps -D warnings diff --git a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs index c0a2da167..1e8560a5e 100644 --- a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs @@ -13,7 +13,6 @@ use crate::high_level_api::traits::{ }; use crate::integer::bigint::{I1024, I2048, U1024, U2048}; use crate::integer::block_decomposition::DecomposableInto; -use crate::integer::ciphertext::IntegerCiphertext; use crate::integer::{I256, I512, U256, U512}; use crate::{FheBool, FheInt}; use std::ops::{ @@ -1694,25 +1693,21 @@ generic_integer_impl_scalar_left_operation!( rust_trait: Sub(sub), implem: { |lhs, rhs: &FheInt<_>| { - // `-` is not commutative, so we resort to converting to trivial - // which should give same perf - global_state::with_internal_keys(|key| match key { + global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { - let mut result = cpu_key - .pbs_key() - .create_trivial_radix(lhs, rhs.ciphertext.on_cpu().blocks().len()); - cpu_key - .pbs_key() - .sub_assign_parallelized(&mut result, &*rhs.ciphertext.on_cpu()); + let result = cpu_key.pbs_key().left_scalar_sub_parallelized(lhs, &*rhs.ciphertext.on_cpu()); SignedRadixCiphertext::Cpu(result) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_cuda_key) => { - with_thread_local_cuda_streams(|_stream| { - panic!("Cuda devices do not support subtracting a chiphertext to a clear") -// let mut result = cuda_key.key.key.create_signed_trivial_radix(lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams); -// cuda_key.key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams); -// RadixCiphertext::Cuda(result) + InternalServerKey::Cuda(cuda_key) => { + with_thread_local_cuda_streams(|streams| { + let mut result = cuda_key.pbs_key().create_trivial_radix( + lhs, + rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), + streams + ); + cuda_key.pbs_key().sub_assign(&mut result, &*rhs.ciphertext.on_gpu(streams), streams); + SignedRadixCiphertext::Cuda(result) }) } }) @@ -1749,25 +1744,21 @@ generic_integer_impl_scalar_left_operation!( rust_trait: Sub(sub), implem: { |lhs, rhs: &FheInt<_>| { - // `-` is not commutative, so we resort to converting to trivial - // which should give same perf global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { - let mut result = cpu_key - .pbs_key() - .create_trivial_radix(lhs, rhs.ciphertext.on_cpu().blocks().len()); - cpu_key - .pbs_key() - .sub_assign_parallelized(&mut result, &*rhs.ciphertext.on_cpu()); + let result = cpu_key.pbs_key().left_scalar_sub_parallelized(lhs, &*rhs.ciphertext.on_cpu()); SignedRadixCiphertext::Cpu(result) }, #[cfg(feature = "gpu")] - InternalServerKey::Cuda(_cuda_key) => { - with_thread_local_cuda_streams(|_stream| { - panic!("Cuda devices do not support subtracting a chiphertext to a clear") -// let mut result = cuda_key.key.key.create_signed_trivial_radix(lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams); -// cuda_key.key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams); -// RadixCiphertext::Cuda(result) + InternalServerKey::Cuda(cuda_key) => { + with_thread_local_cuda_streams(|streams| { + let mut result = cuda_key.pbs_key().create_trivial_radix( + lhs, + rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), + streams + ); + cuda_key.pbs_key().sub_assign(&mut result, &*rhs.ciphertext.on_gpu(streams), streams); + SignedRadixCiphertext::Cuda(result) }) } }) diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index 8e82ece70..2f1804120 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -17,7 +17,6 @@ use crate::high_level_api::traits::{ }; use crate::integer::bigint::{U1024, U2048, U512}; use crate::integer::block_decomposition::DecomposableInto; -use crate::integer::ciphertext::IntegerCiphertext; #[cfg(feature = "gpu")] use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; use crate::integer::U256; @@ -1883,24 +1882,17 @@ generic_integer_impl_scalar_left_operation!( rust_trait: Sub(sub), implem: { |lhs, rhs: &FheUint<_>| { - // `-` is not commutative, so we resort to converting to trivial - // which should give same perf - global_state::with_internal_keys(|key| match key { + global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { - let mut result = cpu_key - .pbs_key() - .create_trivial_radix(lhs, rhs.ciphertext.on_cpu().blocks().len()); - cpu_key - .pbs_key() - .sub_assign_parallelized(&mut result, &*rhs.ciphertext.on_cpu()); + let result = cpu_key.pbs_key().left_scalar_sub_parallelized(lhs, &*rhs.ciphertext.on_cpu()); RadixCiphertext::Cpu(result) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { - let mut result: CudaUnsignedRadixCiphertext = cuda_key.key.key.create_trivial_radix( + let mut result: CudaUnsignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix( lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams); - cuda_key.key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams); + cuda_key.pbs_key().sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams); RadixCiphertext::Cuda(result) }) } @@ -1938,24 +1930,17 @@ generic_integer_impl_scalar_left_operation!( rust_trait: Sub(sub), implem: { |lhs, rhs: &FheUint<_>| { - // `-` is not commutative, so we resort to converting to trivial - // which should give same perf global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(cpu_key) => { - let mut result = cpu_key - .pbs_key() - .create_trivial_radix(lhs, rhs.ciphertext.on_cpu().blocks().len()); - cpu_key - .pbs_key() - .sub_assign_parallelized(&mut result, &*rhs.ciphertext.on_cpu()); + let result = cpu_key.pbs_key().left_scalar_sub_parallelized(lhs, &*rhs.ciphertext.on_cpu()); RadixCiphertext::Cpu(result) }, #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { - let mut result: CudaUnsignedRadixCiphertext = cuda_key.key.key.create_trivial_radix( + let mut result: CudaUnsignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix( lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams); - cuda_key.key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams); + cuda_key.pbs_key().sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams); RadixCiphertext::Cuda(result) }) } diff --git a/tfhe/src/integer/server_key/radix/neg.rs b/tfhe/src/integer/server_key/radix/neg.rs index d76bef924..d5ff88af8 100644 --- a/tfhe/src/integer/server_key/radix/neg.rs +++ b/tfhe/src/integer/server_key/radix/neg.rs @@ -2,7 +2,6 @@ use crate::integer::ciphertext::IntegerRadixCiphertext; use crate::integer::server_key::CheckError; use crate::integer::ServerKey; use crate::shortint::ciphertext::{Degree, MaxDegree}; -#[cfg(test)] use crate::shortint::MessageModulus; /// Iterator that returns the new degree of blocks @@ -10,13 +9,11 @@ use crate::shortint::MessageModulus; /// /// It takes as input an iterator that returns the degree of the blocks /// before negation as well as their message modulus. -#[cfg(test)] pub(crate) struct NegatedDegreeIter { iter: I, z_b: u64, } -#[cfg(test)] impl NegatedDegreeIter where I: Iterator, @@ -26,7 +23,6 @@ where } } -#[cfg(test)] impl Iterator for NegatedDegreeIter where I: Iterator, diff --git a/tfhe/src/integer/server_key/radix/scalar_add.rs b/tfhe/src/integer/server_key/radix/scalar_add.rs index 3f413ae89..7e8a07389 100644 --- a/tfhe/src/integer/server_key/radix/scalar_add.rs +++ b/tfhe/src/integer/server_key/radix/scalar_add.rs @@ -3,6 +3,7 @@ use crate::integer::ciphertext::IntegerRadixCiphertext; use crate::integer::server_key::CheckError; use crate::integer::ServerKey; use crate::shortint::ciphertext::{Degree, MaxDegree}; +use crate::shortint::{CarryModulus, MessageModulus}; impl ServerKey { /// Computes homomorphically an addition between a scalar and a ciphertext. @@ -89,6 +90,22 @@ impl ServerKey { where T: DecomposableInto, C: IntegerRadixCiphertext, + { + let block_metadata_iter = ct + .blocks() + .iter() + .map(|b| (b.degree, b.message_modulus, b.carry_modulus)); + self.is_scalar_add_possible_impl(block_metadata_iter, scalar) + } + + pub(crate) fn is_scalar_add_possible_impl( + &self, + block_metadata: Iter, + scalar: T, + ) -> Result<(), CheckError> + where + T: DecomposableInto, + Iter: Iterator, { let bits_in_message = self.key.message_modulus.0.ilog2(); let decomposer = @@ -96,19 +113,16 @@ impl ServerKey { // Assumes message_modulus and carry_modulus matches between pairs of block let mut preceding_block_carry = Degree::new(0); - for (left_block, scalar_block_value) in ct.blocks().iter().zip(decomposer) { - let degree_after_add = left_block.degree + Degree::new(u64::from(scalar_block_value)); + for (block_metadata, scalar_block_value) in block_metadata.zip(decomposer) { + let (block_degree, block_msg_mod, block_carry_mod) = block_metadata; + let degree_after_add = block_degree + Degree::new(u64::from(scalar_block_value)); // Also need to take into account preceding_carry - let max_degree = MaxDegree::from_msg_carry_modulus( - left_block.message_modulus, - left_block.carry_modulus, - ); + let max_degree = MaxDegree::from_msg_carry_modulus(block_msg_mod, block_carry_mod); max_degree.validate(degree_after_add + preceding_block_carry)?; - preceding_block_carry = - Degree::new(degree_after_add.get() / left_block.message_modulus.0); + preceding_block_carry = Degree::new(degree_after_add.get() / block_msg_mod.0); } Ok(()) } diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs index c783985e6..755cd5430 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs @@ -1,8 +1,9 @@ use crate::core_crypto::prelude::{Cleartext, SignedNumeric, UnsignedNumeric}; use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; use crate::integer::ciphertext::IntegerRadixCiphertext; +use crate::integer::server_key::radix::neg::NegatedDegreeIter; use crate::integer::server_key::radix::scalar_sub::TwosComplementNegation; -use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey, SignedRadixCiphertext}; +use crate::integer::{BooleanBlock, CheckError, RadixCiphertext, ServerKey, SignedRadixCiphertext}; use crate::shortint::{Ciphertext, PaddingBit}; use rayon::prelude::*; @@ -114,6 +115,99 @@ impl ServerKey { self.scalar_add_assign_parallelized(ct, scalar.twos_complement_negation()); } + pub fn unchecked_left_scalar_sub(&self, scalar: Scalar, rhs: &T) -> T + where + Scalar: DecomposableInto, + T: IntegerRadixCiphertext, + { + // a - b <=> a + (-b) + let mut neg_rhs = self.unchecked_neg(rhs); + self.unchecked_scalar_add_assign(&mut neg_rhs, scalar); + neg_rhs + } + + pub fn is_left_scalar_sub_possible( + &self, + scalar: Scalar, + rhs: &T, + ) -> Result<(), CheckError> + where + Scalar: DecomposableInto, + T: IntegerRadixCiphertext, + { + // We do scalar - ct by doing scalar + (-ct) + // So we first have to check `-ct` is possible + // then that adding scalar to it is possible + self.is_neg_possible(rhs)?; + let neg_degree_iter = + NegatedDegreeIter::new(rhs.blocks().iter().map(|b| (b.degree, b.message_modulus))); + let block_metadata_iter = rhs + .blocks() + .iter() + .zip(neg_degree_iter) + .map(|(block, neg_degree)| (neg_degree, block.message_modulus, block.carry_modulus)); + + self.is_scalar_add_possible_impl(block_metadata_iter, scalar) + } + + pub fn smart_left_scalar_sub_parallelized(&self, scalar: Scalar, rhs: &mut T) -> T + where + Scalar: DecomposableInto, + T: IntegerRadixCiphertext, + { + if self.is_neg_possible(rhs).is_err() { + self.full_propagate_parallelized(rhs); + self.unchecked_left_scalar_sub(scalar, rhs) + } else { + // a - b <=> a + (-b) + let mut neg_rhs = self.unchecked_neg(rhs); + if self.is_scalar_add_possible(&neg_rhs, scalar).is_err() { + // since adding scalar does not increase the nose, only the + // degree can be problematic + self.full_propagate_parallelized(&mut neg_rhs); + } + self.unchecked_scalar_add_assign(&mut neg_rhs, scalar); + neg_rhs + } + } + + pub fn left_scalar_sub_parallelized(&self, scalar: Scalar, rhs: &T) -> T + where + Scalar: DecomposableInto, + T: IntegerRadixCiphertext, + { + if rhs.block_carries_are_empty() { + // a - b <=> a + (-b) <=> a + (!b + 1) <=> !b + a + 1 + let mut flipped_ct = self.bitnot(rhs); + let scalar_blocks = BlockDecomposer::with_block_count( + scalar, + self.message_modulus().0.ilog2(), + rhs.blocks().len(), + ) + .iter_as::() + .collect(); + let (input_carry, compute_overflow) = (true, false); + self.add_assign_scalar_blocks_parallelized( + &mut flipped_ct, + scalar_blocks, + input_carry, + compute_overflow, + ); + flipped_ct + } else { + // We could clone rhs and full_propagate, then do the same thing as when the + // rhs's carries are clean. This would cost 2 full_propagate, but the second one + // would be less expensive because it happens in a scalar_add. + // + // However, we chose to all the smart version on the cloned_rhs, as maybe the carries + // are not so bad that the smart version will be able to avoid the first full_prop + let mut tmp_rhs = rhs.clone(); + let mut res = self.smart_left_scalar_sub_parallelized(scalar, &mut tmp_rhs); + self.full_propagate_parallelized(&mut res); + res + } + } + pub fn unsigned_overflowing_scalar_sub_assign_parallelized( &self, lhs: &mut RadixCiphertext, diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_sub.rs b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_sub.rs index 99603fdae..d9695f323 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_signed/test_scalar_sub.rs @@ -1,11 +1,11 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::server_key::radix_parallel::tests_signed::{ - random_non_zero_value, signed_overflowing_add_under_modulus, + random_non_zero_value, signed_add_under_modulus, signed_overflowing_add_under_modulus, signed_overflowing_sub_under_modulus, signed_sub_under_modulus, NB_CTXT, }; use crate::integer::server_key::radix_parallel::tests_unsigned::{ - nb_tests_for_params, nb_tests_smaller_for_params, CpuFunctionExecutor, + nb_tests_for_params, nb_tests_smaller_for_params, CpuFunctionExecutor, MAX_NB_CTXT, }; use crate::integer::tests::create_parameterized_test; use crate::integer::{ @@ -20,6 +20,9 @@ use std::sync::Arc; create_parameterized_test!(integer_signed_unchecked_scalar_sub); create_parameterized_test!(integer_signed_default_overflowing_scalar_sub); +create_parameterized_test!(integer_signed_unchecked_left_scalar_sub); +create_parameterized_test!(integer_signed_smart_left_scalar_sub); +create_parameterized_test!(integer_signed_default_left_scalar_sub); fn integer_signed_unchecked_scalar_sub

(param: P) where @@ -36,6 +39,31 @@ where let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_sub_parallelized); signed_default_overflowing_scalar_sub_test(param, executor); } + +fn integer_signed_unchecked_left_scalar_sub

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_left_scalar_sub); + signed_unchecked_left_scalar_sub_test(param, executor); +} + +fn integer_signed_smart_left_scalar_sub

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::smart_left_scalar_sub_parallelized); + signed_smart_left_scalar_sub_test(param, executor); +} + +fn integer_signed_default_left_scalar_sub

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::left_scalar_sub_parallelized); + signed_default_left_scalar_sub_test(param, executor); +} + pub(crate) fn signed_unchecked_scalar_sub_test(param: P, mut executor: T) where P: Into, @@ -261,3 +289,159 @@ where assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO); } } + +pub(crate) fn signed_unchecked_left_scalar_sub_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(i64, &'a SignedRadixCiphertext), SignedRadixCiphertext>, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks.clone()); + + let cks: crate::integer::ClientKey = cks.into(); + + for num_blocks in 1..MAX_NB_CTXT { + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64; + if modulus <= 1 { + continue; + } + + for _ in 0..nb_tests { + let clear_lhs = rng.gen::() % modulus; + let mut clear_rhs = rng.gen::() % modulus; + + let mut ct_rhs = cks.encrypt_signed_radix(clear_rhs, num_blocks); + + ct_rhs = executor.execute((clear_lhs, &ct_rhs)); + clear_rhs = signed_sub_under_modulus(clear_lhs, clear_rhs, modulus); + + let dec_res: i64 = cks.decrypt_signed_radix(&ct_rhs); + assert_eq!(dec_res, clear_rhs); + + let mut clear_lhs = rng.gen::() % modulus; + while sks.is_left_scalar_sub_possible(clear_lhs, &ct_rhs).is_ok() { + ct_rhs = executor.execute((clear_lhs, &ct_rhs)); + clear_rhs = signed_sub_under_modulus(clear_lhs, clear_rhs, modulus); + let dec_res: i64 = cks.decrypt_signed_radix(&ct_rhs); + assert_eq!(dec_res, clear_rhs); + clear_lhs = rng.gen::() % modulus; + } + } + } +} + +pub(crate) fn signed_smart_left_scalar_sub_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(i64, &'a mut SignedRadixCiphertext), SignedRadixCiphertext>, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks); + + let cks: crate::integer::ClientKey = cks.into(); + + for num_blocks in 1..MAX_NB_CTXT { + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64; + if modulus <= 1 { + continue; + } + + let clear_lhs = rng.gen::() % modulus; + let mut clear_rhs = rng.gen::() % modulus; + + let mut ct_rhs = cks.encrypt_signed_radix(clear_rhs, num_blocks); + + ct_rhs = executor.execute((clear_lhs, &mut ct_rhs)); + clear_rhs = signed_sub_under_modulus(clear_lhs, clear_rhs, modulus); + + let dec_res: i64 = cks.decrypt_signed_radix(&ct_rhs); + assert_eq!(dec_res, clear_rhs); + for _ in 0..nb_tests { + let clear_lhs = rng.gen::() % modulus; + + ct_rhs = executor.execute((clear_lhs, &mut ct_rhs)); + clear_rhs = signed_sub_under_modulus(clear_lhs, clear_rhs, modulus); + let dec_res: i64 = cks.decrypt_signed_radix(&ct_rhs); + assert_eq!(dec_res, clear_rhs); + } + } +} + +pub(crate) fn signed_default_left_scalar_sub_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(i64, &'a SignedRadixCiphertext), SignedRadixCiphertext>, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = rand::thread_rng(); + + executor.setup(&cks, sks.clone()); + + let cks: crate::integer::ClientKey = cks.into(); + + for num_blocks in 1..MAX_NB_CTXT { + // message_modulus^vec_length + let modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64; + if modulus <= 1 { + continue; + } + + for _ in 0..nb_tests { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_1 = cks.encrypt_signed_radix(clear_1, num_blocks); + + let ct_res = executor.execute((clear_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + + let tmp = executor.execute((clear_0, &ctxt_1)); + assert_eq!(ct_res, tmp, "Operation is not deterministic"); + + let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); + assert_eq!(dec_res, signed_sub_under_modulus(clear_0, clear_1, modulus)); + + let non_zero = random_non_zero_value(&mut rng, modulus); + let non_clean = sks.unchecked_scalar_add(&ctxt_1, non_zero); + let ct_res = executor.execute((clear_0, &non_clean)); + assert!(ct_res.block_carries_are_empty()); + let dec_res: i64 = cks.decrypt_signed_radix(&ct_res); + let expected = signed_sub_under_modulus( + clear_0, + signed_add_under_modulus(clear_1, non_zero, modulus), + modulus, + ); + assert_eq!(dec_res, expected); + + let ct_res2 = executor.execute((clear_0, &non_clean)); + assert_eq!(ct_res, ct_res2, "Failed determinism check"); + } + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_sub.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_sub.rs index 6b2e2b002..9c0e12e9c 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_scalar_sub.rs @@ -1,16 +1,28 @@ +use std::sync::Arc; + +use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{ default_overflowing_scalar_sub_test, default_scalar_sub_test, smart_scalar_sub_test, + FunctionExecutor, +}; +use crate::integer::server_key::radix_parallel::tests_unsigned::{ + nb_tests_for_params, random_non_zero_value, CpuFunctionExecutor, }; -use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; use crate::integer::tests::create_parameterized_test; -use crate::integer::ServerKey; +use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey}; #[cfg(tarpaulin)] use crate::shortint::parameters::coverage_parameters::*; use crate::shortint::parameters::test_params::*; use crate::shortint::parameters::*; +use rand::prelude::*; + +use super::{MAX_NB_CTXT, NB_CTXT}; create_parameterized_test!(integer_smart_scalar_sub); create_parameterized_test!(integer_default_scalar_sub); +create_parameterized_test!(integer_unchecked_left_scalar_sub); +create_parameterized_test!(integer_smart_left_scalar_sub); +create_parameterized_test!(integer_default_left_scalar_sub); create_parameterized_test!(integer_default_overflowing_scalar_sub); fn integer_smart_scalar_sub

(param: P) @@ -29,6 +41,30 @@ where default_scalar_sub_test(param, executor); } +fn integer_unchecked_left_scalar_sub

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_left_scalar_sub); + unchecked_left_scalar_sub_test(param, executor); +} + +fn integer_smart_left_scalar_sub

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::smart_left_scalar_sub_parallelized); + smart_left_scalar_sub_test(param, executor); +} + +fn integer_default_left_scalar_sub

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::left_scalar_sub_parallelized); + default_left_scalar_sub_test(param, executor); +} + fn integer_default_overflowing_scalar_sub

(param: P) where P: Into, @@ -37,3 +73,148 @@ where CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_sub_parallelized); default_overflowing_scalar_sub_test(param, executor); } + +pub(crate) fn unchecked_left_scalar_sub_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(u64, &'a RadixCiphertext), RadixCiphertext>, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = thread_rng(); + + executor.setup(&cks, sks.clone()); + + let cks: crate::integer::ClientKey = cks.into(); + + for num_blocks in 1..MAX_NB_CTXT { + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32); + + for _ in 0..nb_tests { + let clear_lhs = rng.gen::() % modulus; + let mut clear_rhs = rng.gen::() % modulus; + + let mut ct_rhs = cks.encrypt_radix(clear_rhs, num_blocks); + + ct_rhs = executor.execute((clear_lhs, &ct_rhs)); + clear_rhs = clear_lhs.wrapping_sub(clear_rhs) % modulus; + + let dec_res: u64 = cks.decrypt_radix(&ct_rhs); + assert_eq!(dec_res, clear_rhs); + + let mut clear_lhs = rng.gen::() % modulus; + while sks.is_left_scalar_sub_possible(clear_lhs, &ct_rhs).is_ok() { + ct_rhs = executor.execute((clear_lhs, &ct_rhs)); + clear_rhs = clear_lhs.wrapping_sub(clear_rhs) % modulus; + let dec_res: u64 = cks.decrypt_radix(&ct_rhs); + assert_eq!(dec_res, clear_rhs); + clear_lhs = rng.gen::() % modulus; + } + } + } +} + +pub(crate) fn smart_left_scalar_sub_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(u64, &'a mut RadixCiphertext), RadixCiphertext>, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = thread_rng(); + + executor.setup(&cks, sks); + + let cks: crate::integer::ClientKey = cks.into(); + + for num_blocks in 1..MAX_NB_CTXT { + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32); + + let clear_lhs = rng.gen::() % modulus; + let mut clear_rhs = rng.gen::() % modulus; + + let mut ct_rhs = cks.encrypt_radix(clear_rhs, num_blocks); + + ct_rhs = executor.execute((clear_lhs, &mut ct_rhs)); + clear_rhs = clear_lhs.wrapping_sub(clear_rhs) % modulus; + + let dec_res: u64 = cks.decrypt_radix(&ct_rhs); + assert_eq!(dec_res, clear_rhs); + for _ in 0..nb_tests { + let clear_lhs = rng.gen::() % modulus; + + ct_rhs = executor.execute((clear_lhs, &mut ct_rhs)); + clear_rhs = clear_lhs.wrapping_sub(clear_rhs) % modulus; + let dec_res: u64 = cks.decrypt_radix(&ct_rhs); + assert_eq!(dec_res, clear_rhs); + } + } +} + +pub(crate) fn default_left_scalar_sub_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(u64, &'a RadixCiphertext), RadixCiphertext>, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + + let mut rng = thread_rng(); + + executor.setup(&cks, sks.clone()); + + let cks: crate::integer::ClientKey = cks.into(); + + for num_blocks in 1..MAX_NB_CTXT { + // message_modulus^vec_length + let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32); + + for _ in 0..nb_tests { + let clear_0 = rng.gen::() % modulus; + let clear_1 = rng.gen::() % modulus; + + let ctxt_1 = cks.encrypt_radix(clear_1, num_blocks); + + let ct_res = executor.execute((clear_0, &ctxt_1)); + assert!(ct_res.block_carries_are_empty()); + + let tmp = executor.execute((clear_0, &ctxt_1)); + assert_eq!(ct_res, tmp, "Operation is not deterministic"); + + let dec_res: u64 = cks.decrypt_radix(&ct_res); + assert_eq!(dec_res, (clear_0.wrapping_sub(clear_1)) % modulus); + + let non_zero = random_non_zero_value(&mut rng, modulus); + let non_clean = sks.unchecked_scalar_add(&ctxt_1, non_zero); + let ct_res = executor.execute((clear_0, &non_clean)); + assert!(ct_res.block_carries_are_empty()); + let dec_res: u64 = cks.decrypt_radix(&ct_res); + assert_eq!( + dec_res, + (clear_0.wrapping_sub(clear_1.wrapping_add(non_zero))) % modulus + ); + + let ct_res2 = executor.execute((clear_0, &non_clean)); + assert_eq!(ct_res, ct_res2, "Failed determinism check"); + } + } +}