feat(gpu): signed scalar sub

This commit is contained in:
Agnes Leroy
2024-03-08 10:20:49 +01:00
committed by Agnès Leroy
parent d3801446ff
commit 6f954bb538
11 changed files with 357 additions and 294 deletions

View File

@@ -503,7 +503,7 @@ generic_integer_impl_scalar_operation!(
InternalServerKey::Cuda(cuda_key) => {
let inner_result = with_thread_local_cuda_stream(|stream| {
cuda_key.key.scalar_sub(
&lhs.ciphertext.on_gpu(), rhs, stream
&*lhs.ciphertext.on_gpu(), rhs, stream
)
});
RadixCiphertext::Cuda(inner_result)

View File

@@ -1,7 +1,7 @@
use crate::core_crypto::gpu::CudaStream;
use crate::core_crypto::prelude::UnsignedNumeric;
use crate::core_crypto::prelude::Numeric;
use crate::integer::block_decomposition::DecomposableInto;
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
use crate::integer::gpu::server_key::CudaServerKey;
use crate::integer::server_key::TwosComplementNegation;
@@ -43,14 +43,10 @@ impl CudaServerKey {
/// let dec: u64 = cks.decrypt(&ct_res);
/// assert_eq!(msg - scalar, dec);
/// ```
pub fn unchecked_scalar_sub<T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
pub fn unchecked_scalar_sub<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
where
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
T: CudaIntegerRadixCiphertext,
{
let mut result = unsafe { ct.duplicate_async(stream) };
self.unchecked_scalar_sub_assign(&mut result, scalar, stream);
@@ -61,26 +57,28 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_sub_assign_async<T>(
pub unsafe fn unchecked_scalar_sub_assign_async<Scalar, T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
scalar: T,
ct: &mut T,
scalar: Scalar,
stream: &CudaStream,
) where
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
T: CudaIntegerRadixCiphertext,
{
let negated_scalar = scalar.twos_complement_negation();
self.unchecked_scalar_add_assign_async(ct, negated_scalar, stream);
ct.as_mut().info = ct.as_ref().info.after_scalar_sub(scalar);
}
pub fn unchecked_scalar_sub_assign<T>(
pub fn unchecked_scalar_sub_assign<Scalar, T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
scalar: T,
ct: &mut T,
scalar: Scalar,
stream: &CudaStream,
) where
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.unchecked_scalar_sub_assign_async(ct, scalar, stream);
@@ -125,14 +123,10 @@ impl CudaServerKey {
/// let dec: u64 = cks.decrypt(&ct_res);
/// assert_eq!(msg - scalar, dec);
/// ```
pub fn scalar_sub<T>(
&self,
ct: &CudaUnsignedRadixCiphertext,
scalar: T,
stream: &CudaStream,
) -> CudaUnsignedRadixCiphertext
pub fn scalar_sub<Scalar, T>(&self, ct: &T, scalar: Scalar, stream: &CudaStream) -> T
where
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
T: CudaIntegerRadixCiphertext,
{
let mut result = unsafe { ct.duplicate_async(stream) };
self.scalar_sub_assign(&mut result, scalar, stream);
@@ -143,13 +137,14 @@ impl CudaServerKey {
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn scalar_sub_assign_async<T>(
pub unsafe fn scalar_sub_assign_async<Scalar, T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
scalar: T,
ct: &mut T,
scalar: Scalar,
stream: &CudaStream,
) where
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign_async(ct, stream);
@@ -159,13 +154,10 @@ impl CudaServerKey {
self.full_propagate_assign_async(ct, stream);
}
pub fn scalar_sub_assign<T>(
&self,
ct: &mut CudaUnsignedRadixCiphertext,
scalar: T,
stream: &CudaStream,
) where
T: DecomposableInto<u8> + UnsignedNumeric + TwosComplementNegation,
pub fn scalar_sub_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, stream: &CudaStream)
where
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.scalar_sub_assign_async(ct, scalar, stream);

View File

@@ -2,6 +2,7 @@ pub(crate) mod test_add;
pub(crate) mod test_mul;
pub(crate) mod test_neg;
pub(crate) mod test_scalar_add;
pub(crate) mod test_scalar_sub;
pub(crate) mod test_sub;
use crate::core_crypto::gpu::CudaStream;

View File

@@ -0,0 +1,16 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parametrized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_cases_signed::signed_unchecked_scalar_sub_test;
use crate::shortint::parameters::*;
create_gpu_parametrized_test!(integer_signed_unchecked_scalar_sub);
fn integer_signed_unchecked_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_sub);
signed_unchecked_scalar_sub_test(param, executor);
}

View File

@@ -2,6 +2,7 @@ pub(crate) mod test_add;
pub(crate) mod test_mul;
pub(crate) mod test_neg;
pub(crate) mod test_scalar_add;
pub(crate) mod test_scalar_sub;
pub(crate) mod test_sub;
use crate::core_crypto::gpu::{CudaDevice, CudaStream};
@@ -75,7 +76,6 @@ impl<F> GpuFunctionExecutor<F> {
}
// Unchecked operations
create_gpu_parametrized_test!(integer_unchecked_scalar_sub);
create_gpu_parametrized_test!(integer_unchecked_small_scalar_mul);
create_gpu_parametrized_test!(integer_unchecked_bitnot);
create_gpu_parametrized_test!(integer_unchecked_bitand);
@@ -107,7 +107,6 @@ create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_left);
create_gpu_parametrized_test!(integer_unchecked_scalar_rotate_right);
// Default operations
create_gpu_parametrized_test!(integer_scalar_sub);
create_gpu_parametrized_test!(integer_small_scalar_mul);
create_gpu_parametrized_test!(integer_scalar_right_shift);
create_gpu_parametrized_test!(integer_scalar_left_shift);
@@ -318,14 +317,6 @@ where
unchecked_small_scalar_mul_test(param, executor);
}
fn integer_unchecked_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_sub);
unchecked_scalar_sub_test(param, executor);
}
fn integer_unchecked_bitnot<P>(param: P)
where
P: Into<PBSParameters>,
@@ -1496,14 +1487,6 @@ where
default_small_scalar_mul_test(param, executor);
}
fn integer_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_sub);
default_scalar_sub_test(param, executor);
}
fn integer_bitnot<P>(param: P)
where
P: Into<PBSParameters> + Copy,

View File

@@ -0,0 +1,27 @@
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parametrized_test, GpuFunctionExecutor,
};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{
default_scalar_sub_test, unchecked_scalar_sub_test,
};
use crate::shortint::parameters::*;
create_gpu_parametrized_test!(integer_unchecked_scalar_sub);
create_gpu_parametrized_test!(integer_scalar_sub);
fn integer_unchecked_scalar_sub<P>(param: P)
where
P: Into<PBSParameters> + Copy,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_sub);
unchecked_scalar_sub_test(param, executor);
}
fn integer_scalar_sub<P>(param: P)
where
P: Into<PBSParameters> + Copy,
{
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_sub);
default_scalar_sub_test(param, executor);
}

View File

@@ -1057,3 +1057,221 @@ where
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
}
}
pub(crate) fn signed_unchecked_scalar_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>,
{
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
executor.setup(&cks, sks);
// check some overflow behaviour
let overflowing_values = [
(-modulus, 1, modulus - 1),
(modulus - 1, -1, -modulus),
(-modulus, 2, modulus - 2),
(modulus - 2, -2, -modulus),
];
for (clear_0, clear_1, expected_clear) in overflowing_values {
let ctxt_0 = cks.encrypt_signed(clear_0);
let ct_res = executor.execute((&ctxt_0, clear_1));
let dec_res: i64 = cks.decrypt_signed(&ct_res);
let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus);
assert_eq!(clear_res, dec_res);
assert_eq!(clear_res, expected_clear);
}
for _ in 0..NB_TESTS {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
let ctxt_0 = cks.encrypt_signed(clear_0);
let ct_res = executor.execute((&ctxt_0, clear_1));
let dec_res: i64 = cks.decrypt_signed(&ct_res);
let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus);
assert_eq!(clear_res, dec_res);
}
}
pub(crate) fn signed_default_overflowing_scalar_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<
(&'a SignedRadixCiphertext, i64),
(SignedRadixCiphertext, BooleanBlock),
>,
{
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let mut rng = rand::thread_rng();
// message_modulus^vec_length
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
executor.setup(&cks, sks.clone());
let hardcoded_values = [
(-modulus, 1),
(modulus - 1, -1),
(1, -modulus),
(-1, modulus - 1),
];
for (clear_0, clear_1) in hardcoded_values {
let ctxt_0 = cks.encrypt_signed(clear_0);
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1));
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
}
for _ in 0..NB_TESTS_SMALLER {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
let ctxt_0 = cks.encrypt_signed(clear_0);
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, clear_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
for _ in 0..NB_TESTS_SMALLER {
// Add non zero scalar to have non clean ciphertexts
let clear_2 = random_non_zero_value(&mut rng, modulus);
let clear_rhs = random_non_zero_value(&mut rng, modulus);
let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
let (clear_lhs, _) = signed_overflowing_add_under_modulus(clear_0, clear_2, modulus);
let d0: i64 = cks.decrypt_signed(&ctxt_0);
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_rhs));
assert!(ct_res.block_carries_are_empty());
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_lhs} + {clear_rhs}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
}
}
// Test with trivial inputs
for _ in 0..4 {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1));
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(encrypted_overflow.0.degree.get(), 1);
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
}
// Test with scalar that is bigger than ciphertext modulus
for _ in 0..2 {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen_range(modulus..=i64::MAX);
let a = cks.encrypt_signed(clear_0);
let (encrypted_result, encrypted_overflow) = executor.execute((&a, clear_1));
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert!(decrypted_overflowed); // Actually we know its an overflow case
assert_eq!(encrypted_overflow.0.degree.get(), 1);
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
}
}

View File

@@ -2,6 +2,7 @@ pub(crate) mod test_add;
pub(crate) mod test_mul;
pub(crate) mod test_neg;
pub(crate) mod test_scalar_add;
pub(crate) mod test_scalar_sub;
pub(crate) mod test_sub;
use crate::integer::keycache::KEY_CACHE;
@@ -11,7 +12,6 @@ use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecu
use crate::integer::{
BooleanBlock, IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext,
};
use crate::shortint::ciphertext::NoiseLevel;
#[cfg(tarpaulin)]
use crate::shortint::parameters::coverage_parameters::*;
use crate::shortint::parameters::*;
@@ -1512,7 +1512,6 @@ where
// Unchecked Scalar Tests
//================================================================================
create_parametrized_test!(integer_signed_unchecked_scalar_sub);
create_parametrized_test!(integer_signed_unchecked_scalar_mul);
create_parametrized_test!(integer_signed_unchecked_scalar_rotate_left);
create_parametrized_test!(integer_signed_unchecked_scalar_rotate_right);
@@ -1524,42 +1523,6 @@ create_parametrized_test!(integer_signed_unchecked_scalar_bitxor);
create_parametrized_test!(integer_signed_unchecked_scalar_div_rem);
create_parametrized_test!(integer_signed_unchecked_scalar_div_rem_floor);
fn integer_signed_unchecked_scalar_sub(param: impl Into<PBSParameters>) {
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let mut rng = rand::thread_rng();
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
// check some overflow behaviour
let overflowing_values = [
(-modulus, 1, modulus - 1),
(modulus - 1, -1, -modulus),
(-modulus, 2, modulus - 2),
(modulus - 2, -2, -modulus),
];
for (clear_0, clear_1, expected_clear) in overflowing_values {
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
let ct_res = sks.unchecked_scalar_sub(&ctxt_0, clear_1);
let dec_res: i64 = cks.decrypt_signed_radix(&ct_res);
let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus);
assert_eq!(clear_res, dec_res);
assert_eq!(clear_res, expected_clear);
}
for _ in 0..NB_TESTS {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
let ctxt_0 = cks.encrypt_signed_radix(clear_0, NB_CTXT);
let ct_res = sks.unchecked_scalar_sub(&ctxt_0, clear_1);
let dec_res: i64 = cks.decrypt_signed_radix(&ct_res);
let clear_res = signed_sub_under_modulus(clear_0, clear_1, modulus);
assert_eq!(clear_res, dec_res);
}
}
fn integer_signed_unchecked_scalar_mul(param: impl Into<PBSParameters>) {
let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
@@ -2040,7 +2003,6 @@ fn integer_signed_unchecked_scalar_div_rem_floor(param: impl Into<PBSParameters>
// Default Scalar Tests
//================================================================================
create_parametrized_test!(integer_signed_default_overflowing_scalar_sub);
create_parametrized_test!(integer_signed_default_scalar_bitand);
create_parametrized_test!(integer_signed_default_scalar_bitor);
create_parametrized_test!(integer_signed_default_scalar_bitxor);
@@ -2050,179 +2012,6 @@ create_parametrized_test!(integer_signed_default_scalar_right_shift);
create_parametrized_test!(integer_signed_default_scalar_rotate_right);
create_parametrized_test!(integer_signed_default_scalar_rotate_left);
pub(crate) fn integer_signed_default_overflowing_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let mut rng = rand::thread_rng();
// message_modulus^vec_length
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
let hardcoded_values = [
(-modulus, 1),
(modulus - 1, -1),
(1, -modulus),
(-1, modulus - 1),
];
for (clear_0, clear_1) in hardcoded_values {
let ctxt_0 = cks.encrypt_signed(clear_0);
let (ct_res, result_overflowed) =
sks.signed_overflowing_scalar_sub_parallelized(&ctxt_0, clear_1);
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
}
for _ in 0..NB_TESTS_SMALLER {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
let ctxt_0 = cks.encrypt_signed(clear_0);
let (ct_res, result_overflowed) =
sks.signed_overflowing_scalar_sub_parallelized(&ctxt_0, clear_1);
let (tmp_ct, tmp_o) = sks.signed_overflowing_scalar_sub_parallelized(&ctxt_0, clear_1);
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
for _ in 0..NB_TESTS_SMALLER {
// Add non zero scalar to have non clean ciphertexts
let clear_2 = random_non_zero_value(&mut rng, modulus);
let clear_rhs = random_non_zero_value(&mut rng, modulus);
let ctxt_0 = sks.unchecked_scalar_add(&ctxt_0, clear_2);
let (clear_lhs, _) = signed_overflowing_add_under_modulus(clear_0, clear_2, modulus);
let d0: i64 = cks.decrypt_signed(&ctxt_0);
assert_eq!(d0, clear_lhs, "Failed sanity decryption check");
let (ct_res, result_overflowed) =
sks.signed_overflowing_scalar_sub_parallelized(&ctxt_0, clear_rhs);
assert!(ct_res.block_carries_are_empty());
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_lhs, clear_rhs, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_lhs} + {clear_rhs}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_lhs} - {clear_rhs}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
}
}
// Test with trivial inputs
for _ in 0..4 {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
let (encrypted_result, encrypted_overflow) =
sks.signed_overflowing_scalar_sub_parallelized(&a, clear_1);
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for add, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(encrypted_overflow.0.degree.get(), 1);
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
}
// Test with scalar that is bigger than ciphertext modulus
for _ in 0..2 {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen_range(modulus..=i64::MAX);
let a = cks.encrypt_signed(clear_0);
let (encrypted_result, encrypted_overflow) =
sks.signed_overflowing_scalar_sub_parallelized(&a, clear_1);
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert!(decrypted_overflowed); // Actually we know its an overflow case
assert_eq!(encrypted_overflow.0.degree.get(), 1);
assert_eq!(encrypted_overflow.0.noise_level(), NoiseLevel::ZERO);
}
}
fn integer_signed_default_scalar_bitand(param: impl Into<PBSParameters>) {
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
sks.set_deterministic_pbs_execution(true);

View File

@@ -0,0 +1,27 @@
use crate::integer::server_key::radix_parallel::tests_cases_signed::{
signed_default_overflowing_scalar_sub_test, signed_unchecked_scalar_sub_test,
};
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
use crate::integer::ServerKey;
#[cfg(tarpaulin)]
use crate::shortint::parameters::coverage_parameters::*;
use crate::shortint::parameters::*;
create_parametrized_test!(integer_signed_unchecked_scalar_sub);
create_parametrized_test!(integer_signed_default_overflowing_scalar_sub);
fn integer_signed_unchecked_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_sub);
signed_unchecked_scalar_sub_test(param, executor);
}
fn integer_signed_default_overflowing_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_sub_parallelized);
signed_default_overflowing_scalar_sub_test(param, executor);
}

View File

@@ -2,6 +2,7 @@ pub(crate) mod test_add;
pub(crate) mod test_mul;
pub(crate) mod test_neg;
pub(crate) mod test_scalar_add;
pub(crate) mod test_scalar_sub;
pub(crate) mod test_sub;
use super::tests_cases_unsigned::*;
@@ -597,9 +598,6 @@ create_parametrized_test!(integer_unchecked_scalar_rotate_left);
create_parametrized_test!(integer_default_scalar_rotate_right);
create_parametrized_test!(integer_default_scalar_rotate_left);
create_parametrized_test!(integer_default_scalar_div_rem);
create_parametrized_test!(integer_smart_scalar_sub);
create_parametrized_test!(integer_default_scalar_sub);
create_parametrized_test!(integer_default_overflowing_scalar_sub);
create_parametrized_test!(integer_smart_if_then_else);
create_parametrized_test!(integer_default_if_then_else);
create_parametrized_test!(integer_trim_radix_msb_blocks_handles_dirty_inputs);
@@ -982,14 +980,6 @@ where
// Smart Scalar Tests
//=============================================================================
fn integer_smart_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_sub_parallelized);
smart_scalar_sub_test(param, executor);
}
fn integer_smart_small_scalar_mul<P>(param: P)
where
P: Into<PBSParameters>,
@@ -1134,23 +1124,6 @@ where
// Default Scalar Tests
//=============================================================================
fn integer_default_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::scalar_sub_parallelized);
default_scalar_sub_test(param, executor);
}
fn integer_default_overflowing_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor =
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_sub_parallelized);
default_overflowing_scalar_sub_test(param, executor);
}
fn integer_default_scalar_bitand<P>(param: P)
where
P: Into<PBSParameters>,

View File

@@ -0,0 +1,37 @@
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::{
default_overflowing_scalar_sub_test, default_scalar_sub_test, smart_scalar_sub_test,
};
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
use crate::integer::ServerKey;
#[cfg(tarpaulin)]
use crate::shortint::parameters::coverage_parameters::*;
use crate::shortint::parameters::*;
create_parametrized_test!(integer_smart_scalar_sub);
create_parametrized_test!(integer_default_scalar_sub);
create_parametrized_test!(integer_default_overflowing_scalar_sub);
fn integer_smart_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_sub_parallelized);
smart_scalar_sub_test(param, executor);
}
fn integer_default_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::scalar_sub_parallelized);
default_scalar_sub_test(param, executor);
}
fn integer_default_overflowing_scalar_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor =
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_sub_parallelized);
default_overflowing_scalar_sub_test(param, executor);
}