mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
feat: add clear - ciphertex
Add integer and hlapi function to perform `clear - ciphertext` As subtraction is not commutative having a specialized version is better. As can be seen from the code, the real benefit is for the default version where the cost of `clear - ciphertext` is the same as `clear + ciphertex` which is better that transforming the clear into a trivial ciphertext to perform the subtract algorithm
This commit is contained in:
2
Makefile
2
Makefile
@@ -284,7 +284,7 @@ check_typos: install_typos_checker
|
||||
.PHONY: clippy_gpu # Run clippy lints on tfhe with "gpu" enabled
|
||||
clippy_gpu: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats \
|
||||
--features=boolean,shortint,integer,internal-keycache,gpu,pbs-stats,extended-types \
|
||||
--all-targets \
|
||||
-p $(TFHE_SPEC) -- --no-deps -D warnings
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ use crate::high_level_api::traits::{
|
||||
};
|
||||
use crate::integer::bigint::{I1024, I2048, U1024, U2048};
|
||||
use crate::integer::block_decomposition::DecomposableInto;
|
||||
use crate::integer::ciphertext::IntegerCiphertext;
|
||||
use crate::integer::{I256, I512, U256, U512};
|
||||
use crate::{FheBool, FheInt};
|
||||
use std::ops::{
|
||||
@@ -1694,25 +1693,21 @@ generic_integer_impl_scalar_left_operation!(
|
||||
rust_trait: Sub(sub),
|
||||
implem: {
|
||||
|lhs, rhs: &FheInt<_>| {
|
||||
// `-` is not commutative, so we resort to converting to trivial
|
||||
// which should give same perf
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(cpu_key) => {
|
||||
let mut result = cpu_key
|
||||
.pbs_key()
|
||||
.create_trivial_radix(lhs, rhs.ciphertext.on_cpu().blocks().len());
|
||||
cpu_key
|
||||
.pbs_key()
|
||||
.sub_assign_parallelized(&mut result, &*rhs.ciphertext.on_cpu());
|
||||
let result = cpu_key.pbs_key().left_scalar_sub_parallelized(lhs, &*rhs.ciphertext.on_cpu());
|
||||
SignedRadixCiphertext::Cpu(result)
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(_cuda_key) => {
|
||||
with_thread_local_cuda_streams(|_stream| {
|
||||
panic!("Cuda devices do not support subtracting a chiphertext to a clear")
|
||||
// let mut result = cuda_key.key.key.create_signed_trivial_radix(lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams);
|
||||
// cuda_key.key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams);
|
||||
// RadixCiphertext::Cuda(result)
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let mut result = cuda_key.pbs_key().create_trivial_radix(
|
||||
lhs,
|
||||
rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(),
|
||||
streams
|
||||
);
|
||||
cuda_key.pbs_key().sub_assign(&mut result, &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
SignedRadixCiphertext::Cuda(result)
|
||||
})
|
||||
}
|
||||
})
|
||||
@@ -1749,25 +1744,21 @@ generic_integer_impl_scalar_left_operation!(
|
||||
rust_trait: Sub(sub),
|
||||
implem: {
|
||||
|lhs, rhs: &FheInt<_>| {
|
||||
// `-` is not commutative, so we resort to converting to trivial
|
||||
// which should give same perf
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(cpu_key) => {
|
||||
let mut result = cpu_key
|
||||
.pbs_key()
|
||||
.create_trivial_radix(lhs, rhs.ciphertext.on_cpu().blocks().len());
|
||||
cpu_key
|
||||
.pbs_key()
|
||||
.sub_assign_parallelized(&mut result, &*rhs.ciphertext.on_cpu());
|
||||
let result = cpu_key.pbs_key().left_scalar_sub_parallelized(lhs, &*rhs.ciphertext.on_cpu());
|
||||
SignedRadixCiphertext::Cpu(result)
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(_cuda_key) => {
|
||||
with_thread_local_cuda_streams(|_stream| {
|
||||
panic!("Cuda devices do not support subtracting a chiphertext to a clear")
|
||||
// let mut result = cuda_key.key.key.create_signed_trivial_radix(lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams);
|
||||
// cuda_key.key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams);
|
||||
// RadixCiphertext::Cuda(result)
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let mut result = cuda_key.pbs_key().create_trivial_radix(
|
||||
lhs,
|
||||
rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(),
|
||||
streams
|
||||
);
|
||||
cuda_key.pbs_key().sub_assign(&mut result, &*rhs.ciphertext.on_gpu(streams), streams);
|
||||
SignedRadixCiphertext::Cuda(result)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
@@ -17,7 +17,6 @@ use crate::high_level_api::traits::{
|
||||
};
|
||||
use crate::integer::bigint::{U1024, U2048, U512};
|
||||
use crate::integer::block_decomposition::DecomposableInto;
|
||||
use crate::integer::ciphertext::IntegerCiphertext;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
|
||||
use crate::integer::U256;
|
||||
@@ -1883,24 +1882,17 @@ generic_integer_impl_scalar_left_operation!(
|
||||
rust_trait: Sub(sub),
|
||||
implem: {
|
||||
|lhs, rhs: &FheUint<_>| {
|
||||
// `-` is not commutative, so we resort to converting to trivial
|
||||
// which should give same perf
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(cpu_key) => {
|
||||
let mut result = cpu_key
|
||||
.pbs_key()
|
||||
.create_trivial_radix(lhs, rhs.ciphertext.on_cpu().blocks().len());
|
||||
cpu_key
|
||||
.pbs_key()
|
||||
.sub_assign_parallelized(&mut result, &*rhs.ciphertext.on_cpu());
|
||||
let result = cpu_key.pbs_key().left_scalar_sub_parallelized(lhs, &*rhs.ciphertext.on_cpu());
|
||||
RadixCiphertext::Cpu(result)
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let mut result: CudaUnsignedRadixCiphertext = cuda_key.key.key.create_trivial_radix(
|
||||
let mut result: CudaUnsignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix(
|
||||
lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams);
|
||||
cuda_key.key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams);
|
||||
cuda_key.pbs_key().sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams);
|
||||
RadixCiphertext::Cuda(result)
|
||||
})
|
||||
}
|
||||
@@ -1938,24 +1930,17 @@ generic_integer_impl_scalar_left_operation!(
|
||||
rust_trait: Sub(sub),
|
||||
implem: {
|
||||
|lhs, rhs: &FheUint<_>| {
|
||||
// `-` is not commutative, so we resort to converting to trivial
|
||||
// which should give same perf
|
||||
global_state::with_internal_keys(|key| match key {
|
||||
InternalServerKey::Cpu(cpu_key) => {
|
||||
let mut result = cpu_key
|
||||
.pbs_key()
|
||||
.create_trivial_radix(lhs, rhs.ciphertext.on_cpu().blocks().len());
|
||||
cpu_key
|
||||
.pbs_key()
|
||||
.sub_assign_parallelized(&mut result, &*rhs.ciphertext.on_cpu());
|
||||
let result = cpu_key.pbs_key().left_scalar_sub_parallelized(lhs, &*rhs.ciphertext.on_cpu());
|
||||
RadixCiphertext::Cpu(result)
|
||||
},
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
with_thread_local_cuda_streams(|streams| {
|
||||
let mut result: CudaUnsignedRadixCiphertext = cuda_key.key.key.create_trivial_radix(
|
||||
let mut result: CudaUnsignedRadixCiphertext = cuda_key.pbs_key().create_trivial_radix(
|
||||
lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams);
|
||||
cuda_key.key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams);
|
||||
cuda_key.pbs_key().sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams);
|
||||
RadixCiphertext::Cuda(result)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ use crate::integer::ciphertext::IntegerRadixCiphertext;
|
||||
use crate::integer::server_key::CheckError;
|
||||
use crate::integer::ServerKey;
|
||||
use crate::shortint::ciphertext::{Degree, MaxDegree};
|
||||
#[cfg(test)]
|
||||
use crate::shortint::MessageModulus;
|
||||
|
||||
/// Iterator that returns the new degree of blocks
|
||||
@@ -10,13 +9,11 @@ use crate::shortint::MessageModulus;
|
||||
///
|
||||
/// It takes as input an iterator that returns the degree of the blocks
|
||||
/// before negation as well as their message modulus.
|
||||
#[cfg(test)]
|
||||
pub(crate) struct NegatedDegreeIter<I> {
|
||||
iter: I,
|
||||
z_b: u64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl<I> NegatedDegreeIter<I>
|
||||
where
|
||||
I: Iterator<Item = (Degree, MessageModulus)>,
|
||||
@@ -26,7 +23,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl<I> Iterator for NegatedDegreeIter<I>
|
||||
where
|
||||
I: Iterator<Item = (Degree, MessageModulus)>,
|
||||
|
||||
@@ -3,6 +3,7 @@ use crate::integer::ciphertext::IntegerRadixCiphertext;
|
||||
use crate::integer::server_key::CheckError;
|
||||
use crate::integer::ServerKey;
|
||||
use crate::shortint::ciphertext::{Degree, MaxDegree};
|
||||
use crate::shortint::{CarryModulus, MessageModulus};
|
||||
|
||||
impl ServerKey {
|
||||
/// Computes homomorphically an addition between a scalar and a ciphertext.
|
||||
@@ -89,6 +90,22 @@ impl ServerKey {
|
||||
where
|
||||
T: DecomposableInto<u8>,
|
||||
C: IntegerRadixCiphertext,
|
||||
{
|
||||
let block_metadata_iter = ct
|
||||
.blocks()
|
||||
.iter()
|
||||
.map(|b| (b.degree, b.message_modulus, b.carry_modulus));
|
||||
self.is_scalar_add_possible_impl(block_metadata_iter, scalar)
|
||||
}
|
||||
|
||||
pub(crate) fn is_scalar_add_possible_impl<T, Iter>(
|
||||
&self,
|
||||
block_metadata: Iter,
|
||||
scalar: T,
|
||||
) -> Result<(), CheckError>
|
||||
where
|
||||
T: DecomposableInto<u8>,
|
||||
Iter: Iterator<Item = (Degree, MessageModulus, CarryModulus)>,
|
||||
{
|
||||
let bits_in_message = self.key.message_modulus.0.ilog2();
|
||||
let decomposer =
|
||||
@@ -96,19 +113,16 @@ impl ServerKey {
|
||||
|
||||
// Assumes message_modulus and carry_modulus matches between pairs of block
|
||||
let mut preceding_block_carry = Degree::new(0);
|
||||
for (left_block, scalar_block_value) in ct.blocks().iter().zip(decomposer) {
|
||||
let degree_after_add = left_block.degree + Degree::new(u64::from(scalar_block_value));
|
||||
for (block_metadata, scalar_block_value) in block_metadata.zip(decomposer) {
|
||||
let (block_degree, block_msg_mod, block_carry_mod) = block_metadata;
|
||||
let degree_after_add = block_degree + Degree::new(u64::from(scalar_block_value));
|
||||
|
||||
// Also need to take into account preceding_carry
|
||||
let max_degree = MaxDegree::from_msg_carry_modulus(
|
||||
left_block.message_modulus,
|
||||
left_block.carry_modulus,
|
||||
);
|
||||
let max_degree = MaxDegree::from_msg_carry_modulus(block_msg_mod, block_carry_mod);
|
||||
|
||||
max_degree.validate(degree_after_add + preceding_block_carry)?;
|
||||
|
||||
preceding_block_carry =
|
||||
Degree::new(degree_after_add.get() / left_block.message_modulus.0);
|
||||
preceding_block_carry = Degree::new(degree_after_add.get() / block_msg_mod.0);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use crate::core_crypto::prelude::{Cleartext, SignedNumeric, UnsignedNumeric};
|
||||
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
|
||||
use crate::integer::ciphertext::IntegerRadixCiphertext;
|
||||
use crate::integer::server_key::radix::neg::NegatedDegreeIter;
|
||||
use crate::integer::server_key::radix::scalar_sub::TwosComplementNegation;
|
||||
use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey, SignedRadixCiphertext};
|
||||
use crate::integer::{BooleanBlock, CheckError, RadixCiphertext, ServerKey, SignedRadixCiphertext};
|
||||
use crate::shortint::{Ciphertext, PaddingBit};
|
||||
use rayon::prelude::*;
|
||||
|
||||
@@ -114,6 +115,99 @@ impl ServerKey {
|
||||
self.scalar_add_assign_parallelized(ct, scalar.twos_complement_negation());
|
||||
}
|
||||
|
||||
pub fn unchecked_left_scalar_sub<Scalar, T>(&self, scalar: Scalar, rhs: &T) -> T
|
||||
where
|
||||
Scalar: DecomposableInto<u8>,
|
||||
T: IntegerRadixCiphertext,
|
||||
{
|
||||
// a - b <=> a + (-b)
|
||||
let mut neg_rhs = self.unchecked_neg(rhs);
|
||||
self.unchecked_scalar_add_assign(&mut neg_rhs, scalar);
|
||||
neg_rhs
|
||||
}
|
||||
|
||||
pub fn is_left_scalar_sub_possible<Scalar, T>(
|
||||
&self,
|
||||
scalar: Scalar,
|
||||
rhs: &T,
|
||||
) -> Result<(), CheckError>
|
||||
where
|
||||
Scalar: DecomposableInto<u8>,
|
||||
T: IntegerRadixCiphertext,
|
||||
{
|
||||
// We do scalar - ct by doing scalar + (-ct)
|
||||
// So we first have to check `-ct` is possible
|
||||
// then that adding scalar to it is possible
|
||||
self.is_neg_possible(rhs)?;
|
||||
let neg_degree_iter =
|
||||
NegatedDegreeIter::new(rhs.blocks().iter().map(|b| (b.degree, b.message_modulus)));
|
||||
let block_metadata_iter = rhs
|
||||
.blocks()
|
||||
.iter()
|
||||
.zip(neg_degree_iter)
|
||||
.map(|(block, neg_degree)| (neg_degree, block.message_modulus, block.carry_modulus));
|
||||
|
||||
self.is_scalar_add_possible_impl(block_metadata_iter, scalar)
|
||||
}
|
||||
|
||||
pub fn smart_left_scalar_sub_parallelized<Scalar, T>(&self, scalar: Scalar, rhs: &mut T) -> T
|
||||
where
|
||||
Scalar: DecomposableInto<u8>,
|
||||
T: IntegerRadixCiphertext,
|
||||
{
|
||||
if self.is_neg_possible(rhs).is_err() {
|
||||
self.full_propagate_parallelized(rhs);
|
||||
self.unchecked_left_scalar_sub(scalar, rhs)
|
||||
} else {
|
||||
// a - b <=> a + (-b)
|
||||
let mut neg_rhs = self.unchecked_neg(rhs);
|
||||
if self.is_scalar_add_possible(&neg_rhs, scalar).is_err() {
|
||||
// since adding scalar does not increase the nose, only the
|
||||
// degree can be problematic
|
||||
self.full_propagate_parallelized(&mut neg_rhs);
|
||||
}
|
||||
self.unchecked_scalar_add_assign(&mut neg_rhs, scalar);
|
||||
neg_rhs
|
||||
}
|
||||
}
|
||||
|
||||
pub fn left_scalar_sub_parallelized<Scalar, T>(&self, scalar: Scalar, rhs: &T) -> T
|
||||
where
|
||||
Scalar: DecomposableInto<u8>,
|
||||
T: IntegerRadixCiphertext,
|
||||
{
|
||||
if rhs.block_carries_are_empty() {
|
||||
// a - b <=> a + (-b) <=> a + (!b + 1) <=> !b + a + 1
|
||||
let mut flipped_ct = self.bitnot(rhs);
|
||||
let scalar_blocks = BlockDecomposer::with_block_count(
|
||||
scalar,
|
||||
self.message_modulus().0.ilog2(),
|
||||
rhs.blocks().len(),
|
||||
)
|
||||
.iter_as::<u8>()
|
||||
.collect();
|
||||
let (input_carry, compute_overflow) = (true, false);
|
||||
self.add_assign_scalar_blocks_parallelized(
|
||||
&mut flipped_ct,
|
||||
scalar_blocks,
|
||||
input_carry,
|
||||
compute_overflow,
|
||||
);
|
||||
flipped_ct
|
||||
} else {
|
||||
// We could clone rhs and full_propagate, then do the same thing as when the
|
||||
// rhs's carries are clean. This would cost 2 full_propagate, but the second one
|
||||
// would be less expensive because it happens in a scalar_add.
|
||||
//
|
||||
// However, we chose to all the smart version on the cloned_rhs, as maybe the carries
|
||||
// are not so bad that the smart version will be able to avoid the first full_prop
|
||||
let mut tmp_rhs = rhs.clone();
|
||||
let mut res = self.smart_left_scalar_sub_parallelized(scalar, &mut tmp_rhs);
|
||||
self.full_propagate_parallelized(&mut res);
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unsigned_overflowing_scalar_sub_assign_parallelized<T>(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
|
||||
use crate::integer::server_key::radix_parallel::tests_signed::{
|
||||
random_non_zero_value, signed_overflowing_add_under_modulus,
|
||||
random_non_zero_value, signed_add_under_modulus, signed_overflowing_add_under_modulus,
|
||||
signed_overflowing_sub_under_modulus, signed_sub_under_modulus, NB_CTXT,
|
||||
};
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::{
|
||||
nb_tests_for_params, nb_tests_smaller_for_params, CpuFunctionExecutor,
|
||||
nb_tests_for_params, nb_tests_smaller_for_params, CpuFunctionExecutor, MAX_NB_CTXT,
|
||||
};
|
||||
use crate::integer::tests::create_parameterized_test;
|
||||
use crate::integer::{
|
||||
@@ -20,6 +20,9 @@ use std::sync::Arc;
|
||||
|
||||
create_parameterized_test!(integer_signed_unchecked_scalar_sub);
|
||||
create_parameterized_test!(integer_signed_default_overflowing_scalar_sub);
|
||||
create_parameterized_test!(integer_signed_unchecked_left_scalar_sub);
|
||||
create_parameterized_test!(integer_signed_smart_left_scalar_sub);
|
||||
create_parameterized_test!(integer_signed_default_left_scalar_sub);
|
||||
|
||||
fn integer_signed_unchecked_scalar_sub<P>(param: P)
|
||||
where
|
||||
@@ -36,6 +39,31 @@ where
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_sub_parallelized);
|
||||
signed_default_overflowing_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_signed_unchecked_left_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_left_scalar_sub);
|
||||
signed_unchecked_left_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_signed_smart_left_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::smart_left_scalar_sub_parallelized);
|
||||
signed_smart_left_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_signed_default_left_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::left_scalar_sub_parallelized);
|
||||
signed_default_left_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
pub(crate) fn signed_unchecked_scalar_sub_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
@@ -261,3 +289,159 @@ where
|
||||
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn signed_unchecked_left_scalar_sub_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<(i64, &'a SignedRadixCiphertext), SignedRadixCiphertext>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
executor.setup(&cks, sks.clone());
|
||||
|
||||
let cks: crate::integer::ClientKey = cks.into();
|
||||
|
||||
for num_blocks in 1..MAX_NB_CTXT {
|
||||
// message_modulus^vec_length
|
||||
let modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64;
|
||||
if modulus <= 1 {
|
||||
continue;
|
||||
}
|
||||
|
||||
for _ in 0..nb_tests {
|
||||
let clear_lhs = rng.gen::<i64>() % modulus;
|
||||
let mut clear_rhs = rng.gen::<i64>() % modulus;
|
||||
|
||||
let mut ct_rhs = cks.encrypt_signed_radix(clear_rhs, num_blocks);
|
||||
|
||||
ct_rhs = executor.execute((clear_lhs, &ct_rhs));
|
||||
clear_rhs = signed_sub_under_modulus(clear_lhs, clear_rhs, modulus);
|
||||
|
||||
let dec_res: i64 = cks.decrypt_signed_radix(&ct_rhs);
|
||||
assert_eq!(dec_res, clear_rhs);
|
||||
|
||||
let mut clear_lhs = rng.gen::<i64>() % modulus;
|
||||
while sks.is_left_scalar_sub_possible(clear_lhs, &ct_rhs).is_ok() {
|
||||
ct_rhs = executor.execute((clear_lhs, &ct_rhs));
|
||||
clear_rhs = signed_sub_under_modulus(clear_lhs, clear_rhs, modulus);
|
||||
let dec_res: i64 = cks.decrypt_signed_radix(&ct_rhs);
|
||||
assert_eq!(dec_res, clear_rhs);
|
||||
clear_lhs = rng.gen::<i64>() % modulus;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn signed_smart_left_scalar_sub_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<(i64, &'a mut SignedRadixCiphertext), SignedRadixCiphertext>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
let cks: crate::integer::ClientKey = cks.into();
|
||||
|
||||
for num_blocks in 1..MAX_NB_CTXT {
|
||||
// message_modulus^vec_length
|
||||
let modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64;
|
||||
if modulus <= 1 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let clear_lhs = rng.gen::<i64>() % modulus;
|
||||
let mut clear_rhs = rng.gen::<i64>() % modulus;
|
||||
|
||||
let mut ct_rhs = cks.encrypt_signed_radix(clear_rhs, num_blocks);
|
||||
|
||||
ct_rhs = executor.execute((clear_lhs, &mut ct_rhs));
|
||||
clear_rhs = signed_sub_under_modulus(clear_lhs, clear_rhs, modulus);
|
||||
|
||||
let dec_res: i64 = cks.decrypt_signed_radix(&ct_rhs);
|
||||
assert_eq!(dec_res, clear_rhs);
|
||||
for _ in 0..nb_tests {
|
||||
let clear_lhs = rng.gen::<i64>() % modulus;
|
||||
|
||||
ct_rhs = executor.execute((clear_lhs, &mut ct_rhs));
|
||||
clear_rhs = signed_sub_under_modulus(clear_lhs, clear_rhs, modulus);
|
||||
let dec_res: i64 = cks.decrypt_signed_radix(&ct_rhs);
|
||||
assert_eq!(dec_res, clear_rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn signed_default_left_scalar_sub_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<(i64, &'a SignedRadixCiphertext), SignedRadixCiphertext>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
executor.setup(&cks, sks.clone());
|
||||
|
||||
let cks: crate::integer::ClientKey = cks.into();
|
||||
|
||||
for num_blocks in 1..MAX_NB_CTXT {
|
||||
// message_modulus^vec_length
|
||||
let modulus = (cks.parameters().message_modulus().0.pow(num_blocks as u32) / 2) as i64;
|
||||
if modulus <= 1 {
|
||||
continue;
|
||||
}
|
||||
|
||||
for _ in 0..nb_tests {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let clear_1 = rng.gen::<i64>() % modulus;
|
||||
|
||||
let ctxt_1 = cks.encrypt_signed_radix(clear_1, num_blocks);
|
||||
|
||||
let ct_res = executor.execute((clear_0, &ctxt_1));
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
|
||||
let tmp = executor.execute((clear_0, &ctxt_1));
|
||||
assert_eq!(ct_res, tmp, "Operation is not deterministic");
|
||||
|
||||
let dec_res: i64 = cks.decrypt_signed_radix(&ct_res);
|
||||
assert_eq!(dec_res, signed_sub_under_modulus(clear_0, clear_1, modulus));
|
||||
|
||||
let non_zero = random_non_zero_value(&mut rng, modulus);
|
||||
let non_clean = sks.unchecked_scalar_add(&ctxt_1, non_zero);
|
||||
let ct_res = executor.execute((clear_0, &non_clean));
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
let dec_res: i64 = cks.decrypt_signed_radix(&ct_res);
|
||||
let expected = signed_sub_under_modulus(
|
||||
clear_0,
|
||||
signed_add_under_modulus(clear_1, non_zero, modulus),
|
||||
modulus,
|
||||
);
|
||||
assert_eq!(dec_res, expected);
|
||||
|
||||
let ct_res2 = executor.execute((clear_0, &non_clean));
|
||||
assert_eq!(ct_res, ct_res2, "Failed determinism check");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,28 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{
|
||||
default_overflowing_scalar_sub_test, default_scalar_sub_test, smart_scalar_sub_test,
|
||||
FunctionExecutor,
|
||||
};
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::{
|
||||
nb_tests_for_params, random_non_zero_value, CpuFunctionExecutor,
|
||||
};
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
|
||||
use crate::integer::tests::create_parameterized_test;
|
||||
use crate::integer::ServerKey;
|
||||
use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey};
|
||||
#[cfg(tarpaulin)]
|
||||
use crate::shortint::parameters::coverage_parameters::*;
|
||||
use crate::shortint::parameters::test_params::*;
|
||||
use crate::shortint::parameters::*;
|
||||
use rand::prelude::*;
|
||||
|
||||
use super::{MAX_NB_CTXT, NB_CTXT};
|
||||
|
||||
create_parameterized_test!(integer_smart_scalar_sub);
|
||||
create_parameterized_test!(integer_default_scalar_sub);
|
||||
create_parameterized_test!(integer_unchecked_left_scalar_sub);
|
||||
create_parameterized_test!(integer_smart_left_scalar_sub);
|
||||
create_parameterized_test!(integer_default_left_scalar_sub);
|
||||
create_parameterized_test!(integer_default_overflowing_scalar_sub);
|
||||
|
||||
fn integer_smart_scalar_sub<P>(param: P)
|
||||
@@ -29,6 +41,30 @@ where
|
||||
default_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_unchecked_left_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_left_scalar_sub);
|
||||
unchecked_left_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_smart_left_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::smart_left_scalar_sub_parallelized);
|
||||
smart_left_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_default_left_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::left_scalar_sub_parallelized);
|
||||
default_left_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_default_overflowing_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
@@ -37,3 +73,148 @@ where
|
||||
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_sub_parallelized);
|
||||
default_overflowing_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
pub(crate) fn unchecked_left_scalar_sub_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<(u64, &'a RadixCiphertext), RadixCiphertext>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut rng = thread_rng();
|
||||
|
||||
executor.setup(&cks, sks.clone());
|
||||
|
||||
let cks: crate::integer::ClientKey = cks.into();
|
||||
|
||||
for num_blocks in 1..MAX_NB_CTXT {
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32);
|
||||
|
||||
for _ in 0..nb_tests {
|
||||
let clear_lhs = rng.gen::<u64>() % modulus;
|
||||
let mut clear_rhs = rng.gen::<u64>() % modulus;
|
||||
|
||||
let mut ct_rhs = cks.encrypt_radix(clear_rhs, num_blocks);
|
||||
|
||||
ct_rhs = executor.execute((clear_lhs, &ct_rhs));
|
||||
clear_rhs = clear_lhs.wrapping_sub(clear_rhs) % modulus;
|
||||
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_rhs);
|
||||
assert_eq!(dec_res, clear_rhs);
|
||||
|
||||
let mut clear_lhs = rng.gen::<u64>() % modulus;
|
||||
while sks.is_left_scalar_sub_possible(clear_lhs, &ct_rhs).is_ok() {
|
||||
ct_rhs = executor.execute((clear_lhs, &ct_rhs));
|
||||
clear_rhs = clear_lhs.wrapping_sub(clear_rhs) % modulus;
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_rhs);
|
||||
assert_eq!(dec_res, clear_rhs);
|
||||
clear_lhs = rng.gen::<u64>() % modulus;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn smart_left_scalar_sub_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<(u64, &'a mut RadixCiphertext), RadixCiphertext>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut rng = thread_rng();
|
||||
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
let cks: crate::integer::ClientKey = cks.into();
|
||||
|
||||
for num_blocks in 1..MAX_NB_CTXT {
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32);
|
||||
|
||||
let clear_lhs = rng.gen::<u64>() % modulus;
|
||||
let mut clear_rhs = rng.gen::<u64>() % modulus;
|
||||
|
||||
let mut ct_rhs = cks.encrypt_radix(clear_rhs, num_blocks);
|
||||
|
||||
ct_rhs = executor.execute((clear_lhs, &mut ct_rhs));
|
||||
clear_rhs = clear_lhs.wrapping_sub(clear_rhs) % modulus;
|
||||
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_rhs);
|
||||
assert_eq!(dec_res, clear_rhs);
|
||||
for _ in 0..nb_tests {
|
||||
let clear_lhs = rng.gen::<u64>() % modulus;
|
||||
|
||||
ct_rhs = executor.execute((clear_lhs, &mut ct_rhs));
|
||||
clear_rhs = clear_lhs.wrapping_sub(clear_rhs) % modulus;
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_rhs);
|
||||
assert_eq!(dec_res, clear_rhs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn default_left_scalar_sub_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<(u64, &'a RadixCiphertext), RadixCiphertext>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut rng = thread_rng();
|
||||
|
||||
executor.setup(&cks, sks.clone());
|
||||
|
||||
let cks: crate::integer::ClientKey = cks.into();
|
||||
|
||||
for num_blocks in 1..MAX_NB_CTXT {
|
||||
// message_modulus^vec_length
|
||||
let modulus = cks.parameters().message_modulus().0.pow(num_blocks as u32);
|
||||
|
||||
for _ in 0..nb_tests {
|
||||
let clear_0 = rng.gen::<u64>() % modulus;
|
||||
let clear_1 = rng.gen::<u64>() % modulus;
|
||||
|
||||
let ctxt_1 = cks.encrypt_radix(clear_1, num_blocks);
|
||||
|
||||
let ct_res = executor.execute((clear_0, &ctxt_1));
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
|
||||
let tmp = executor.execute((clear_0, &ctxt_1));
|
||||
assert_eq!(ct_res, tmp, "Operation is not deterministic");
|
||||
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_res);
|
||||
assert_eq!(dec_res, (clear_0.wrapping_sub(clear_1)) % modulus);
|
||||
|
||||
let non_zero = random_non_zero_value(&mut rng, modulus);
|
||||
let non_clean = sks.unchecked_scalar_add(&ctxt_1, non_zero);
|
||||
let ct_res = executor.execute((clear_0, &non_clean));
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
let dec_res: u64 = cks.decrypt_radix(&ct_res);
|
||||
assert_eq!(
|
||||
dec_res,
|
||||
(clear_0.wrapping_sub(clear_1.wrapping_add(non_zero))) % modulus
|
||||
);
|
||||
|
||||
let ct_res2 = executor.execute((clear_0, &non_clean));
|
||||
assert_eq!(ct_res, ct_res2, "Failed determinism check");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user