mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
chore(gpu): refactor signed overflow sub test to use FnExecutor
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
use crate::integer::ciphertext::IntegerRadixCiphertext;
|
||||
use crate::integer::server_key::radix_parallel::sub::SignedOperation;
|
||||
use crate::integer::server_key::CheckError;
|
||||
use crate::integer::{BooleanBlock, ServerKey, SignedRadixCiphertext};
|
||||
use crate::integer::server_key::radix_parallel::sub::SignedOperation;
|
||||
use crate::shortint::ciphertext::{Degree, MaxDegree, NoiseLevel};
|
||||
|
||||
impl ServerKey {
|
||||
@@ -269,5 +269,4 @@ impl ServerKey {
|
||||
) -> (SignedRadixCiphertext, BooleanBlock) {
|
||||
self.unchecked_signed_overflowing_add_or_sub(lhs, rhs, SignedOperation::Addition)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -63,13 +63,14 @@ fn integer_signed_unchecked_overflowing_add_parallelized<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_signed_overflowing_add_parallelized);
|
||||
let executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::unchecked_signed_overflowing_add_parallelized);
|
||||
signed_unchecked_overflowing_add_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_signed_default_overflowing_add<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_add_parallelized);
|
||||
signed_default_overflowing_add_test(param, executor);
|
||||
@@ -84,20 +85,20 @@ where
|
||||
}
|
||||
|
||||
fn integer_signed_smart_add<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::smart_add_parallelized);
|
||||
signed_smart_add_test(param, executor);
|
||||
}
|
||||
|
||||
pub(crate) fn signed_unchecked_overflowing_add_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
(SignedRadixCiphertext, BooleanBlock),
|
||||
>,
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
(SignedRadixCiphertext, BooleanBlock),
|
||||
>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
@@ -122,8 +123,7 @@ pub(crate) fn signed_unchecked_overflowing_add_test<P, T>(param: P, mut executor
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
let ctxt_1 = cks.encrypt_signed(clear_1);
|
||||
|
||||
let (ct_res, result_overflowed) =
|
||||
executor.execute((&ctxt_0, &ctxt_1));
|
||||
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_add_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
@@ -279,8 +279,7 @@ where
|
||||
let d1: i64 = cks.decrypt_signed(&ctxt_1);
|
||||
assert_eq!(d1, clear_rhs, "Failed sanity decryption check");
|
||||
|
||||
let (ct_res, result_overflowed) =
|
||||
executor.execute((&ctxt_0, &ctxt_1));
|
||||
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
|
||||
let (expected_result, expected_overflowed) =
|
||||
|
||||
@@ -50,126 +50,12 @@ where
|
||||
signed_unchecked_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn signed_unchecked_overflowing_sub_test_case<P, F>(param: P, signed_overflowing_sub: F)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
F: for<'a> Fn(
|
||||
&'a ServerKey,
|
||||
&'a SignedRadixCiphertext,
|
||||
&'a SignedRadixCiphertext,
|
||||
) -> (SignedRadixCiphertext, BooleanBlock),
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
|
||||
|
||||
let hardcoded_values = [
|
||||
(-modulus, 1),
|
||||
(modulus - 1, -1),
|
||||
(1, -modulus),
|
||||
(-1, modulus - 1),
|
||||
];
|
||||
for (clear_0, clear_1) in hardcoded_values {
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
|
||||
let (ct_res, result_overflowed) =
|
||||
sks.signed_overflowing_scalar_sub_parallelized(&ctxt_0, clear_1);
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
assert_eq!(result_overflowed.0.degree.get(), 1);
|
||||
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
|
||||
}
|
||||
|
||||
for _ in 0..nb_tests {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let clear_1 = rng.gen::<i64>() % modulus;
|
||||
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
let ctxt_1 = cks.encrypt_signed(clear_1);
|
||||
|
||||
let (ct_res, result_overflowed) = signed_overflowing_sub(&sks, &ctxt_0, &ctxt_1);
|
||||
let (tmp_ct, tmp_o) = signed_overflowing_sub(&sks, &ctxt_0, &ctxt_1);
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
|
||||
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");
|
||||
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
}
|
||||
|
||||
// Test with trivial inputs, as it was bugged at some point
|
||||
let values = [
|
||||
(rng.gen::<i64>() % modulus, 0i64),
|
||||
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
|
||||
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
|
||||
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
|
||||
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
|
||||
];
|
||||
for (clear_0, clear_1) in values {
|
||||
let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
|
||||
let b: SignedRadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT);
|
||||
|
||||
let (encrypted_result, encrypted_overflow) = signed_overflowing_sub(&sks, &a, &b);
|
||||
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn integer_signed_unchecked_overflowing_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
signed_unchecked_overflowing_sub_test_case(param, ServerKey::unchecked_signed_overflowing_sub);
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_signed_overflowing_sub);
|
||||
signed_unchecked_overflowing_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_signed_unchecked_overflowing_sub_parallelized<P>(param: P)
|
||||
@@ -177,12 +63,10 @@ where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
// Call _impl so we are sure the parallel version is tested
|
||||
//
|
||||
// However this only supports param X_X where X >= 4
|
||||
signed_unchecked_overflowing_sub_test_case(
|
||||
param,
|
||||
ServerKey::unchecked_signed_overflowing_sub_parallelized_impl,
|
||||
);
|
||||
let executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::unchecked_signed_overflowing_sub_parallelized_impl);
|
||||
signed_unchecked_overflowing_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_signed_default_sub<P>(param: P)
|
||||
@@ -196,6 +80,17 @@ where
|
||||
fn integer_signed_default_overflowing_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::signed_overflowing_sub_parallelized);
|
||||
signed_default_overflowing_sub_test(param, executor);
|
||||
}
|
||||
pub(crate) fn signed_default_overflowing_sub_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
(SignedRadixCiphertext, BooleanBlock),
|
||||
>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests_smaller = nb_tests_smaller_for_params(param);
|
||||
@@ -203,6 +98,9 @@ where
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
executor.setup(&cks, sks.clone());
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
@@ -216,8 +114,8 @@ where
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
let ctxt_1 = cks.encrypt_signed(clear_1);
|
||||
|
||||
let (ct_res, result_overflowed) = sks.signed_overflowing_sub_parallelized(&ctxt_0, &ctxt_1);
|
||||
let (tmp_ct, tmp_o) = sks.signed_overflowing_sub_parallelized(&ctxt_0, &ctxt_1);
|
||||
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
|
||||
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1));
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
|
||||
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");
|
||||
@@ -257,8 +155,7 @@ where
|
||||
let d1: i64 = cks.decrypt_signed(&ctxt_1);
|
||||
assert_eq!(d1, clear_rhs, "Failed sanity decryption check");
|
||||
|
||||
let (ct_res, result_overflowed) =
|
||||
sks.signed_overflowing_sub_parallelized(&ctxt_0, &ctxt_1);
|
||||
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
|
||||
let (expected_result, expected_overflowed) =
|
||||
@@ -292,8 +189,7 @@ where
|
||||
let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
|
||||
let b: SignedRadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT);
|
||||
|
||||
let (encrypted_result, encrypted_overflow) =
|
||||
sks.signed_overflowing_sub_parallelized(&a, &b);
|
||||
let (encrypted_result, encrypted_overflow) = executor.execute((&a, &b));
|
||||
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
@@ -316,6 +212,123 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn signed_unchecked_overflowing_sub_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
T: for<'a> FunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
(SignedRadixCiphertext, BooleanBlock),
|
||||
>,
|
||||
{
|
||||
let param = param.into();
|
||||
let nb_tests = nb_tests_for_params(param);
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT));
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
executor.setup(&cks, sks.clone());
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let modulus = (cks.parameters().message_modulus().0.pow(NB_CTXT as u32) / 2) as i64;
|
||||
|
||||
let hardcoded_values = [
|
||||
(-modulus, 1),
|
||||
(modulus - 1, -1),
|
||||
(1, -modulus),
|
||||
(-1, modulus - 1),
|
||||
];
|
||||
for (clear_0, clear_1) in hardcoded_values {
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
let ctxt_1 = cks.encrypt_signed(clear_1);
|
||||
|
||||
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
assert_eq!(result_overflowed.0.degree.get(), 1);
|
||||
assert_eq!(result_overflowed.0.noise_level(), NoiseLevel::NOMINAL);
|
||||
}
|
||||
|
||||
for _ in 0..nb_tests {
|
||||
let clear_0 = rng.gen::<i64>() % modulus;
|
||||
let clear_1 = rng.gen::<i64>() % modulus;
|
||||
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
let ctxt_1 = cks.encrypt_signed(clear_1);
|
||||
|
||||
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, &ctxt_1));
|
||||
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, &ctxt_1));
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
|
||||
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");
|
||||
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&ct_res);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&result_overflowed);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
}
|
||||
|
||||
// Test with trivial inputs, as it was bugged at some point
|
||||
let values = [
|
||||
(rng.gen::<i64>() % modulus, 0i64),
|
||||
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
|
||||
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
|
||||
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
|
||||
(rng.gen::<i64>() % modulus, rng.gen::<i64>() % modulus),
|
||||
];
|
||||
for (clear_0, clear_1) in values {
|
||||
let a: SignedRadixCiphertext = sks.create_trivial_radix(clear_0, NB_CTXT);
|
||||
let b: SignedRadixCiphertext = sks.create_trivial_radix(clear_1, NB_CTXT);
|
||||
|
||||
let (encrypted_result, encrypted_overflow) = executor.execute((&a, &b));
|
||||
|
||||
let (expected_result, expected_overflowed) =
|
||||
signed_overflowing_sub_under_modulus(clear_0, clear_1, modulus);
|
||||
|
||||
let decrypted_result: i64 = cks.decrypt_signed(&encrypted_result);
|
||||
let decrypted_overflowed = cks.decrypt_bool(&encrypted_overflow);
|
||||
assert_eq!(
|
||||
decrypted_result, expected_result,
|
||||
"Invalid result for sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected {expected_result}, got {decrypted_result}"
|
||||
);
|
||||
assert_eq!(
|
||||
decrypted_overflowed,
|
||||
expected_overflowed,
|
||||
"Invalid overflow flag result for overflowing_sub, for ({clear_0} - {clear_1}) % {modulus} \
|
||||
expected overflow flag {expected_overflowed}, got {decrypted_overflowed}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn signed_unchecked_sub_test<P, T>(param: P, mut executor: T)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
|
||||
Reference in New Issue
Block a user