chore(gpu): refactor signed overflow sub test to use FnExecutor

This commit is contained in:
Agnes Leroy
2024-06-11 10:25:34 +02:00
committed by Agnès Leroy
parent 418409231b
commit 2185bcf80e
3 changed files with 155 additions and 144 deletions

View File

@@ -1,7 +1,7 @@
use crate::integer::ciphertext::IntegerRadixCiphertext;
use crate::integer::server_key::radix_parallel::sub::SignedOperation;
use crate::integer::server_key::CheckError;
use crate::integer::{BooleanBlock, ServerKey, SignedRadixCiphertext};
use crate::integer::server_key::radix_parallel::sub::SignedOperation;
use crate::shortint::ciphertext::{Degree, MaxDegree, NoiseLevel};
impl ServerKey {
@@ -269,5 +269,4 @@ impl ServerKey {
) -> (SignedRadixCiphertext, BooleanBlock) {
self.unchecked_signed_overflowing_add_or_sub(lhs, rhs, SignedOperation::Addition)
}
}

View File

@@ -63,13 +63,14 @@ fn integer_signed_unchecked_overflowing_add_parallelized<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_signed_overflowing_add_parallelized);
let executor =
CpuFunctionExecutor::new(&ServerKey::unchecked_signed_overflowing_add_parallelized);
signed_unchecked_overflowing_add_test(param, executor);
}
fn integer_signed_default_overflowing_add<P>(param: P)
where
P: Into<PBSParameters>,
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_add_parallelized);
signed_default_overflowing_add_test(param, executor);
@@ -84,20 +85,20 @@ where
}
fn integer_signed_smart_add<P>(param: P)
where
P: Into<PBSParameters>,
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::smart_add_parallelized);
signed_smart_add_test(param, executor);
}
pub(crate) fn signed_unchecked_overflowing_add_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
(SignedRadixCiphertext, BooleanBlock),
>,
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
(SignedRadixCiphertext, BooleanBlock),
>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
@@ -122,8 +123,7 @@ pub(crate) fn signed_unchecked_overflowing_add_test<P, T>(param: P, mut executor
let ctxt_0 = cks.encrypt_signed(clear_0);
let ctxt_1 = cks.encrypt_signed(clear_1);
let (ct_res, result_overflowed) =
executor.execute((&ctxt_0, &ctxt_1));
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
let (expected_result, expected_overflowed) =
signed_overflowing_add_under_modulus(clear_0, clear_1, modulus);
@@ -279,8 +279,7 @@ where
let d1: i64 = cks.decrypt_signed(&ctxt_1);
assert_eq!(d1, clear_rhs, "Failed sanity decryption check");
let (ct_res, result_overflowed) =
executor.execute((&ctxt_0, &ctxt_1));
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
let (expected_result, expected_overflowed) =

View File

@@ -50,126 +50,12 @@ where
signed_unchecked_sub_test(param, executor);
}
fn signed_unchecked_overflowing_sub_test_case<P, F>(param: P, signed_overflowing_sub: F)
where
P: Into<PBSParameters>,
F: for<'a> Fn(
&'a ServerKey,
&'a SignedRadixCiphertext,
&'a SignedRadixCiphertext,
) -> (SignedRadixCiphertext, BooleanBlock),
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let mut rng = rand::thread_rng();
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
let hardcoded_values = [
(-modulus, 1),
(modulus - 1, -1),
(1, -modulus),
(-1, modulus - 1),
];
for (clear_0, clear_1) in hardcoded_values {
let ctxt_0 = cks.encrypt_signed(clear_0);
let (ct_res, result_overflowed) =
sks.signed_overflowing_scalar_sub_parallelized(&ctxt_0, clear_1);
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
}
for _ in 0..nb_tests {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
let ctxt_0 = cks.encrypt_signed(clear_0);
let ctxt_1 = cks.encrypt_signed(clear_1);
let (ct_res, result_overflowed) = signed_overflowing_sub(&sks, &ctxt_0, &ctxt_1);
let (tmp_ct, tmp_o) = signed_overflowing_sub(&sks, &ctxt_0, &ctxt_1);
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
}
// Test with trivial inputs, as it was bugged at some point
let values = [
(rng.gen::<i64>() % modulus, 0i64),
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
];
for (clear_0, clear_1) in values {
let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
let b: SignedRadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT);
let (encrypted_result, encrypted_overflow) = signed_overflowing_sub(&sks, &a, &b);
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
}
}
fn integer_signed_unchecked_overflowing_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
signed_unchecked_overflowing_sub_test_case(param, ServerKey::unchecked_signed_overflowing_sub);
let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_signed_overflowing_sub);
signed_unchecked_overflowing_sub_test(param, executor);
}
fn integer_signed_unchecked_overflowing_sub_parallelized<P>(param: P)
@@ -177,12 +63,10 @@ where
P: Into<PBSParameters>,
{
// Call _impl so we are sure the parallel version is tested
//
// However this only supports param X_X where X >= 4
signed_unchecked_overflowing_sub_test_case(
param,
ServerKey::unchecked_signed_overflowing_sub_parallelized_impl,
);
let executor =
CpuFunctionExecutor::new(&ServerKey::unchecked_signed_overflowing_sub_parallelized_impl);
signed_unchecked_overflowing_sub_test(param, executor);
}
fn integer_signed_default_sub<P>(param: P)
@@ -196,6 +80,17 @@ where
fn integer_signed_default_overflowing_sub<P>(param: P)
where
P: Into<PBSParameters>,
{
let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_sub_parallelized);
signed_default_overflowing_sub_test(param, executor);
}
pub(crate) fn signed_default_overflowing_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
(SignedRadixCiphertext, BooleanBlock),
>,
{
let param = param.into();
let nb_tests_smaller = nb_tests_smaller_for_params(param);
@@ -203,6 +98,9 @@ where
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
executor.setup(&cks, sks.clone());
let mut rng = rand::thread_rng();
@@ -216,8 +114,8 @@ where
let ctxt_0 = cks.encrypt_signed(clear_0);
let ctxt_1 = cks.encrypt_signed(clear_1);
let (ct_res, result_overflowed) = sks.signed_overflowing_sub_parallelized(&ctxt_0, &ctxt_1);
let (tmp_ct, tmp_o) = sks.signed_overflowing_sub_parallelized(&ctxt_0, &ctxt_1);
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");
@@ -257,8 +155,7 @@ where
let d1: i64 = cks.decrypt_signed(&ctxt_1);
assert_eq!(d1, clear_rhs, "Failed sanity decryption check");
let (ct_res, result_overflowed) =
sks.signed_overflowing_sub_parallelized(&ctxt_0, &ctxt_1);
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
let (expected_result, expected_overflowed) =
@@ -292,8 +189,7 @@ where
let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
let b: SignedRadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT);
let (encrypted_result, encrypted_overflow) =
sks.signed_overflowing_sub_parallelized(&a, &b);
let (encrypted_result, encrypted_overflow) = executor.execute((&a, &b));
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
@@ -316,6 +212,123 @@ where
}
}
pub(crate) fn signed_unchecked_overflowing_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,
T: for<'a> FunctionExecutor<
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
(SignedRadixCiphertext, BooleanBlock),
>,
{
let param = param.into();
let nb_tests = nb_tests_for_params(param);
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = RadixClientKey::from((cks, NB_CTXT));
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
executor.setup(&cks, sks.clone());
let mut rng = rand::thread_rng();
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
let hardcoded_values = [
(-modulus, 1),
(modulus - 1, -1),
(1, -modulus),
(-1, modulus - 1),
];
for (clear_0, clear_1) in hardcoded_values {
let ctxt_0 = cks.encrypt_signed(clear_0);
let ctxt_1 = cks.encrypt_signed(clear_1);
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
assert_eq!(result_overflowed.0.degree.get(), 1);
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
}
for _ in 0..nb_tests {
let clear_0 = rng.gen::<i64>() % modulus;
let clear_1 = rng.gen::<i64>() % modulus;
let ctxt_0 = cks.encrypt_signed(clear_0);
let ctxt_1 = cks.encrypt_signed(clear_1);
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1));
assert!(ct_res.block_carries_are_empty());
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
}
// Test with trivial inputs, as it was bugged at some point
let values = [
(rng.gen::<i64>() % modulus, 0i64),
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
];
for (clear_0, clear_1) in values {
let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
let b: SignedRadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT);
let (encrypted_result, encrypted_overflow) = executor.execute((&a, &b));
let (expected_result, expected_overflowed) =
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
assert_eq!(
decrypted_result, expected_result,
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
expected {expected_result}, got {decrypted_result}"
);
assert_eq!(
decrypted_overflowed,
expected_overflowed,
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
);
}
}
pub(crate) fn signed_unchecked_sub_test<P, T>(param: P, mut executor: T)
where
P: Into<PBSParameters>,