mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
feat(gpu): signed scalar sub
This commit is contained in:
@@ -503,7 +503,7 @@ generic_integer_impl_scalar_operation!(
|
||||
InternalServerKey::Cuda(cuda_key) => {
|
||||
let inner_result = with_thread_local_cuda_stream(|stream| {
|
||||
cuda_key.key.scalar_sub(
|
||||
&lhs.ciphertext.on_gpu(), rhs, stream
|
||||
&*lhs.ciphertext.on_gpu(), rhs, stream
|
||||
)
|
||||
});
|
||||
RadixCiphertext::Cuda(inner_result)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::core_crypto::gpu::CudaStream;
|
||||
use crate::core_crypto::prelude::UnsignedNumeric;
|
||||
use crate::core_crypto::prelude::Numeric;
|
||||
use crate::integer::block_decomposition::DecomposableInto;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::CudaServerKey;
|
||||
use crate::integer::server_key::TwosComplementNegation;
|
||||
|
||||
@@ -43,14 +43,10 @@ impl CudaServerKey {
|
||||
/// let dec: u64 = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(msg - scalar, dec);
|
||||
/// ```
|
||||
pub fn unchecked_scalar_sub<T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
pub fn unchecked_scalar_sub<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
|
||||
where
|
||||
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
|
||||
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut result = unsafe { ct.duplicate_async(stream) };
|
||||
self.unchecked_scalar_sub_assign(&mut result, scalar, stream);
|
||||
@@ -61,26 +57,28 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_sub_assign_async<T>(
|
||||
pub unsafe fn unchecked_scalar_sub_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &mut T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) where
|
||||
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
|
||||
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let negated_scalar = scalar.twos_complement_negation();
|
||||
self.unchecked_scalar_add_assign_async(ct, negated_scalar, stream);
|
||||
ct.as_mut().info = ct.as_ref().info.after_scalar_sub(scalar);
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_sub_assign<T>(
|
||||
pub fn unchecked_scalar_sub_assign<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &mut T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) where
|
||||
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
|
||||
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.unchecked_scalar_sub_assign_async(ct, scalar, stream);
|
||||
@@ -125,14 +123,10 @@ impl CudaServerKey {
|
||||
/// let dec: u64 = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(msg - scalar, dec);
|
||||
/// ```
|
||||
pub fn scalar_sub<T>(
|
||||
&self,
|
||||
ct: &CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
pub fn scalar_sub<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
|
||||
where
|
||||
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
|
||||
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut result = unsafe { ct.duplicate_async(stream) };
|
||||
self.scalar_sub_assign(&mut result, scalar, stream);
|
||||
@@ -143,13 +137,14 @@ impl CudaServerKey {
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn scalar_sub_assign_async<T>(
|
||||
pub unsafe fn scalar_sub_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
ct: &mut T,
|
||||
scalar: Scalar,
|
||||
stream: &CudaStream,
|
||||
) where
|
||||
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
|
||||
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
if !ct.block_carries_are_empty() {
|
||||
self.full_propagate_assign_async(ct, stream);
|
||||
@@ -159,13 +154,10 @@ impl CudaServerKey {
|
||||
self.full_propagate_assign_async(ct, stream);
|
||||
}
|
||||
|
||||
pub fn scalar_sub_assign<T>(
|
||||
&self,
|
||||
ct: &mut CudaUnsignedRadixCiphertext,
|
||||
scalar: T,
|
||||
stream: &CudaStream,
|
||||
) where
|
||||
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
|
||||
pub fn scalar_sub_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, stream: &CudaStream)
|
||||
where
|
||||
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.scalar_sub_assign_async(ct, scalar, stream);
|
||||
|
||||
@@ -2,6 +2,7 @@ pub(crate) mod test_add;
|
||||
pub(crate) mod test_mul;
|
||||
pub(crate) mod test_neg;
|
||||
pub(crate) mod test_scalar_add;
|
||||
pub(crate) mod test_scalar_sub;
|
||||
pub(crate) mod test_sub;
|
||||
|
||||
use crate::core_crypto::gpu::CudaStream;
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::{
|
||||
create_gpu_parametrized_test, GpuFunctionExecutor,
|
||||
};
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::server_key::radix_parallel::tests_cases_signed::signed_unchecked_scalar_sub_test;
|
||||
use crate::shortint::parameters::*;
|
||||
|
||||
create_gpu_parametrized_test!(integer_signed_unchecked_scalar_sub);
|
||||
|
||||
fn integer_signed_unchecked_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_sub);
|
||||
signed_unchecked_scalar_sub_test(param, executor);
|
||||
}
|
||||
@@ -2,6 +2,7 @@ pub(crate) mod test_add;
|
||||
pub(crate) mod test_mul;
|
||||
pub(crate) mod test_neg;
|
||||
pub(crate) mod test_scalar_add;
|
||||
pub(crate) mod test_scalar_sub;
|
||||
pub(crate) mod test_sub;
|
||||
|
||||
use crate::core_crypto::gpu::{CudaDevice, CudaStream};
|
||||
@@ -75,7 +76,6 @@ impl<F> GpuFunctionExecutor<F> {
|
||||
}
|
||||
|
||||
// Unchecked operations
|
||||
create_gpu_parametrized_test!(integer_unchecked_scalar_sub);
|
||||
create_gpu_parametrized_test!(integer_unchecked_small_scalar_mul);
|
||||
create_gpu_parametrized_test!(integer_unchecked_bitnot);
|
||||
create_gpu_parametrized_test!(integer_unchecked_bitand);
|
||||
@@ -107,7 +107,6 @@ create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_left);
|
||||
create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_right);
|
||||
|
||||
// Default operations
|
||||
create_gpu_parametrized_test!(integer_scalar_sub);
|
||||
create_gpu_parametrized_test!(integer_small_scalar_mul);
|
||||
create_gpu_parametrized_test!(integer_scalar_right_shift);
|
||||
create_gpu_parametrized_test!(integer_scalar_left_shift);
|
||||
@@ -318,14 +317,6 @@ where
|
||||
unchecked_small_scalar_mul_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_unchecked_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_sub);
|
||||
unchecked_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_unchecked_bitnot<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
@@ -1496,14 +1487,6 @@ where
|
||||
default_small_scalar_mul_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_sub);
|
||||
default_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_bitnot<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters> + Copy,
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::{
|
||||
create_gpu_parametrized_test, GpuFunctionExecutor,
|
||||
};
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{
|
||||
default_scalar_sub_test, unchecked_scalar_sub_test,
|
||||
};
|
||||
use crate::shortint::parameters::*;
|
||||
|
||||
create_gpu_parametrized_test!(integer_unchecked_scalar_sub);
|
||||
create_gpu_parametrized_test!(integer_scalar_sub);
|
||||
|
||||
fn integer_unchecked_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters> + Copy,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_sub);
|
||||
unchecked_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters> + Copy,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_sub);
|
||||
default_scalar_sub_test(param, executor);
|
||||
}
|
||||
@@ -1057,3 +1057,221 @@ where
|
||||
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn signed_unchecked_scalar_sub_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>,
|
||||
{
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
|
||||
|
||||
executor.setup(&cks, sks);
|
||||
|
||||
// check some overflow behaviour
|
||||
let overflowing_values = [
|
||||
(-modulus, 1, modulus - 1),
|
||||
(modulus - 1, -1, -modulus),
|
||||
(-modulus, 2, modulus - 2),
|
||||
(modulus - 2, -2, -modulus),
|
||||
];
|
||||
for (clear_0, clear_1, expected_clear) in overflowing_values {
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
let ct_res = executor.execute((&ctxt_0, clear_1));
|
||||
let dec_res: i64 = cks.decrypt_signed(&ct_res);
|
||||
let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
assert_eq!(clear_res, dec_res);
|
||||
assert_eq!(clear_res, expected_clear);
|
||||
}
|
||||
|
||||
for _ in 0..NB_TESTS {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let clear_1 = rng.gen::<i64>() % modulus;
|
||||
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
|
||||
let ct_res = executor.execute((&ctxt_0, clear_1));
|
||||
let dec_res: i64 = cks.decrypt_signed(&ct_res);
|
||||
let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
assert_eq!(clear_res, dec_res);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn signed_default_overflowing_scalar_sub_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, i64),
|
||||
(SignedRadixCiphertext, BooleanBlock),
|
||||
>,
|
||||
{
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
|
||||
|
||||
executor.setup(&cks, sks.clone());
|
||||
|
||||
let hardcoded_values = [
|
||||
(-modulus, 1),
|
||||
(modulus - 1, -1),
|
||||
(1, -modulus),
|
||||
(-1, modulus - 1),
|
||||
];
|
||||
for (clear_0, clear_1) in hardcoded_values {
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
|
||||
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1));
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
assert_eq!(result_overflowed.0.degree.get(), 1);
|
||||
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
|
||||
}
|
||||
|
||||
for _ in 0..NB_TESTS_SMALLER {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let clear_1 = rng.gen::<i64>() % modulus;
|
||||
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
|
||||
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1));
|
||||
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, clear_1));
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
|
||||
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");
|
||||
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
assert_eq!(result_overflowed.0.degree.get(), 1);
|
||||
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
|
||||
|
||||
for _ in 0..NB_TESTS_SMALLER {
|
||||
// Add non zero scalar to have non clean ciphertexts
|
||||
let clear_2 = random_non_zero_value(&mut rng, modulus);
|
||||
let clear_rhs = random_non_zero_value(&mut rng, modulus);
|
||||
|
||||
let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
|
||||
let (clear_lhs, _) = signed_overflowing_add_under_modulus(clear_0, clear_2, modulus);
|
||||
let d0: i64 = cks.decrypt_signed(&ctxt_0);
|
||||
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
|
||||
|
||||
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_rhs));
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for sub, for ({clear_lhs} + {clear_rhs}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
assert_eq!(result_overflowed.0.degree.get(), 1);
|
||||
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
|
||||
}
|
||||
}
|
||||
|
||||
// Test with trivial inputs
|
||||
for _ in 0..4 {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let clear_1 = rng.gen::<i64>() % modulus;
|
||||
|
||||
let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
|
||||
|
||||
let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1));
|
||||
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for add, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
assert_eq!(encrypted_overflow.0.degree.get(), 1);
|
||||
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
|
||||
}
|
||||
|
||||
// Test with scalar that is bigger than ciphertext modulus
|
||||
for _ in 0..2 {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let clear_1 = rng.gen_range(modulus..=i64::MAX);
|
||||
|
||||
let a = cks.encrypt_signed(clear_0);
|
||||
|
||||
let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1));
|
||||
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
assert!(decrypted_overflowed); // Actually we know its an overflow case
|
||||
assert_eq!(encrypted_overflow.0.degree.get(), 1);
|
||||
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ pub(crate) mod test_add;
|
||||
pub(crate) mod test_mul;
|
||||
pub(crate) mod test_neg;
|
||||
pub(crate) mod test_scalar_add;
|
||||
pub(crate) mod test_scalar_sub;
|
||||
pub(crate) mod test_sub;
|
||||
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
@@ -11,7 +12,6 @@ use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecu
|
||||
use crate::integer::{
|
||||
BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext,
|
||||
};
|
||||
use crate::shortint::ciphertext::NoiseLevel;
|
||||
#[cfg(tarpaulin)]
|
||||
use crate::shortint::parameters::coverage_parameters::*;
|
||||
use crate::shortint::parameters::*;
|
||||
@@ -1512,7 +1512,6 @@ where
|
||||
// Unchecked Scalar Tests
|
||||
//================================================================================
|
||||
|
||||
create_parametrized_test!(integer_signed_unchecked_scalar_sub);
|
||||
create_parametrized_test!(integer_signed_unchecked_scalar_mul);
|
||||
create_parametrized_test!(integer_signed_unchecked_scalar_rotate_left);
|
||||
create_parametrized_test!(integer_signed_unchecked_scalar_rotate_right);
|
||||
@@ -1524,42 +1523,6 @@ create_parametrized_test!(integer_signed_unchecked_scalar_bitxor);
|
||||
create_parametrized_test!(integer_signed_unchecked_scalar_div_rem);
|
||||
create_parametrized_test!(integer_signed_unchecked_scalar_div_rem_floor);
|
||||
|
||||
fn integer_signed_unchecked_scalar_sub(param: impl Into<PBSParameters>) {
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
|
||||
|
||||
// check some overflow behaviour
|
||||
let overflowing_values = [
|
||||
(-modulus, 1, modulus - 1),
|
||||
(modulus - 1, -1, -modulus),
|
||||
(-modulus, 2, modulus - 2),
|
||||
(modulus - 2, -2, -modulus),
|
||||
];
|
||||
for (clear_0, clear_1, expected_clear) in overflowing_values {
|
||||
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
|
||||
let ct_res = sks.unchecked_scalar_sub(&ctxt_0, clear_1);
|
||||
let dec_res: i64 = cks.decrypt_signed_radix(&ct_res);
|
||||
let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
assert_eq!(clear_res, dec_res);
|
||||
assert_eq!(clear_res, expected_clear);
|
||||
}
|
||||
|
||||
for _ in 0..NB_TESTS {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let clear_1 = rng.gen::<i64>() % modulus;
|
||||
|
||||
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
|
||||
|
||||
let ct_res = sks.unchecked_scalar_sub(&ctxt_0, clear_1);
|
||||
let dec_res: i64 = cks.decrypt_signed_radix(&ct_res);
|
||||
let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
assert_eq!(clear_res, dec_res);
|
||||
}
|
||||
}
|
||||
|
||||
fn integer_signed_unchecked_scalar_mul(param: impl Into<PBSParameters>) {
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
@@ -2040,7 +2003,6 @@ fn integer_signed_unchecked_scalar_div_rem_floor(param: impl Into<PBSParameters>
|
||||
// Default Scalar Tests
|
||||
//================================================================================
|
||||
|
||||
create_parametrized_test!(integer_signed_default_overflowing_scalar_sub);
|
||||
create_parametrized_test!(integer_signed_default_scalar_bitand);
|
||||
create_parametrized_test!(integer_signed_default_scalar_bitor);
|
||||
create_parametrized_test!(integer_signed_default_scalar_bitxor);
|
||||
@@ -2050,179 +2012,6 @@ create_parametrized_test!(integer_signed_default_scalar_right_shift);
|
||||
create_parametrized_test!(integer_signed_default_scalar_rotate_right);
|
||||
create_parametrized_test!(integer_signed_default_scalar_rotate_left);
|
||||
|
||||
pub(crate) fn integer_signed_default_overflowing_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
// message_modulus^vec_length
|
||||
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
|
||||
|
||||
let hardcoded_values = [
|
||||
(-modulus, 1),
|
||||
(modulus - 1, -1),
|
||||
(1, -modulus),
|
||||
(-1, modulus - 1),
|
||||
];
|
||||
for (clear_0, clear_1) in hardcoded_values {
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
|
||||
let (ct_res, result_overflowed) =
|
||||
sks.signed_overflowing_scalar_sub_parallelized(&ctxt_0, clear_1);
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
assert_eq!(result_overflowed.0.degree.get(), 1);
|
||||
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
|
||||
}
|
||||
|
||||
for _ in 0..NB_TESTS_SMALLER {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let clear_1 = rng.gen::<i64>() % modulus;
|
||||
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
|
||||
let (ct_res, result_overflowed) =
|
||||
sks.signed_overflowing_scalar_sub_parallelized(&ctxt_0, clear_1);
|
||||
let (tmp_ct, tmp_o) = sks.signed_overflowing_scalar_sub_parallelized(&ctxt_0, clear_1);
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
|
||||
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");
|
||||
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
assert_eq!(result_overflowed.0.degree.get(), 1);
|
||||
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
|
||||
|
||||
for _ in 0..NB_TESTS_SMALLER {
|
||||
// Add non zero scalar to have non clean ciphertexts
|
||||
let clear_2 = random_non_zero_value(&mut rng, modulus);
|
||||
let clear_rhs = random_non_zero_value(&mut rng, modulus);
|
||||
|
||||
let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
|
||||
let (clear_lhs, _) = signed_overflowing_add_under_modulus(clear_0, clear_2, modulus);
|
||||
let d0: i64 = cks.decrypt_signed(&ctxt_0);
|
||||
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
|
||||
|
||||
let (ct_res, result_overflowed) =
|
||||
sks.signed_overflowing_scalar_sub_parallelized(&ctxt_0, clear_rhs);
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for sub, for ({clear_lhs} + {clear_rhs}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
assert_eq!(result_overflowed.0.degree.get(), 1);
|
||||
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
|
||||
}
|
||||
}
|
||||
|
||||
// Test with trivial inputs
|
||||
for _ in 0..4 {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let clear_1 = rng.gen::<i64>() % modulus;
|
||||
|
||||
let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
|
||||
|
||||
let (encrypted_result, encrypted_overflow) =
|
||||
sks.signed_overflowing_scalar_sub_parallelized(&a, clear_1);
|
||||
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for add, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
assert_eq!(encrypted_overflow.0.degree.get(), 1);
|
||||
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
|
||||
}
|
||||
|
||||
// Test with scalar that is bigger than ciphertext modulus
|
||||
for _ in 0..2 {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let clear_1 = rng.gen_range(modulus..=i64::MAX);
|
||||
|
||||
let a = cks.encrypt_signed(clear_0);
|
||||
|
||||
let (encrypted_result, encrypted_overflow) =
|
||||
sks.signed_overflowing_scalar_sub_parallelized(&a, clear_1);
|
||||
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
assert!(decrypted_overflowed); // Actually we know its an overflow case
|
||||
assert_eq!(encrypted_overflow.0.degree.get(), 1);
|
||||
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
|
||||
}
|
||||
}
|
||||
|
||||
fn integer_signed_default_scalar_bitand(param: impl Into<PBSParameters>) {
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
use crate::integer::server_key::radix_parallel::tests_cases_signed::{
|
||||
signed_default_overflowing_scalar_sub_test, signed_unchecked_scalar_sub_test,
|
||||
};
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
|
||||
use crate::integer::ServerKey;
|
||||
#[cfg(tarpaulin)]
|
||||
use crate::shortint::parameters::coverage_parameters::*;
|
||||
use crate::shortint::parameters::*;
|
||||
|
||||
create_parametrized_test!(integer_signed_unchecked_scalar_sub);
|
||||
create_parametrized_test!(integer_signed_default_overflowing_scalar_sub);
|
||||
|
||||
fn integer_signed_unchecked_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_sub);
|
||||
signed_unchecked_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_signed_default_overflowing_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_sub_parallelized);
|
||||
signed_default_overflowing_scalar_sub_test(param, executor);
|
||||
}
|
||||
@@ -2,6 +2,7 @@ pub(crate) mod test_add;
|
||||
pub(crate) mod test_mul;
|
||||
pub(crate) mod test_neg;
|
||||
pub(crate) mod test_scalar_add;
|
||||
pub(crate) mod test_scalar_sub;
|
||||
pub(crate) mod test_sub;
|
||||
|
||||
use super::tests_cases_unsigned::*;
|
||||
@@ -597,9 +598,6 @@ create_parametrized_test!(integer_unchecked_scalar_rotate_left);
|
||||
create_parametrized_test!(integer_default_scalar_rotate_right);
|
||||
create_parametrized_test!(integer_default_scalar_rotate_left);
|
||||
create_parametrized_test!(integer_default_scalar_div_rem);
|
||||
create_parametrized_test!(integer_smart_scalar_sub);
|
||||
create_parametrized_test!(integer_default_scalar_sub);
|
||||
create_parametrized_test!(integer_default_overflowing_scalar_sub);
|
||||
create_parametrized_test!(integer_smart_if_then_else);
|
||||
create_parametrized_test!(integer_default_if_then_else);
|
||||
create_parametrized_test!(integer_trim_radix_msb_blocks_handles_dirty_inputs);
|
||||
@@ -982,14 +980,6 @@ where
|
||||
// Smart Scalar Tests
|
||||
//=============================================================================
|
||||
|
||||
fn integer_smart_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_sub_parallelized);
|
||||
smart_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_smart_small_scalar_mul<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
@@ -1134,23 +1124,6 @@ where
|
||||
// Default Scalar Tests
|
||||
//=============================================================================
|
||||
|
||||
fn integer_default_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::scalar_sub_parallelized);
|
||||
default_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_default_overflowing_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_sub_parallelized);
|
||||
default_overflowing_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_default_scalar_bitand<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{
|
||||
default_overflowing_scalar_sub_test, default_scalar_sub_test, smart_scalar_sub_test,
|
||||
};
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
|
||||
use crate::integer::ServerKey;
|
||||
#[cfg(tarpaulin)]
|
||||
use crate::shortint::parameters::coverage_parameters::*;
|
||||
use crate::shortint::parameters::*;
|
||||
|
||||
create_parametrized_test!(integer_smart_scalar_sub);
|
||||
create_parametrized_test!(integer_default_scalar_sub);
|
||||
create_parametrized_test!(integer_default_overflowing_scalar_sub);
|
||||
|
||||
fn integer_smart_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_sub_parallelized);
|
||||
smart_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_default_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::scalar_sub_parallelized);
|
||||
default_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_default_overflowing_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_sub_parallelized);
|
||||
default_overflowing_scalar_sub_test(param, executor);
|
||||
}
|
||||
Reference in New Issue
Block a user