chore(gpu): make deterministic long run GPU test

This commit is contained in:
Andrei Stoian
2025-09-22 22:57:44 +02:00
committed by Andrei Stoian
parent 45a849ad36
commit 73de886c07
17 changed files with 2684 additions and 542 deletions

View File

@@ -287,7 +287,7 @@ mod gpu {
}
impl CudaGpuChoice {
pub(in crate::high_level_api) fn build_streams(self) -> CudaStreams {
pub(crate) fn build_streams(self) -> CudaStreams {
match self {
Self::Single(idx) => CudaStreams::new_single_gpu(idx),
Self::Multi => CudaStreams::new_multi_gpu(),

View File

@@ -53,6 +53,8 @@ use backward_compatibility::compressed_ciphertext_list::SquashedNoiseCiphertextS
pub use config::{Config, ConfigBuilder};
#[cfg(feature = "gpu")]
pub use global_state::CudaGpuChoice;
#[cfg(feature = "gpu")]
pub use global_state::CustomMultiGpuIndexes;
pub use global_state::{set_server_key, unset_server_key, with_server_key_as_context};
pub use integers::{

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_unsigned::create_gpu_parameterized_test;
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_long_run::test_erc20::{

View File

@@ -1,17 +1,137 @@
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_long_run::OpSequenceGpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_unsigned::create_gpu_parameterized_test;
use crate::integer::gpu::CudaServerKey;
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::tests_long_run::test_random_op_sequence::{
random_op_sequence_test, BinaryOpExecutor, ComparisonOpExecutor, DivRemOpExecutor,
Log2OpExecutor, OverflowingOpExecutor, ScalarBinaryOpExecutor, ScalarComparisonOpExecutor,
ScalarDivRemOpExecutor, ScalarOverflowingOpExecutor, SelectOpExecutor, UnaryOpExecutor,
};
use crate::integer::server_key::radix_parallel::tests_long_run::{
get_user_defined_seed, RandomOpSequenceDataGenerator, NB_CTXT_LONG_RUN,
};
use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey};
use crate::shortint::parameters::*;
use crate::{ClientKey, CompressedServerKey, Tag};
use std::cmp::{max, min};
use std::sync::Arc;
create_gpu_parameterized_test!(random_op_sequence {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
#[allow(clippy::too_many_arguments)]
pub(crate) fn random_op_sequence_test_init_gpu<P>(
param: P,
binary_ops: &mut [(BinaryOpExecutor, impl Fn(u64, u64) -> u64, String)],
unary_ops: &mut [(UnaryOpExecutor, impl Fn(u64) -> u64, String)],
scalar_binary_ops: &mut [(ScalarBinaryOpExecutor, impl Fn(u64, u64) -> u64, String)],
overflowing_ops: &mut [(
OverflowingOpExecutor,
impl Fn(u64, u64) -> (u64, bool),
String,
)],
scalar_overflowing_ops: &mut [(
ScalarOverflowingOpExecutor,
impl Fn(u64, u64) -> (u64, bool),
String,
)],
comparison_ops: &mut [(ComparisonOpExecutor, impl Fn(u64, u64) -> bool, String)],
scalar_comparison_ops: &mut [(
ScalarComparisonOpExecutor,
impl Fn(u64, u64) -> bool,
String,
)],
select_op: &mut [(SelectOpExecutor, impl Fn(bool, u64, u64) -> u64, String)],
div_rem_op: &mut [(DivRemOpExecutor, impl Fn(u64, u64) -> (u64, u64), String)],
scalar_div_rem_op: &mut [(
ScalarDivRemOpExecutor,
impl Fn(u64, u64) -> (u64, u64),
String,
)],
log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)],
) -> (
RadixClientKey,
Arc<ServerKey>,
RandomOpSequenceDataGenerator<u64, RadixCiphertext>,
)
where
P: Into<TestParameters>,
{
let param = param.into();
let (cks0, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let cks = ClientKey::from_raw_parts(cks0.clone(), None, None, None, None, None, Tag::default());
let comp_sks = CompressedServerKey::new(&cks);
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
println!("Setting up operations");
let cks = RadixClientKey::from((cks0, NB_CTXT_LONG_RUN));
let total_num_ops = binary_ops.len()
+ unary_ops.len()
+ scalar_binary_ops.len()
+ overflowing_ops.len()
+ scalar_overflowing_ops.len()
+ comparison_ops.len()
+ scalar_comparison_ops.len()
+ select_op.len()
+ div_rem_op.len()
+ scalar_div_rem_op.len()
+ log2_ops.len();
println!("Total num ops {total_num_ops}");
let mut datagen = get_user_defined_seed().map_or_else(
|| RandomOpSequenceDataGenerator::<u64, RadixCiphertext>::new(total_num_ops, &cks),
|seed| {
RandomOpSequenceDataGenerator::<u64, RadixCiphertext>::new_with_seed(
total_num_ops,
seed,
&cks,
)
},
);
println!("random_op_sequence_test::seed = {}", datagen.get_seed().0);
for x in binary_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in unary_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_binary_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in overflowing_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_overflowing_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in comparison_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_comparison_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in select_op.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in div_rem_op.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_div_rem_op.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in log2_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
(cks, sks, datagen)
}
fn random_op_sequence<P>(param: P)
where
P: Into<TestParameters> + Clone,
@@ -19,18 +139,24 @@ where
println!("Running random op sequence test");
// Binary Ops Executors
let add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
let sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
let bitwise_and_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitand);
let bitwise_or_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitor);
let bitwise_xor_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitxor);
let mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::mul);
let rotate_left_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
let left_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
let rotate_right_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
let right_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
let max_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
let min_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
let add_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
let sub_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
let bitwise_and_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitand);
let bitwise_or_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitor);
let bitwise_xor_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitxor);
let mul_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::mul);
let rotate_left_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
let left_shift_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
let rotate_right_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
let right_shift_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
let max_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
let min_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
// Binary Ops Clear functions
let clear_add = |x, y| x + y;
@@ -93,10 +219,10 @@ where
];
// Unary Ops Executors
let neg_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::neg);
let bitnot_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitnot);
let neg_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::neg);
let bitnot_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitnot);
//let reverse_bits_executor =
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::reverse_bits); Unary Ops Clear
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::reverse_bits); Unary Ops Clear
// functions
let clear_neg = |x: u64| x.wrapping_neg();
let clear_bitnot = |x: u64| !x;
@@ -117,23 +243,26 @@ where
];
// Scalar binary Ops Executors
let scalar_add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
let scalar_sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
let scalar_add_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
let scalar_sub_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
let scalar_bitwise_and_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitand);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitand);
let scalar_bitwise_or_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitor);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitor);
let scalar_bitwise_xor_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitxor);
let scalar_mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_mul);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitxor);
let scalar_mul_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_mul);
let scalar_rotate_left_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
let scalar_left_shift_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
let scalar_rotate_right_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
let scalar_right_shift_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
#[allow(clippy::type_complexity)]
let mut scalar_binary_ops: Vec<(ScalarBinaryOpExecutor, &dyn Fn(u64, u64) -> u64, String)> = vec![
@@ -191,11 +320,11 @@ where
// Overflowing Ops Executors
let overflowing_add_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_add);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_add);
let overflowing_sub_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_sub);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_sub);
//let overflowing_mul_executor =
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_mul);
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_mul);
// Overflowing Ops Clear functions
let clear_overflowing_add = |x: u64, y: u64| -> (u64, bool) { x.overflowing_add(y) };
@@ -226,10 +355,12 @@ where
];
// Scalar Overflowing Ops Executors
let overflowing_scalar_add_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_scalar_add);
let overflowing_scalar_add_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(
&CudaServerKey::unsigned_overflowing_scalar_add,
);
// let overflowing_scalar_sub_executor =
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_scalar_sub);
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&
// CudaServerKey::unsigned_overflowing_scalar_sub);
#[allow(clippy::type_complexity)]
let mut scalar_overflowing_ops: Vec<(
@@ -250,12 +381,12 @@ where
];
// Comparison Ops Executors
let gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::gt);
let ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
let lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
let le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
let eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
let ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
let gt_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::gt);
let ge_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
let lt_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
let le_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
let eq_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
let ne_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
// Comparison Ops Clear functions
let clear_gt = |x: u64, y: u64| -> bool { x > y };
@@ -276,12 +407,18 @@ where
];
// Scalar Comparison Ops Executors
let scalar_gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_gt);
let scalar_ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
let scalar_lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
let scalar_le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
let scalar_eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
let scalar_ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
let scalar_gt_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_gt);
let scalar_ge_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
let scalar_lt_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
let scalar_le_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
let scalar_eq_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
let scalar_ne_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
#[allow(clippy::type_complexity)]
let mut scalar_comparison_ops: Vec<(
@@ -322,7 +459,8 @@ where
];
// Select Executor
let select_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::if_then_else);
let select_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::if_then_else);
// Select
let clear_select = |b: bool, x: u64, y: u64| if b { x } else { y };
@@ -335,7 +473,7 @@ where
)];
// Div executor
let div_rem_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::div_rem);
let div_rem_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::div_rem);
// Div Rem Clear functions
let clear_div_rem = |x: u64, y: u64| -> (u64, u64) { (x.wrapping_div(y), x.wrapping_rem(y)) };
#[allow(clippy::type_complexity)]
@@ -347,7 +485,7 @@ where
// Scalar Div executor
let scalar_div_rem_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_div_rem);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_div_rem);
#[allow(clippy::type_complexity)]
let mut scalar_div_rem_op: Vec<(
ScalarDivRemOpExecutor,
@@ -360,9 +498,11 @@ where
)];
// Log2/Hamming weight ops
let ilog2_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ilog2);
//let count_zeros_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_zeros);
//let count_ones_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_ones);
let ilog2_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ilog2);
//let count_zeros_executor =
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_zeros);
// let count_ones_executor =
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_ones);
let clear_ilog2 = |x: u64| x.ilog2() as u64;
//let clear_count_zeros = |x: u64| x.count_zeros() as u64;
//let clear_count_ones = |x: u64| x.count_ones() as u64;
@@ -382,7 +522,7 @@ where
//),
];
random_op_sequence_test(
let (cks, sks, mut datagen) = random_op_sequence_test_init_gpu(
param,
&mut binary_ops,
&mut unary_ops,
@@ -396,4 +536,21 @@ where
&mut scalar_div_rem_op,
&mut log2_ops,
);
random_op_sequence_test(
&mut datagen,
&cks,
&sks,
&mut binary_ops,
&mut unary_ops,
&mut scalar_binary_ops,
&mut overflowing_ops,
&mut scalar_overflowing_ops,
&mut comparison_ops,
&mut scalar_comparison_ops,
&mut select_op,
&mut div_rem_op,
&mut scalar_div_rem_op,
&mut log2_ops,
);
}

View File

@@ -1,4 +1,4 @@
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_unsigned::create_gpu_parameterized_test;
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_long_run::test_signed_erc20::{

View File

@@ -1,6 +1,7 @@
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_long_run::OpSequenceGpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_unsigned::create_gpu_parameterized_test;
use crate::integer::gpu::CudaServerKey;
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::tests_long_run::test_signed_random_op_sequence::{
signed_random_op_sequence_test, SignedBinaryOpExecutor, SignedComparisonOpExecutor,
SignedDivRemOpExecutor, SignedLog2OpExecutor, SignedOverflowingOpExecutor,
@@ -8,25 +9,177 @@ use crate::integer::server_key::radix_parallel::tests_long_run::test_signed_rand
SignedScalarOverflowingOpExecutor, SignedScalarShiftRotateExecutor, SignedSelectOpExecutor,
SignedShiftRotateExecutor, SignedUnaryOpExecutor,
};
use crate::integer::server_key::radix_parallel::tests_long_run::{
get_user_defined_seed, RandomOpSequenceDataGenerator, NB_CTXT_LONG_RUN,
};
use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext};
use crate::shortint::parameters::*;
use crate::{ClientKey, CompressedServerKey, Tag};
use std::cmp::{max, min};
use std::sync::Arc;
create_gpu_parameterized_test!(signed_random_op_sequence {
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
#[allow(clippy::too_many_arguments)]
pub(crate) fn signed_random_op_sequence_test_init_gpu<P>(
param: P,
binary_ops: &mut [(SignedBinaryOpExecutor, impl Fn(i64, i64) -> i64, String)],
unary_ops: &mut [(SignedUnaryOpExecutor, impl Fn(i64) -> i64, String)],
scalar_binary_ops: &mut [(
SignedScalarBinaryOpExecutor,
impl Fn(i64, i64) -> i64,
String,
)],
overflowing_ops: &mut [(
SignedOverflowingOpExecutor,
impl Fn(i64, i64) -> (i64, bool),
String,
)],
scalar_overflowing_ops: &mut [(
SignedScalarOverflowingOpExecutor,
impl Fn(i64, i64) -> (i64, bool),
String,
)],
comparison_ops: &mut [(
SignedComparisonOpExecutor,
impl Fn(i64, i64) -> bool,
String,
)],
scalar_comparison_ops: &mut [(
SignedScalarComparisonOpExecutor,
impl Fn(i64, i64) -> bool,
String,
)],
select_op: &mut [(
SignedSelectOpExecutor,
impl Fn(bool, i64, i64) -> i64,
String,
)],
div_rem_op: &mut [(
SignedDivRemOpExecutor,
impl Fn(i64, i64) -> (i64, i64),
String,
)],
scalar_div_rem_op: &mut [(
SignedScalarDivRemOpExecutor,
impl Fn(i64, i64) -> (i64, i64),
String,
)],
log2_ops: &mut [(SignedLog2OpExecutor, impl Fn(i64) -> u64, String)],
rotate_shift_ops: &mut [(SignedShiftRotateExecutor, impl Fn(i64, u64) -> i64, String)],
scalar_rotate_shift_ops: &mut [(
SignedScalarShiftRotateExecutor,
impl Fn(i64, u64) -> i64,
String,
)],
) -> (
RadixClientKey,
Arc<ServerKey>,
RandomOpSequenceDataGenerator<i64, SignedRadixCiphertext>,
)
where
P: Into<TestParameters>,
{
let param = param.into();
let (cks0, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let total_num_ops = binary_ops.len()
+ unary_ops.len()
+ scalar_binary_ops.len()
+ overflowing_ops.len()
+ scalar_overflowing_ops.len()
+ comparison_ops.len()
+ scalar_comparison_ops.len()
+ select_op.len()
+ div_rem_op.len()
+ scalar_div_rem_op.len()
+ log2_ops.len()
+ rotate_shift_ops.len()
+ scalar_rotate_shift_ops.len();
let cks = ClientKey::from_raw_parts(cks0.clone(), None, None, None, None, None, Tag::default());
let comp_sks = CompressedServerKey::new(&cks);
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks0, NB_CTXT_LONG_RUN));
let mut datagen = get_user_defined_seed().map_or_else(
|| RandomOpSequenceDataGenerator::<i64, SignedRadixCiphertext>::new(total_num_ops, &cks),
|seed| {
RandomOpSequenceDataGenerator::<i64, SignedRadixCiphertext>::new_with_seed(
total_num_ops,
seed,
&cks,
)
},
);
println!(
"signed_random_op_sequence_test::seed = {}",
datagen.get_seed().0
);
for x in binary_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in unary_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_binary_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in overflowing_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_overflowing_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in comparison_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_comparison_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in select_op.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in div_rem_op.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_div_rem_op.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in log2_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in rotate_shift_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_rotate_shift_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
(cks, sks, datagen)
}
fn signed_random_op_sequence<P>(param: P)
where
P: Into<TestParameters> + Clone,
{
// Binary Ops Executors
let add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
let sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
let bitwise_and_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitand);
let bitwise_or_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitor);
let bitwise_xor_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitxor);
let mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::mul);
let max_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
let min_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
let add_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
let sub_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
let bitwise_and_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitand);
let bitwise_or_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitor);
let bitwise_xor_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitxor);
let mul_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::mul);
let max_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
let min_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
// Binary Ops Clear functions
let clear_add = |x, y| x + y;
@@ -62,10 +215,14 @@ where
(Box::new(min_executor), &clear_min, "min".to_string()),
];
let rotate_left_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
let left_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
let rotate_right_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
let right_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
let rotate_left_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
let left_shift_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
let rotate_right_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
let right_shift_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
// Warning this rotate definition only works with 64-bit ciphertexts
let clear_rotate_left = |x: i64, y: u64| x.rotate_left(y as u32);
let clear_left_shift = |x: i64, y: u64| x << y;
@@ -101,11 +258,11 @@ where
];
// Unary Ops Executors
let neg_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::neg);
let bitnot_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitnot);
let abs_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::abs);
let neg_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::neg);
let bitnot_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitnot);
let abs_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::abs);
//let reverse_bits_executor =
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::reverse_bits); Unary Ops Clear
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::reverse_bits); Unary Ops Clear
// functions
let clear_neg = |x: i64| x.wrapping_neg();
let clear_abs = |x: i64| x.wrapping_abs();
@@ -129,15 +286,18 @@ where
];
// Scalar binary Ops Executors
let scalar_add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
let scalar_sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
let scalar_add_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
let scalar_sub_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
let scalar_bitwise_and_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitand);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitand);
let scalar_bitwise_or_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitor);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitor);
let scalar_bitwise_xor_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitxor);
let scalar_mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_mul);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitxor);
let scalar_mul_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_mul);
#[allow(clippy::type_complexity)]
let mut scalar_binary_ops: Vec<(
@@ -178,13 +338,13 @@ where
];
let scalar_rotate_left_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
let scalar_left_shift_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
let scalar_rotate_right_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
let scalar_right_shift_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
#[allow(clippy::type_complexity)]
let mut scalar_shift_rotate_ops: Vec<(
SignedScalarShiftRotateExecutor,
@@ -215,11 +375,11 @@ where
// Overflowing Ops Executors
let overflowing_add_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_add);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_add);
let overflowing_sub_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_sub);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_sub);
//let overflowing_mul_executor =
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_mul);
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_mul);
// Overflowing Ops Clear functions
let clear_overflowing_add = |x: i64, y: i64| -> (i64, bool) { x.overflowing_add(y) };
@@ -250,10 +410,12 @@ where
];
// Scalar Overflowing Ops Executors
let overflowing_scalar_add_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_scalar_add);
let overflowing_scalar_add_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(
&CudaServerKey::signed_overflowing_scalar_add,
);
// let overflowing_scalar_sub_executor =
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_scalar_sub);
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&
// CudaServerKey::signed_overflowing_scalar_sub);
#[allow(clippy::type_complexity)]
let mut scalar_overflowing_ops: Vec<(
@@ -274,12 +436,12 @@ where
];
// Comparison Ops Executors
let gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::gt);
let ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
let lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
let le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
let eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
let ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
let gt_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::gt);
let ge_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
let lt_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
let le_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
let eq_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
let ne_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
// Comparison Ops Clear functions
let clear_gt = |x: i64, y: i64| -> bool { x > y };
@@ -304,12 +466,18 @@ where
];
// Scalar Comparison Ops Executors
let scalar_gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_gt);
let scalar_ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
let scalar_lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
let scalar_le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
let scalar_eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
let scalar_ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
let scalar_gt_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_gt);
let scalar_ge_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
let scalar_lt_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
let scalar_le_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
let scalar_eq_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
let scalar_ne_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
#[allow(clippy::type_complexity)]
let mut scalar_comparison_ops: Vec<(
@@ -350,7 +518,8 @@ where
];
// Select Executor
let select_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::if_then_else);
let select_executor =
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::if_then_else);
// Select
let clear_select = |b: bool, x: i64, y: i64| if b { x } else { y };
@@ -367,7 +536,7 @@ where
)];
// Div executor
let div_rem_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::div_rem);
let div_rem_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::div_rem);
// Div Rem Clear functions
let clear_div_rem = |x: i64, y: i64| -> (i64, i64) { (x.wrapping_div(y), x.wrapping_rem(y)) };
#[allow(clippy::type_complexity)]
@@ -383,7 +552,7 @@ where
// Scalar Div executor
let scalar_div_rem_executor =
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_scalar_div_rem);
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_scalar_div_rem);
#[allow(clippy::type_complexity)]
let mut scalar_div_rem_op: Vec<(
SignedScalarDivRemOpExecutor,
@@ -396,9 +565,11 @@ where
)];
// Log2/Hamming weight ops
let ilog2_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ilog2);
//let count_zeros_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_zeros);
//let count_ones_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_ones);
let ilog2_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ilog2);
//let count_zeros_executor =
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_zeros);
// let count_ones_executor =
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_ones);
let clear_ilog2 = |x: i64| x.ilog2() as u64;
//let clear_count_zeros = |x: i64| x.count_zeros() as i64;
//let clear_count_ones = |x: i64| x.count_ones() as i64;
@@ -418,7 +589,7 @@ where
//),
];
signed_random_op_sequence_test(
let (cks, sks, mut datagen) = signed_random_op_sequence_test_init_gpu(
param,
&mut binary_ops,
&mut unary_ops,
@@ -434,4 +605,23 @@ where
&mut shift_rotate_ops,
&mut scalar_shift_rotate_ops,
);
signed_random_op_sequence_test(
&mut datagen,
&cks,
&sks,
&mut binary_ops,
&mut unary_ops,
&mut scalar_binary_ops,
&mut overflowing_ops,
&mut scalar_overflowing_ops,
&mut comparison_ops,
&mut scalar_comparison_ops,
&mut select_op,
&mut div_rem_op,
&mut scalar_div_rem_op,
&mut log2_ops,
&mut shift_rotate_ops,
&mut scalar_shift_rotate_ops,
);
}

View File

@@ -20,15 +20,18 @@ pub(crate) mod test_shift;
pub(crate) mod test_sub;
pub(crate) mod test_vector_comparisons;
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
use crate::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
use crate::integer::gpu::server_key::radix::tests_unsigned::GpuFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_unsigned::{GpuContext, GpuFunctionExecutor};
use crate::integer::gpu::CudaServerKey;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
use crate::integer::{
BooleanBlock, RadixCiphertext, RadixClientKey, ServerKey, SignedRadixCiphertext,
BooleanBlock, RadixCiphertext, RadixClientKey, ServerKey, SignedRadixCiphertext, U256,
};
use crate::GpuIndex;
use rand::seq::SliceRandom;
use rand::Rng;
use std::sync::Arc;
/// For default/unchecked unary functions
@@ -643,3 +646,684 @@ where
d_block.to_boolean_block(&context.streams)
}
}
pub(crate) struct GpuMultiDeviceFunctionExecutor<F> {
pub(crate) context: Option<GpuContext>,
pub(crate) func: F,
}
impl<F> GpuMultiDeviceFunctionExecutor<F> {
pub(crate) fn new(func: F) -> Self {
Self {
context: None,
func,
}
}
}
impl<F> GpuMultiDeviceFunctionExecutor<F> {
pub(crate) fn setup_from_keys(&mut self, cks: &RadixClientKey, _sks: &Arc<ServerKey>) {
// Sample a random subset of 1-N gpus, where N is the number of available GPUs
// A GPU index should not appear twice in the subset
let num_gpus = get_number_of_gpus();
let mut rng = rand::thread_rng();
let num_gpus_to_use = rng.gen_range(1..=num_gpus as usize);
let mut all_gpu_indexes: Vec<u32> = (0..num_gpus).collect();
all_gpu_indexes.shuffle(&mut rng);
let gpu_indexes_to_use = &all_gpu_indexes[..num_gpus_to_use];
let gpu_indexes: Vec<GpuIndex> = gpu_indexes_to_use
.iter()
.map(|idx| GpuIndex::new(*idx))
.collect();
println!("Setting up server key on GPUs: [{gpu_indexes_to_use:?}]");
let streams = CudaStreams::new_multi_gpu_with_indexes(&gpu_indexes);
let sks = CudaServerKey::new(cks.as_ref(), &streams);
streams.synchronize();
let context = GpuContext { streams, sks };
self.context = Some(context);
}
}
/// For default/unchecked binary signed functions
impl<'a, F>
FunctionExecutor<(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), SignedRadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
&CudaSignedRadixCiphertext,
&CudaStreams,
) -> CudaSignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
) -> SignedRadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.streams);
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
gpu_result.to_signed_radix_ciphertext(&context.streams)
}
}
impl<'a, F>
FunctionExecutor<(&'a SignedRadixCiphertext, &'a RadixCiphertext), SignedRadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
&CudaUnsignedRadixCiphertext,
&CudaStreams,
) -> CudaSignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a SignedRadixCiphertext, &'a RadixCiphertext),
) -> SignedRadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2 =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
gpu_result.to_signed_radix_ciphertext(&context.streams)
}
}
/// For unchecked/default assign binary functions
impl<'a, F> FunctionExecutor<(&'a mut SignedRadixCiphertext, &'a SignedRadixCiphertext), ()>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(&CudaServerKey, &mut CudaSignedRadixCiphertext, &CudaSignedRadixCiphertext, &CudaStreams),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a mut SignedRadixCiphertext, &'a SignedRadixCiphertext)) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let mut d_ctxt_1 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.streams);
(self.func)(&context.sks, &mut d_ctxt_1, &d_ctxt_2, &context.streams);
*input.0 = d_ctxt_1.to_signed_radix_ciphertext(&context.streams);
}
}
/// For unchecked/default binary functions with one scalar input
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
i64,
&CudaStreams,
) -> CudaSignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a SignedRadixCiphertext, i64)) -> SignedRadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
gpu_result.to_signed_radix_ciphertext(&context.streams)
}
}
/// For unchecked/default binary functions with one scalar input
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, u64), SignedRadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
u64,
&CudaStreams,
) -> CudaSignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a SignedRadixCiphertext, u64)) -> SignedRadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
gpu_result.to_signed_radix_ciphertext(&context.streams)
}
}
/// For unchecked/default binary functions with one scalar input
impl<F> FunctionExecutor<(SignedRadixCiphertext, i64), SignedRadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
i64,
&CudaStreams,
) -> CudaSignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (SignedRadixCiphertext, i64)) -> SignedRadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&input.0, &context.streams);
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
gpu_result.to_signed_radix_ciphertext(&context.streams)
}
}
// Unary Function
impl<'a, F> FunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(&CudaServerKey, &CudaSignedRadixCiphertext, &CudaStreams) -> CudaSignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: &'a SignedRadixCiphertext) -> SignedRadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input, &context.streams);
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, &context.streams);
gpu_result.to_signed_radix_ciphertext(&context.streams)
}
}
impl<'a, F> FunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(&CudaServerKey, &CudaSignedRadixCiphertext, &CudaStreams) -> CudaUnsignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: &'a SignedRadixCiphertext) -> RadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input, &context.streams);
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, &context.streams);
gpu_result.to_radix_ciphertext(&context.streams)
}
}
// Unary assign Function
impl<'a, F> FunctionExecutor<&'a mut SignedRadixCiphertext, ()>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(&CudaServerKey, &mut CudaSignedRadixCiphertext, &CudaStreams),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: &'a mut SignedRadixCiphertext) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let mut d_ctxt_1 =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input, &context.streams);
(self.func)(&context.sks, &mut d_ctxt_1, &context.streams);
*input = d_ctxt_1.to_signed_radix_ciphertext(&context.streams)
}
}
impl<'a, F> FunctionExecutor<&'a Vec<SignedRadixCiphertext>, Option<SignedRadixCiphertext>>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(&CudaServerKey, Vec<CudaSignedRadixCiphertext>) -> Option<CudaSignedRadixCiphertext>,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: &'a Vec<SignedRadixCiphertext>) -> Option<SignedRadixCiphertext> {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: Vec<CudaSignedRadixCiphertext> = input
.iter()
.map(|ct| CudaSignedRadixCiphertext::from_signed_radix_ciphertext(ct, &context.streams))
.collect();
let d_res = (self.func)(&context.sks, d_ctxt_1);
Some(d_res.unwrap().to_signed_radix_ciphertext(&context.streams))
}
}
impl<'a, F>
FunctionExecutor<
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
(SignedRadixCiphertext, BooleanBlock),
> for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
&CudaSignedRadixCiphertext,
&CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
) -> (SignedRadixCiphertext, BooleanBlock) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
(
d_res.0.to_signed_radix_ciphertext(&context.streams),
d_res.1.to_boolean_block(&context.streams),
)
}
}
/// For unchecked/default unsigned overflowing scalar operations
impl<'a, F>
FunctionExecutor<(&'a SignedRadixCiphertext, i64), (SignedRadixCiphertext, BooleanBlock)>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
i64,
&CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a SignedRadixCiphertext, i64),
) -> (SignedRadixCiphertext, BooleanBlock) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
(
d_res.0.to_signed_radix_ciphertext(&context.streams),
d_res.1.to_boolean_block(&context.streams),
)
}
}
impl<'a, F> FunctionExecutor<&'a SignedRadixCiphertext, (SignedRadixCiphertext, BooleanBlock)>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
&CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: &'a SignedRadixCiphertext,
) -> (SignedRadixCiphertext, BooleanBlock) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, &context.streams);
(
d_res.0.to_signed_radix_ciphertext(&context.streams),
d_res.1.to_boolean_block(&context.streams),
)
}
}
impl<'a, F>
FunctionExecutor<
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
(SignedRadixCiphertext, SignedRadixCiphertext),
> for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
&CudaSignedRadixCiphertext,
&CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
) -> (SignedRadixCiphertext, SignedRadixCiphertext) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
(
d_res.0.to_signed_radix_ciphertext(&context.streams),
d_res.1.to_signed_radix_ciphertext(&context.streams),
)
}
}
impl<'a, F>
FunctionExecutor<
(&'a SignedRadixCiphertext, i64),
(SignedRadixCiphertext, SignedRadixCiphertext),
> for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
i64,
&CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a SignedRadixCiphertext, i64),
) -> (SignedRadixCiphertext, SignedRadixCiphertext) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
(
d_res.0.to_signed_radix_ciphertext(&context.streams),
d_res.1.to_signed_radix_ciphertext(&context.streams),
)
}
}
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), BooleanBlock>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
&CudaSignedRadixCiphertext,
&CudaStreams,
) -> CudaBooleanBlock,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
) -> BooleanBlock {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
d_res.to_boolean_block(&context.streams)
}
}
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, i64), BooleanBlock>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(&CudaServerKey, &CudaSignedRadixCiphertext, i64, &CudaStreams) -> CudaBooleanBlock,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a SignedRadixCiphertext, i64)) -> BooleanBlock {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
d_res.to_boolean_block(&context.streams)
}
}
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, U256), BooleanBlock>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(&CudaServerKey, &CudaSignedRadixCiphertext, U256, &CudaStreams) -> CudaBooleanBlock,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a SignedRadixCiphertext, U256)) -> BooleanBlock {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
d_res.to_boolean_block(&context.streams)
}
}
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, U256), SignedRadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaSignedRadixCiphertext,
U256,
&CudaStreams,
) -> CudaSignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a SignedRadixCiphertext, U256)) -> SignedRadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
d_res.to_signed_radix_ciphertext(&context.streams)
}
}
impl<'a, F>
FunctionExecutor<
(
&'a BooleanBlock,
&'a SignedRadixCiphertext,
&'a SignedRadixCiphertext,
),
SignedRadixCiphertext,
> for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaBooleanBlock,
&CudaSignedRadixCiphertext,
&CudaSignedRadixCiphertext,
&CudaStreams,
) -> CudaSignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (
&'a BooleanBlock,
&'a SignedRadixCiphertext,
&'a SignedRadixCiphertext,
),
) -> SignedRadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaBooleanBlock =
CudaBooleanBlock::from_boolean_block(input.0, &context.streams);
let d_ctxt_2: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.streams);
let d_ctxt_3: CudaSignedRadixCiphertext =
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.2, &context.streams);
let d_res = (self.func)(
&context.sks,
&d_ctxt_1,
&d_ctxt_2,
&d_ctxt_3,
&context.streams,
);
d_res.to_signed_radix_ciphertext(&context.streams)
}
}

View File

@@ -50,6 +50,7 @@ macro_rules! create_gpu_parameterized_test{
};
}
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
pub(crate) use create_gpu_parameterized_test;
pub(crate) struct GpuContext {
@@ -904,3 +905,530 @@ where
d_block.to_boolean_block(&context.streams)
}
}
/// For default/unchecked binary functions
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
&CudaUnsignedRadixCiphertext,
&CudaStreams,
) -> CudaUnsignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a RadixCiphertext, &'a RadixCiphertext)) -> RadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2 =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
gpu_result.to_radix_ciphertext(&context.streams)
}
}
/// For unchecked/default assign binary functions
impl<'a, F> FunctionExecutor<(&'a mut RadixCiphertext, &'a RadixCiphertext), ()>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&mut CudaUnsignedRadixCiphertext,
&CudaUnsignedRadixCiphertext,
&CudaStreams,
),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a mut RadixCiphertext, &'a RadixCiphertext)) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let mut d_ctxt_1 =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2 =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
(self.func)(&context.sks, &mut d_ctxt_1, &d_ctxt_2, &context.streams);
*input.0 = d_ctxt_1.to_radix_ciphertext(&context.streams);
}
}
/// For unchecked/default binary functions with one scalar input
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
u64,
&CudaStreams,
) -> CudaUnsignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a RadixCiphertext, u64)) -> RadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
gpu_result.to_radix_ciphertext(&context.streams)
}
}
/// For unchecked/default binary functions with one scalar input
impl<F> FunctionExecutor<(RadixCiphertext, u64), RadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
u64,
&CudaStreams,
) -> CudaUnsignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (RadixCiphertext, u64)) -> RadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(&input.0, &context.streams);
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
gpu_result.to_radix_ciphertext(&context.streams)
}
}
// Unary Function
impl<'a, F> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
&CudaStreams,
) -> CudaUnsignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: &'a RadixCiphertext) -> RadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(input, &context.streams);
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, &context.streams);
gpu_result.to_radix_ciphertext(&context.streams)
}
}
// Unary assign Function
impl<'a, F> FunctionExecutor<&'a mut RadixCiphertext, ()> for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(&CudaServerKey, &mut CudaUnsignedRadixCiphertext, &CudaStreams),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: &'a mut RadixCiphertext) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let mut d_ctxt_1 =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input, &context.streams);
(self.func)(&context.sks, &mut d_ctxt_1, &context.streams);
*input = d_ctxt_1.to_radix_ciphertext(&context.streams)
}
}
impl<'a, F> FunctionExecutor<&'a Vec<RadixCiphertext>, Option<RadixCiphertext>>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(&CudaServerKey, Vec<CudaUnsignedRadixCiphertext>) -> Option<CudaUnsignedRadixCiphertext>,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: &'a Vec<RadixCiphertext>) -> Option<RadixCiphertext> {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: Vec<CudaUnsignedRadixCiphertext> = input
.iter()
.map(|ct| CudaUnsignedRadixCiphertext::from_radix_ciphertext(ct, &context.streams))
.collect();
let d_res = (self.func)(&context.sks, d_ctxt_1);
Some(d_res.unwrap().to_radix_ciphertext(&context.streams))
}
}
impl<'a, F>
FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), (RadixCiphertext, BooleanBlock)>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
&CudaUnsignedRadixCiphertext,
&CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a RadixCiphertext, &'a RadixCiphertext),
) -> (RadixCiphertext, BooleanBlock) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
(
d_res.0.to_radix_ciphertext(&context.streams),
d_res.1.to_boolean_block(&context.streams),
)
}
}
/// For unchecked/default unsigned overflowing scalar operations
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, u64), (RadixCiphertext, BooleanBlock)>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
u64,
&CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a RadixCiphertext, u64)) -> (RadixCiphertext, BooleanBlock) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
(
d_res.0.to_radix_ciphertext(&context.streams),
d_res.1.to_boolean_block(&context.streams),
)
}
}
/// For ilog operation
impl<'a, F> FunctionExecutor<&'a RadixCiphertext, (RadixCiphertext, BooleanBlock)>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
&CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: &'a RadixCiphertext) -> (RadixCiphertext, BooleanBlock) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, &context.streams);
(
d_res.0.to_radix_ciphertext(&context.streams),
d_res.1.to_boolean_block(&context.streams),
)
}
}
impl<'a, F>
FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), (RadixCiphertext, RadixCiphertext)>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
&CudaUnsignedRadixCiphertext,
&CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a RadixCiphertext, &'a RadixCiphertext),
) -> (RadixCiphertext, RadixCiphertext) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
(
d_res.0.to_radix_ciphertext(&context.streams),
d_res.1.to_radix_ciphertext(&context.streams),
)
}
}
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, u64), (RadixCiphertext, RadixCiphertext)>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
u64,
&CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext),
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a RadixCiphertext, u64)) -> (RadixCiphertext, RadixCiphertext) {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
(
d_res.0.to_radix_ciphertext(&context.streams),
d_res.1.to_radix_ciphertext(&context.streams),
)
}
}
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), BooleanBlock>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
&CudaUnsignedRadixCiphertext,
&CudaStreams,
) -> CudaBooleanBlock,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a RadixCiphertext, &'a RadixCiphertext)) -> BooleanBlock {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let d_ctxt_2: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
d_res.to_boolean_block(&context.streams)
}
}
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, u64), BooleanBlock>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(&CudaServerKey, &CudaUnsignedRadixCiphertext, u64, &CudaStreams) -> CudaBooleanBlock,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a RadixCiphertext, u64)) -> BooleanBlock {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
d_res.to_boolean_block(&context.streams)
}
}
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, U256), BooleanBlock>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(&CudaServerKey, &CudaUnsignedRadixCiphertext, U256, &CudaStreams) -> CudaBooleanBlock,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a RadixCiphertext, U256)) -> BooleanBlock {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
d_res.to_boolean_block(&context.streams)
}
}
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, U256), RadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaUnsignedRadixCiphertext,
U256,
&CudaStreams,
) -> CudaUnsignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(&mut self, input: (&'a RadixCiphertext, U256)) -> RadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
d_res.to_radix_ciphertext(&context.streams)
}
}
impl<'a, F>
FunctionExecutor<(&'a BooleanBlock, &'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>
for GpuMultiDeviceFunctionExecutor<F>
where
F: Fn(
&CudaServerKey,
&CudaBooleanBlock,
&CudaUnsignedRadixCiphertext,
&CudaUnsignedRadixCiphertext,
&CudaStreams,
) -> CudaUnsignedRadixCiphertext,
{
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
self.setup_from_keys(cks, &sks);
}
fn execute(
&mut self,
input: (&'a BooleanBlock, &'a RadixCiphertext, &'a RadixCiphertext),
) -> RadixCiphertext {
let context = self
.context
.as_ref()
.expect("setup was not properly called");
let d_ctxt_1: CudaBooleanBlock =
CudaBooleanBlock::from_boolean_block(input.0, &context.streams);
let d_ctxt_2: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
let d_ctxt_3: CudaUnsignedRadixCiphertext =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.2, &context.streams);
let d_res = (self.func)(
&context.sks,
&d_ctxt_1,
&d_ctxt_2,
&d_ctxt_3,
&context.streams,
);
d_res.to_radix_ciphertext(&context.streams)
}
}

View File

@@ -1,6 +1,6 @@
use crate::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parameterized_test, GpuFunctionExecutor,
};

View File

@@ -1,5 +1,5 @@
use crate::core_crypto::gpu::get_number_of_gpus;
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parameterized_test, GpuFunctionExecutor,
};

View File

@@ -1,5 +1,5 @@
use crate::core_crypto::gpu::get_number_of_gpus;
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parameterized_test, GpuFunctionExecutor,
};

View File

@@ -1,5 +1,5 @@
use crate::core_crypto::gpu::get_number_of_gpus;
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
use crate::integer::gpu::server_key::radix::tests_unsigned::{
create_gpu_parameterized_test, GpuFunctionExecutor,
};

View File

@@ -5,6 +5,7 @@ use crate::integer::{
BooleanBlock, IntegerCiphertext, RadixCiphertext, RadixClientKey, SignedRadixCiphertext,
};
use crate::shortint::parameters::NoiseLevel;
use crate::CompressedServerKey;
use rand::Rng;
use tfhe_csprng::generators::DefaultRandomGenerator;
use tfhe_csprng::seeders::{Seed, Seeder};
@@ -17,6 +18,42 @@ pub(crate) const NB_CTXT_LONG_RUN: usize = 32;
pub(crate) const NB_TESTS_LONG_RUN: usize = 20000;
pub(crate) const NB_TESTS_LONG_RUN_MINIMAL: usize = 200;
pub(crate) fn get_user_defined_seed() -> Option<Seed> {
match std::env::var("TFHE_RS_LONGRUN_TESTS_SEED") {
Ok(val) => match val.parse::<u128>() {
Ok(s) => Some(Seed(s)),
Err(_e) => None,
},
Err(_e) => None,
}
}
/// This trait is to be implemented by a struct that is capable
/// of executing a particular function to be tested in the
/// random op sequence tests. This executor
/// can execute ops either on CPU or on both CPU and GPU
pub(crate) trait OpSequenceFunctionExecutor<TestInput, TestOutput> {
/// Setups the executor
///
/// Implementers are expected to be fully functional after this
/// function has been called.
fn setup(
&mut self,
cks: &RadixClientKey,
sks: &CompressedServerKey,
seeder: &mut DeterministicSeeder<DefaultRandomGenerator>,
);
/// Executes the function
///
/// The function receives some inputs and return some output.
/// Implementers may have to do more than just calling the function
/// that is being tested (for example input/output may need to be converted)
///
/// Look at the test case function to know what are the expected inputs and outputs.
fn execute(&mut self, input: TestInput) -> TestOutput;
}
pub(crate) fn get_long_test_iterations() -> usize {
static ENV_KEY_LONG_TESTS: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
@@ -60,11 +97,13 @@ impl RadixEncryptable for i64 {
key.encrypt_signed(*self)
}
}
struct RandomOpSequenceDataGenerator<P, C> {
// Generates data for the random op sequence tests. It
// can generate data deterministically using a user-specified
// seed
pub(crate) struct RandomOpSequenceDataGenerator<P, C> {
pub(crate) lhs: Vec<TestDataSample<P, C>>,
pub(crate) rhs: Vec<TestDataSample<P, C>>,
deterministic_seeder: DeterministicSeeder<DefaultRandomGenerator>,
pub(crate) deterministic_seeder: DeterministicSeeder<DefaultRandomGenerator>,
seed: Seed,
cks: RadixClientKey,
op_counter: usize,
@@ -75,7 +114,7 @@ impl<
C: IntegerCiphertext,
> RandomOpSequenceDataGenerator<P, C>
{
fn new(total_num_ops: usize, cks: &RadixClientKey) -> Self {
pub(crate) fn new(total_num_ops: usize, cks: &RadixClientKey) -> Self {
let mut rng = rand::thread_rng();
let seed: u128 = rng.gen();
@@ -99,7 +138,7 @@ impl<
}) // Generate random i64 values and encrypt them
.collect()
}
fn new_with_seed(total_num_ops: usize, seed: Seed, cks: &RadixClientKey) -> Self {
pub(crate) fn new_with_seed(total_num_ops: usize, seed: Seed, cks: &RadixClientKey) -> Self {
let mut deterministic_seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
Self {
@@ -112,7 +151,7 @@ impl<
}
}
fn get_seed(&self) -> Seed {
pub(crate) fn get_seed(&self) -> Seed {
self.seed
}
@@ -199,6 +238,11 @@ impl<
let val = self.deterministic_seeder.seed().0 % 2;
(val == 1, self.cks.encrypt_bool(val == 1))
}
fn gen_bool_uniform(&mut self) -> bool {
let val = self.deterministic_seeder.seed().0 % 2;
val == 1
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn sanity_check_op_sequence_result_u64(

View File

@@ -1,13 +1,14 @@
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
use crate::integer::server_key::radix_parallel::tests_long_run::{
get_long_test_iterations, sanity_check_op_sequence_result_bool,
sanity_check_op_sequence_result_u64, RandomOpSequenceDataGenerator, NB_CTXT_LONG_RUN,
get_long_test_iterations, get_user_defined_seed, sanity_check_op_sequence_result_bool,
sanity_check_op_sequence_result_u64, OpSequenceFunctionExecutor, RandomOpSequenceDataGenerator,
NB_CTXT_LONG_RUN,
};
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
use crate::integer::server_key::radix_parallel::tests_unsigned::OpSequenceCpuFunctionExecutor;
use crate::integer::tests::create_parameterized_test;
use crate::integer::{BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey};
use crate::shortint::parameters::*;
use crate::{ClientKey, CompressedServerKey, Tag};
use std::cmp::{max, min};
use std::sync::Arc;
@@ -19,59 +20,78 @@ create_parameterized_test!(random_op_sequence_data_generator {
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
});
pub(crate) type BinaryOpExecutor =
Box<dyn for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>>;
pub(crate) type BinaryOpExecutor = Box<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext),
RadixCiphertext,
>,
>;
pub(crate) type UnaryOpExecutor =
Box<dyn for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>>;
Box<dyn for<'a> OpSequenceFunctionExecutor<&'a RadixCiphertext, RadixCiphertext>>;
pub(crate) type ScalarBinaryOpExecutor =
Box<dyn for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>>;
Box<dyn for<'a> OpSequenceFunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>>;
pub(crate) type OverflowingOpExecutor = Box<
dyn for<'a> FunctionExecutor<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext),
(RadixCiphertext, BooleanBlock),
>,
>;
pub(crate) type ScalarOverflowingOpExecutor =
Box<dyn for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), (RadixCiphertext, BooleanBlock)>>;
pub(crate) type ComparisonOpExecutor =
Box<dyn for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), BooleanBlock>>;
pub(crate) type ScalarOverflowingOpExecutor = Box<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a RadixCiphertext, u64),
(RadixCiphertext, BooleanBlock),
>,
>;
pub(crate) type ComparisonOpExecutor = Box<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext),
BooleanBlock,
>,
>;
pub(crate) type ScalarComparisonOpExecutor =
Box<dyn for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), BooleanBlock>>;
Box<dyn for<'a> OpSequenceFunctionExecutor<(&'a RadixCiphertext, u64), BooleanBlock>>;
pub(crate) type SelectOpExecutor = Box<
dyn for<'a> FunctionExecutor<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a BooleanBlock, &'a RadixCiphertext, &'a RadixCiphertext),
RadixCiphertext,
>,
>;
pub(crate) type DivRemOpExecutor = Box<
dyn for<'a> FunctionExecutor<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a RadixCiphertext, &'a RadixCiphertext),
(RadixCiphertext, RadixCiphertext),
>,
>;
pub(crate) type ScalarDivRemOpExecutor = Box<
dyn for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), (RadixCiphertext, RadixCiphertext)>,
dyn for<'a> OpSequenceFunctionExecutor<
(&'a RadixCiphertext, u64),
(RadixCiphertext, RadixCiphertext),
>,
>;
pub(crate) type Log2OpExecutor =
Box<dyn for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>>;
Box<dyn for<'a> OpSequenceFunctionExecutor<&'a RadixCiphertext, RadixCiphertext>>;
fn random_op_sequence<P>(param: P)
where
P: Into<TestParameters> + Clone,
{
// Binary Ops Executors
let add_executor = CpuFunctionExecutor::new(&ServerKey::add_parallelized);
let sub_executor = CpuFunctionExecutor::new(&ServerKey::sub_parallelized);
let bitwise_and_executor = CpuFunctionExecutor::new(&ServerKey::bitand_parallelized);
let bitwise_or_executor = CpuFunctionExecutor::new(&ServerKey::bitor_parallelized);
let bitwise_xor_executor = CpuFunctionExecutor::new(&ServerKey::bitxor_parallelized);
let mul_executor = CpuFunctionExecutor::new(&ServerKey::mul_parallelized);
let rotate_left_executor = CpuFunctionExecutor::new(&ServerKey::rotate_left_parallelized);
let left_shift_executor = CpuFunctionExecutor::new(&ServerKey::left_shift_parallelized);
let rotate_right_executor = CpuFunctionExecutor::new(&ServerKey::rotate_right_parallelized);
let right_shift_executor = CpuFunctionExecutor::new(&ServerKey::right_shift_parallelized);
let max_executor = CpuFunctionExecutor::new(&ServerKey::max_parallelized);
let min_executor = CpuFunctionExecutor::new(&ServerKey::min_parallelized);
let add_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::add_parallelized);
let sub_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::sub_parallelized);
let bitwise_and_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitand_parallelized);
let bitwise_or_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitor_parallelized);
let bitwise_xor_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitxor_parallelized);
let mul_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::mul_parallelized);
let rotate_left_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::rotate_left_parallelized);
let left_shift_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::left_shift_parallelized);
let rotate_right_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::rotate_right_parallelized);
let right_shift_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::right_shift_parallelized);
let max_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::max_parallelized);
let min_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::min_parallelized);
// Binary Ops Clear functions
let clear_add = |x: u64, y: u64| x.wrapping_add(y);
@@ -134,9 +154,10 @@ where
];
// Unary Ops Executors
let neg_executor = CpuFunctionExecutor::new(&ServerKey::neg_parallelized);
let bitnot_executor = CpuFunctionExecutor::new(&ServerKey::bitnot);
let reverse_bits_executor = CpuFunctionExecutor::new(&ServerKey::reverse_bits_parallelized);
let neg_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::neg_parallelized);
let bitnot_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitnot);
let reverse_bits_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::reverse_bits_parallelized);
// Unary Ops Clear functions
let clear_neg = |x: u64| x.wrapping_neg();
let clear_bitnot = |x: u64| !x;
@@ -157,23 +178,26 @@ where
];
// Scalar binary Ops Executors
let scalar_add_executor = CpuFunctionExecutor::new(&ServerKey::scalar_add_parallelized);
let scalar_sub_executor = CpuFunctionExecutor::new(&ServerKey::scalar_sub_parallelized);
let scalar_add_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_add_parallelized);
let scalar_sub_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_sub_parallelized);
let scalar_bitwise_and_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_bitand_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_bitand_parallelized);
let scalar_bitwise_or_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_bitor_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_bitor_parallelized);
let scalar_bitwise_xor_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_bitxor_parallelized);
let scalar_mul_executor = CpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_bitxor_parallelized);
let scalar_mul_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized);
let scalar_rotate_left_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_rotate_left_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_rotate_left_parallelized);
let scalar_left_shift_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_left_shift_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_left_shift_parallelized);
let scalar_rotate_right_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_rotate_right_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_rotate_right_parallelized);
let scalar_right_shift_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_right_shift_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_right_shift_parallelized);
#[allow(clippy::type_complexity)]
let mut scalar_binary_ops: Vec<(ScalarBinaryOpExecutor, &dyn Fn(u64, u64) -> u64, String)> = vec![
@@ -231,11 +255,11 @@ where
// Overflowing Ops Executors
let overflowing_add_executor =
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_add_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_add_parallelized);
let overflowing_sub_executor =
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub_parallelized);
let overflowing_mul_executor =
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_mul_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_mul_parallelized);
// Overflowing Ops Clear functions
let clear_overflowing_add = |x: u64, y: u64| -> (u64, bool) { x.overflowing_add(y) };
@@ -266,10 +290,12 @@ where
];
// Scalar Overflowing Ops Executors
let overflowing_scalar_add_executor =
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_add_parallelized);
let overflowing_scalar_sub_executor =
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_sub_parallelized);
let overflowing_scalar_add_executor = OpSequenceCpuFunctionExecutor::new(
&ServerKey::unsigned_overflowing_scalar_add_parallelized,
);
let overflowing_scalar_sub_executor = OpSequenceCpuFunctionExecutor::new(
&ServerKey::unsigned_overflowing_scalar_sub_parallelized,
);
#[allow(clippy::type_complexity)]
let mut scalar_overflowing_ops: Vec<(
@@ -290,12 +316,12 @@ where
];
// Comparison Ops Executors
let gt_executor = CpuFunctionExecutor::new(&ServerKey::gt_parallelized);
let ge_executor = CpuFunctionExecutor::new(&ServerKey::ge_parallelized);
let lt_executor = CpuFunctionExecutor::new(&ServerKey::lt_parallelized);
let le_executor = CpuFunctionExecutor::new(&ServerKey::le_parallelized);
let eq_executor = CpuFunctionExecutor::new(&ServerKey::eq_parallelized);
let ne_executor = CpuFunctionExecutor::new(&ServerKey::ne_parallelized);
let gt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::gt_parallelized);
let ge_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::ge_parallelized);
let lt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::lt_parallelized);
let le_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::le_parallelized);
let eq_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::eq_parallelized);
let ne_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::ne_parallelized);
// Comparison Ops Clear functions
let clear_gt = |x: u64, y: u64| -> bool { x > y };
@@ -316,12 +342,12 @@ where
];
// Scalar Comparison Ops Executors
let scalar_gt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_gt_parallelized);
let scalar_ge_executor = CpuFunctionExecutor::new(&ServerKey::scalar_ge_parallelized);
let scalar_lt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_lt_parallelized);
let scalar_le_executor = CpuFunctionExecutor::new(&ServerKey::scalar_le_parallelized);
let scalar_eq_executor = CpuFunctionExecutor::new(&ServerKey::scalar_eq_parallelized);
let scalar_ne_executor = CpuFunctionExecutor::new(&ServerKey::scalar_ne_parallelized);
let scalar_gt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_gt_parallelized);
let scalar_ge_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_ge_parallelized);
let scalar_lt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_lt_parallelized);
let scalar_le_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_le_parallelized);
let scalar_eq_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_eq_parallelized);
let scalar_ne_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_ne_parallelized);
#[allow(clippy::type_complexity)]
let mut scalar_comparison_ops: Vec<(
@@ -362,7 +388,7 @@ where
];
// Select Executor
let select_executor = CpuFunctionExecutor::new(&ServerKey::cmux_parallelized);
let select_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::cmux_parallelized);
// Select
let clear_select = |b: bool, x: u64, y: u64| if b { x } else { y };
@@ -375,7 +401,7 @@ where
)];
// Div executor
let div_rem_executor = CpuFunctionExecutor::new(&ServerKey::div_rem_parallelized);
let div_rem_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::div_rem_parallelized);
// Div Rem Clear functions
let clear_div_rem = |x: u64, y: u64| -> (u64, u64) { (x.wrapping_div(y), x.wrapping_rem(y)) };
#[allow(clippy::type_complexity)]
@@ -386,7 +412,8 @@ where
)];
// Scalar Div executor
let scalar_div_rem_executor = CpuFunctionExecutor::new(&ServerKey::scalar_div_rem_parallelized);
let scalar_div_rem_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_div_rem_parallelized);
#[allow(clippy::type_complexity)]
let mut scalar_div_rem_op: Vec<(
ScalarDivRemOpExecutor,
@@ -399,9 +426,11 @@ where
)];
// Log2/Hamming weight ops
let ilog2_executor = CpuFunctionExecutor::new(&ServerKey::ilog2_parallelized);
let count_zeros_executor = CpuFunctionExecutor::new(&ServerKey::count_zeros_parallelized);
let count_ones_executor = CpuFunctionExecutor::new(&ServerKey::count_ones_parallelized);
let ilog2_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::ilog2_parallelized);
let count_zeros_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::count_zeros_parallelized);
let count_ones_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::count_ones_parallelized);
let clear_ilog2 = |x: u64| x.ilog2() as u64;
let clear_count_zeros = |x: u64| x.count_zeros() as u64;
let clear_count_ones = |x: u64| x.count_ones() as u64;
@@ -421,7 +450,7 @@ where
),
];
random_op_sequence_test(
let (cks, sks, mut datagen) = random_op_sequence_test_init_cpu(
param,
&mut binary_ops,
&mut unary_ops,
@@ -435,10 +464,27 @@ where
&mut scalar_div_rem_op,
&mut log2_ops,
);
random_op_sequence_test(
&mut datagen,
&cks,
&sks,
&mut binary_ops,
&mut unary_ops,
&mut scalar_binary_ops,
&mut overflowing_ops,
&mut scalar_overflowing_ops,
&mut comparison_ops,
&mut scalar_comparison_ops,
&mut select_op,
&mut div_rem_op,
&mut scalar_div_rem_op,
&mut log2_ops,
);
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn random_op_sequence_test<P>(
pub(crate) fn random_op_sequence_test_init_cpu<P>(
param: P,
binary_ops: &mut [(BinaryOpExecutor, impl Fn(u64, u64) -> u64, String)],
unary_ops: &mut [(UnaryOpExecutor, impl Fn(u64) -> u64, String)],
@@ -467,51 +513,21 @@ pub(crate) fn random_op_sequence_test<P>(
String,
)],
log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)],
) where
) -> (
RadixClientKey,
Arc<ServerKey>,
RandomOpSequenceDataGenerator<u64, RadixCiphertext>,
)
where
P: Into<TestParameters>,
{
let param = param.into();
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT_LONG_RUN));
let temp_cks =
ClientKey::from_raw_parts(cks.clone(), None, None, None, None, None, Tag::default());
let comp_sks = CompressedServerKey::new(&temp_cks);
println!("Setting up operations");
for x in binary_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in unary_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in scalar_binary_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in overflowing_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in scalar_overflowing_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in comparison_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in scalar_comparison_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in select_op.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in div_rem_op.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in scalar_div_rem_op.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in log2_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
let total_num_ops = binary_ops.len()
+ unary_ops.len()
+ scalar_binary_ops.len()
@@ -525,6 +541,93 @@ pub(crate) fn random_op_sequence_test<P>(
+ log2_ops.len();
println!("Total num ops {total_num_ops}");
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT_LONG_RUN));
let mut datagen = get_user_defined_seed().map_or_else(
|| RandomOpSequenceDataGenerator::<u64, RadixCiphertext>::new(total_num_ops, &cks),
|seed| {
RandomOpSequenceDataGenerator::<u64, RadixCiphertext>::new_with_seed(
total_num_ops,
seed,
&cks,
)
},
);
println!("random_op_sequence_test::seed = {}", datagen.get_seed().0);
println!("Setting up operations");
for x in binary_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in unary_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_binary_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in overflowing_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_overflowing_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in comparison_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_comparison_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in select_op.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in div_rem_op.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_div_rem_op.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in log2_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
(cks, sks, datagen)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn random_op_sequence_test(
datagen: &mut RandomOpSequenceDataGenerator<u64, RadixCiphertext>,
cks: &RadixClientKey,
sks: &Arc<ServerKey>,
binary_ops: &mut [(BinaryOpExecutor, impl Fn(u64, u64) -> u64, String)],
unary_ops: &mut [(UnaryOpExecutor, impl Fn(u64) -> u64, String)],
scalar_binary_ops: &mut [(ScalarBinaryOpExecutor, impl Fn(u64, u64) -> u64, String)],
overflowing_ops: &mut [(
OverflowingOpExecutor,
impl Fn(u64, u64) -> (u64, bool),
String,
)],
scalar_overflowing_ops: &mut [(
ScalarOverflowingOpExecutor,
impl Fn(u64, u64) -> (u64, bool),
String,
)],
comparison_ops: &mut [(ComparisonOpExecutor, impl Fn(u64, u64) -> bool, String)],
scalar_comparison_ops: &mut [(
ScalarComparisonOpExecutor,
impl Fn(u64, u64) -> bool,
String,
)],
select_op: &mut [(SelectOpExecutor, impl Fn(bool, u64, u64) -> u64, String)],
div_rem_op: &mut [(DivRemOpExecutor, impl Fn(u64, u64) -> (u64, u64), String)],
scalar_div_rem_op: &mut [(
ScalarDivRemOpExecutor,
impl Fn(u64, u64) -> (u64, u64),
String,
)],
log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)],
) {
let binary_ops_range = 0..binary_ops.len();
let unary_ops_range = binary_ops_range.end..binary_ops_range.end + unary_ops.len();
let scalar_binary_ops_range =
@@ -544,10 +647,6 @@ pub(crate) fn random_op_sequence_test<P>(
div_rem_op_range.end..div_rem_op_range.end + scalar_div_rem_op.len();
let log2_ops_range = scalar_div_rem_op_range.end..scalar_div_rem_op_range.end + log2_ops.len();
let mut datagen =
RandomOpSequenceDataGenerator::<u64, RadixCiphertext>::new(total_num_ops, &cks);
println!("random_op_sequence_test::seed = {}", datagen.get_seed().0);
for fn_index in 0..get_long_test_iterations() {
let (i, idx) = datagen.gen_op_index();
@@ -718,7 +817,7 @@ pub(crate) fn random_op_sequence_test<P>(
let expected_res = clear_fn(lhs.p, rhs.p);
let res_ct = sks.cast_to_unsigned(
res.clone().into_radix::<RadixCiphertext>(1, &sks),
res.clone().into_radix::<RadixCiphertext>(1, sks),
NB_CTXT_LONG_RUN,
);
datagen.put_op_result_random_side(expected_res as u64, &res_ct, fn_name, idx);
@@ -749,7 +848,7 @@ pub(crate) fn random_op_sequence_test<P>(
let expected_res = clear_fn(lhs.p, rhs.p);
let res_ct: RadixCiphertext = sks.cast_to_unsigned(
res.clone().into_radix::<RadixCiphertext>(1, &sks),
res.clone().into_radix::<RadixCiphertext>(1, sks),
NB_CTXT_LONG_RUN,
);
datagen.put_op_result_random_side(expected_res as u64, &res_ct, fn_name, idx);

View File

@@ -1,16 +1,16 @@
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
use crate::integer::server_key::radix_parallel::tests_long_run::{
get_long_test_iterations, sanity_check_op_sequence_result_bool,
sanity_check_op_sequence_result_i64, sanity_check_op_sequence_result_u64, NB_CTXT_LONG_RUN,
get_long_test_iterations, get_user_defined_seed, sanity_check_op_sequence_result_bool,
sanity_check_op_sequence_result_i64, sanity_check_op_sequence_result_u64,
OpSequenceFunctionExecutor, RandomOpSequenceDataGenerator, NB_CTXT_LONG_RUN,
};
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
use crate::integer::server_key::radix_parallel::tests_unsigned::OpSequenceCpuFunctionExecutor;
use crate::integer::tests::create_parameterized_test;
use crate::integer::{
BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey, SignedRadixCiphertext,
};
use crate::shortint::parameters::*;
use rand::Rng;
use crate::{ClientKey, CompressedServerKey, Tag};
use std::cmp::{max, min};
use std::sync::Arc;
@@ -19,46 +19,48 @@ create_parameterized_test!(random_op_sequence {
});
pub(crate) type SignedBinaryOpExecutor = Box<
dyn for<'a> FunctionExecutor<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
SignedRadixCiphertext,
>,
>;
pub(crate) type SignedShiftRotateExecutor = Box<
dyn for<'a> FunctionExecutor<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a SignedRadixCiphertext, &'a RadixCiphertext),
SignedRadixCiphertext,
>,
>;
pub(crate) type SignedUnaryOpExecutor =
Box<dyn for<'a> FunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>>;
Box<dyn for<'a> OpSequenceFunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>>;
pub(crate) type SignedScalarBinaryOpExecutor =
Box<dyn for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>>;
pub(crate) type SignedScalarShiftRotateExecutor =
Box<dyn for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, u64), SignedRadixCiphertext>>;
pub(crate) type SignedScalarBinaryOpExecutor = Box<
dyn for<'a> OpSequenceFunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>,
>;
pub(crate) type SignedScalarShiftRotateExecutor = Box<
dyn for<'a> OpSequenceFunctionExecutor<(&'a SignedRadixCiphertext, u64), SignedRadixCiphertext>,
>;
pub(crate) type SignedOverflowingOpExecutor = Box<
dyn for<'a> FunctionExecutor<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
(SignedRadixCiphertext, BooleanBlock),
>,
>;
pub(crate) type SignedScalarOverflowingOpExecutor = Box<
dyn for<'a> FunctionExecutor<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a SignedRadixCiphertext, i64),
(SignedRadixCiphertext, BooleanBlock),
>,
>;
pub(crate) type SignedComparisonOpExecutor = Box<
dyn for<'a> FunctionExecutor<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
BooleanBlock,
>,
>;
pub(crate) type SignedScalarComparisonOpExecutor =
Box<dyn for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), BooleanBlock>>;
Box<dyn for<'a> OpSequenceFunctionExecutor<(&'a SignedRadixCiphertext, i64), BooleanBlock>>;
pub(crate) type SignedSelectOpExecutor = Box<
dyn for<'a> FunctionExecutor<
dyn for<'a> OpSequenceFunctionExecutor<
(
&'a BooleanBlock,
&'a SignedRadixCiphertext,
@@ -68,32 +70,32 @@ pub(crate) type SignedSelectOpExecutor = Box<
>,
>;
pub(crate) type SignedDivRemOpExecutor = Box<
dyn for<'a> FunctionExecutor<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
(SignedRadixCiphertext, SignedRadixCiphertext),
>,
>;
pub(crate) type SignedScalarDivRemOpExecutor = Box<
dyn for<'a> FunctionExecutor<
dyn for<'a> OpSequenceFunctionExecutor<
(&'a SignedRadixCiphertext, i64),
(SignedRadixCiphertext, SignedRadixCiphertext),
>,
>;
pub(crate) type SignedLog2OpExecutor =
Box<dyn for<'a> FunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext>>;
Box<dyn for<'a> OpSequenceFunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext>>;
fn random_op_sequence<P>(param: P)
where
P: Into<TestParameters> + Clone,
{
// Binary Ops Executors
let add_executor = CpuFunctionExecutor::new(&ServerKey::add_parallelized);
let sub_executor = CpuFunctionExecutor::new(&ServerKey::sub_parallelized);
let bitwise_and_executor = CpuFunctionExecutor::new(&ServerKey::bitand_parallelized);
let bitwise_or_executor = CpuFunctionExecutor::new(&ServerKey::bitor_parallelized);
let bitwise_xor_executor = CpuFunctionExecutor::new(&ServerKey::bitxor_parallelized);
let mul_executor = CpuFunctionExecutor::new(&ServerKey::mul_parallelized);
let max_executor = CpuFunctionExecutor::new(&ServerKey::max_parallelized);
let min_executor = CpuFunctionExecutor::new(&ServerKey::min_parallelized);
let add_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::add_parallelized);
let sub_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::sub_parallelized);
let bitwise_and_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitand_parallelized);
let bitwise_or_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitor_parallelized);
let bitwise_xor_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitxor_parallelized);
let mul_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::mul_parallelized);
let max_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::max_parallelized);
let min_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::min_parallelized);
// Binary Ops Clear functions
let clear_add = |x: i64, y: i64| x.wrapping_add(y);
@@ -129,10 +131,14 @@ where
(Box::new(min_executor), &clear_min, "min".to_string()),
];
let rotate_left_executor = CpuFunctionExecutor::new(&ServerKey::rotate_left_parallelized);
let left_shift_executor = CpuFunctionExecutor::new(&ServerKey::left_shift_parallelized);
let rotate_right_executor = CpuFunctionExecutor::new(&ServerKey::rotate_right_parallelized);
let right_shift_executor = CpuFunctionExecutor::new(&ServerKey::right_shift_parallelized);
let rotate_left_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::rotate_left_parallelized);
let left_shift_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::left_shift_parallelized);
let rotate_right_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::rotate_right_parallelized);
let right_shift_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::right_shift_parallelized);
// Warning this rotate definition only works with 64-bit ciphertexts
let clear_rotate_left = |x: i64, y: u64| x.rotate_left(y as u32);
let clear_left_shift = |x: i64, y: u64| x << y;
@@ -167,9 +173,10 @@ where
),
];
// Unary Ops Executors
let neg_executor = CpuFunctionExecutor::new(&ServerKey::neg_parallelized);
let bitnot_executor = CpuFunctionExecutor::new(&ServerKey::bitnot);
let reverse_bits_executor = CpuFunctionExecutor::new(&ServerKey::reverse_bits_parallelized);
let neg_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::neg_parallelized);
let bitnot_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitnot);
let reverse_bits_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::reverse_bits_parallelized);
// Unary Ops Clear functions
let clear_neg = |x: i64| x.wrapping_neg();
let clear_bitnot = |x: i64| !x;
@@ -190,15 +197,18 @@ where
];
// Scalar binary Ops Executors
let scalar_add_executor = CpuFunctionExecutor::new(&ServerKey::scalar_add_parallelized);
let scalar_sub_executor = CpuFunctionExecutor::new(&ServerKey::scalar_sub_parallelized);
let scalar_add_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_add_parallelized);
let scalar_sub_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_sub_parallelized);
let scalar_bitwise_and_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_bitand_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_bitand_parallelized);
let scalar_bitwise_or_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_bitor_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_bitor_parallelized);
let scalar_bitwise_xor_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_bitxor_parallelized);
let scalar_mul_executor = CpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_bitxor_parallelized);
let scalar_mul_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized);
#[allow(clippy::type_complexity)]
let mut scalar_binary_ops: Vec<(
@@ -239,13 +249,13 @@ where
];
let scalar_rotate_left_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_rotate_left_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_rotate_left_parallelized);
let scalar_left_shift_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_left_shift_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_left_shift_parallelized);
let scalar_rotate_right_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_rotate_right_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_rotate_right_parallelized);
let scalar_right_shift_executor =
CpuFunctionExecutor::new(&ServerKey::scalar_right_shift_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_right_shift_parallelized);
#[allow(clippy::type_complexity)]
let mut scalar_shift_rotate_ops: Vec<(
SignedScalarShiftRotateExecutor,
@@ -276,11 +286,11 @@ where
// Overflowing Ops Executors
let overflowing_add_executor =
CpuFunctionExecutor::new(&ServerKey::signed_overflowing_add_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::signed_overflowing_add_parallelized);
let overflowing_sub_executor =
CpuFunctionExecutor::new(&ServerKey::signed_overflowing_sub_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::signed_overflowing_sub_parallelized);
let overflowing_mul_executor =
CpuFunctionExecutor::new(&ServerKey::signed_overflowing_mul_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::signed_overflowing_mul_parallelized);
// Overflowing Ops Clear functions
let clear_overflowing_add = |x: i64, y: i64| -> (i64, bool) { x.overflowing_add(y) };
let clear_overflowing_sub = |x: i64, y: i64| -> (i64, bool) { x.overflowing_sub(y) };
@@ -311,9 +321,9 @@ where
// Scalar Overflowing Ops Executors
let overflowing_scalar_add_executor =
CpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_add_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_add_parallelized);
let overflowing_scalar_sub_executor =
CpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_sub_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_sub_parallelized);
#[allow(clippy::type_complexity)]
let mut scalar_overflowing_ops: Vec<(
@@ -334,12 +344,12 @@ where
];
// Comparison Ops Executors
let gt_executor = CpuFunctionExecutor::new(&ServerKey::gt_parallelized);
let ge_executor = CpuFunctionExecutor::new(&ServerKey::ge_parallelized);
let lt_executor = CpuFunctionExecutor::new(&ServerKey::lt_parallelized);
let le_executor = CpuFunctionExecutor::new(&ServerKey::le_parallelized);
let eq_executor = CpuFunctionExecutor::new(&ServerKey::eq_parallelized);
let ne_executor = CpuFunctionExecutor::new(&ServerKey::ne_parallelized);
let gt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::gt_parallelized);
let ge_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::ge_parallelized);
let lt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::lt_parallelized);
let le_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::le_parallelized);
let eq_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::eq_parallelized);
let ne_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::ne_parallelized);
// Comparison Ops Clear functions
let clear_gt = |x: i64, y: i64| -> bool { x > y };
@@ -364,12 +374,12 @@ where
];
// Scalar Comparison Ops Executors
let scalar_gt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_gt_parallelized);
let scalar_ge_executor = CpuFunctionExecutor::new(&ServerKey::scalar_ge_parallelized);
let scalar_lt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_lt_parallelized);
let scalar_le_executor = CpuFunctionExecutor::new(&ServerKey::scalar_le_parallelized);
let scalar_eq_executor = CpuFunctionExecutor::new(&ServerKey::scalar_eq_parallelized);
let scalar_ne_executor = CpuFunctionExecutor::new(&ServerKey::scalar_ne_parallelized);
let scalar_gt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_gt_parallelized);
let scalar_ge_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_ge_parallelized);
let scalar_lt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_lt_parallelized);
let scalar_le_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_le_parallelized);
let scalar_eq_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_eq_parallelized);
let scalar_ne_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_ne_parallelized);
#[allow(clippy::type_complexity)]
let mut scalar_comparison_ops: Vec<(
@@ -410,7 +420,7 @@ where
];
// Select Executor
let select_executor = CpuFunctionExecutor::new(&ServerKey::cmux_parallelized);
let select_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::cmux_parallelized);
// Select
let clear_select = |b: bool, x: i64, y: i64| if b { x } else { y };
@@ -427,7 +437,7 @@ where
)];
// Div executor
let div_rem_executor = CpuFunctionExecutor::new(&ServerKey::div_rem_parallelized);
let div_rem_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::div_rem_parallelized);
// Div Rem Clear functions
let clear_div_rem = |x: i64, y: i64| -> (i64, i64) { (x.wrapping_div(y), x.wrapping_rem(y)) };
#[allow(clippy::type_complexity)]
@@ -443,7 +453,7 @@ where
// Scalar Div executor
let scalar_div_rem_executor =
CpuFunctionExecutor::new(&ServerKey::signed_scalar_div_rem_parallelized);
OpSequenceCpuFunctionExecutor::new(&ServerKey::signed_scalar_div_rem_parallelized);
#[allow(clippy::type_complexity)]
let mut scalar_div_rem_op: Vec<(
SignedScalarDivRemOpExecutor,
@@ -456,9 +466,11 @@ where
)];
// Log2/Hamming weight ops
let ilog2_executor = CpuFunctionExecutor::new(&ServerKey::ilog2_parallelized);
let count_zeros_executor = CpuFunctionExecutor::new(&ServerKey::count_zeros_parallelized);
let count_ones_executor = CpuFunctionExecutor::new(&ServerKey::count_ones_parallelized);
let ilog2_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::ilog2_parallelized);
let count_zeros_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::count_zeros_parallelized);
let count_ones_executor =
OpSequenceCpuFunctionExecutor::new(&ServerKey::count_ones_parallelized);
let clear_ilog2 = |x: i64| x.ilog2() as u64;
let clear_count_zeros = |x: i64| x.count_zeros() as u64;
let clear_count_ones = |x: i64| x.count_ones() as u64;
@@ -478,7 +490,7 @@ where
),
];
signed_random_op_sequence_test(
let (cks, sks, mut datagen) = signed_random_op_sequence_test_init_cpu(
param,
&mut binary_ops,
&mut unary_ops,
@@ -494,10 +506,29 @@ where
&mut rotate_shift_ops,
&mut scalar_shift_rotate_ops,
);
signed_random_op_sequence_test(
&mut datagen,
&cks,
&sks,
&mut binary_ops,
&mut unary_ops,
&mut scalar_binary_ops,
&mut overflowing_ops,
&mut scalar_overflowing_ops,
&mut comparison_ops,
&mut scalar_comparison_ops,
&mut select_op,
&mut div_rem_op,
&mut scalar_div_rem_op,
&mut log2_ops,
&mut rotate_shift_ops,
&mut scalar_shift_rotate_ops,
);
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn signed_random_op_sequence_test<P>(
pub(crate) fn signed_random_op_sequence_test_init_cpu<P>(
param: P,
binary_ops: &mut [(SignedBinaryOpExecutor, impl Fn(i64, i64) -> i64, String)],
unary_ops: &mut [(SignedUnaryOpExecutor, impl Fn(i64) -> i64, String)],
@@ -548,57 +579,14 @@ pub(crate) fn signed_random_op_sequence_test<P>(
impl Fn(i64, u64) -> i64,
String,
)],
) where
) -> (
RadixClientKey,
Arc<ServerKey>,
RandomOpSequenceDataGenerator<i64, SignedRadixCiphertext>,
)
where
P: Into<TestParameters>,
{
let param = param.into();
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT_LONG_RUN));
let mut rng = rand::thread_rng();
for x in binary_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in unary_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in scalar_binary_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in overflowing_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in scalar_overflowing_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in comparison_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in scalar_comparison_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in select_op.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in div_rem_op.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in scalar_div_rem_op.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in log2_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in rotate_shift_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
for x in scalar_rotate_shift_ops.iter_mut() {
x.0.setup(&cks, sks.clone());
}
let total_num_ops = binary_ops.len()
+ unary_ops.len()
+ scalar_binary_ops.len()
@@ -612,6 +600,132 @@ pub(crate) fn signed_random_op_sequence_test<P>(
+ log2_ops.len()
+ rotate_shift_ops.len()
+ scalar_rotate_shift_ops.len();
let param = param.into();
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
let temp_cks =
ClientKey::from_raw_parts(cks.clone(), None, None, None, None, None, Tag::default());
let comp_sks = CompressedServerKey::new(&temp_cks);
sks.set_deterministic_pbs_execution(true);
let sks = Arc::new(sks);
let cks = RadixClientKey::from((cks, NB_CTXT_LONG_RUN));
let mut datagen = get_user_defined_seed().map_or_else(
|| RandomOpSequenceDataGenerator::<i64, SignedRadixCiphertext>::new(total_num_ops, &cks),
|seed| {
RandomOpSequenceDataGenerator::<i64, SignedRadixCiphertext>::new_with_seed(
total_num_ops,
seed,
&cks,
)
},
);
println!(
"signed_random_op_sequence_test::seed = {}",
datagen.get_seed().0
);
for x in binary_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in unary_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_binary_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in overflowing_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_overflowing_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in comparison_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_comparison_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in select_op.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in div_rem_op.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_div_rem_op.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in log2_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in rotate_shift_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
for x in scalar_rotate_shift_ops.iter_mut() {
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
}
(cks, sks, datagen)
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn signed_random_op_sequence_test(
datagen: &mut RandomOpSequenceDataGenerator<i64, SignedRadixCiphertext>,
cks: &RadixClientKey,
sks: &Arc<ServerKey>,
binary_ops: &mut [(SignedBinaryOpExecutor, impl Fn(i64, i64) -> i64, String)],
unary_ops: &mut [(SignedUnaryOpExecutor, impl Fn(i64) -> i64, String)],
scalar_binary_ops: &mut [(
SignedScalarBinaryOpExecutor,
impl Fn(i64, i64) -> i64,
String,
)],
overflowing_ops: &mut [(
SignedOverflowingOpExecutor,
impl Fn(i64, i64) -> (i64, bool),
String,
)],
scalar_overflowing_ops: &mut [(
SignedScalarOverflowingOpExecutor,
impl Fn(i64, i64) -> (i64, bool),
String,
)],
comparison_ops: &mut [(
SignedComparisonOpExecutor,
impl Fn(i64, i64) -> bool,
String,
)],
scalar_comparison_ops: &mut [(
SignedScalarComparisonOpExecutor,
impl Fn(i64, i64) -> bool,
String,
)],
select_op: &mut [(
SignedSelectOpExecutor,
impl Fn(bool, i64, i64) -> i64,
String,
)],
div_rem_op: &mut [(
SignedDivRemOpExecutor,
impl Fn(i64, i64) -> (i64, i64),
String,
)],
scalar_div_rem_op: &mut [(
SignedScalarDivRemOpExecutor,
impl Fn(i64, i64) -> (i64, i64),
String,
)],
log2_ops: &mut [(SignedLog2OpExecutor, impl Fn(i64) -> u64, String)],
rotate_shift_ops: &mut [(SignedShiftRotateExecutor, impl Fn(i64, u64) -> i64, String)],
scalar_rotate_shift_ops: &mut [(
SignedScalarShiftRotateExecutor,
impl Fn(i64, u64) -> i64,
String,
)],
) {
let binary_ops_range = 0..binary_ops.len();
let unary_ops_range = binary_ops_range.end..binary_ops_range.end + unary_ops.len();
let scalar_binary_ops_range =
@@ -634,16 +748,6 @@ pub(crate) fn signed_random_op_sequence_test<P>(
let scalar_rotate_shift_ops_range =
rotate_shift_ops_range.end..rotate_shift_ops_range.end + scalar_rotate_shift_ops.len();
let mut datagen =
crate::integer::server_key::radix_parallel::tests_long_run::RandomOpSequenceDataGenerator::<
i64,
SignedRadixCiphertext,
>::new(total_num_ops, &cks);
println!(
"signed_random_op_sequence_test::seed = {}",
datagen.get_seed().0
);
for fn_index in 0..get_long_test_iterations() {
let (i, idx) = datagen.gen_op_index();
@@ -819,7 +923,7 @@ pub(crate) fn signed_random_op_sequence_test<P>(
let expected_res = clear_fn(lhs.p, rhs.p);
let res_ct: SignedRadixCiphertext = sks.cast_to_signed(
res.clone().into_radix::<SignedRadixCiphertext>(1, &sks),
res.clone().into_radix::<SignedRadixCiphertext>(1, sks),
NB_CTXT_LONG_RUN,
);
@@ -851,7 +955,7 @@ pub(crate) fn signed_random_op_sequence_test<P>(
let expected_res = clear_fn(lhs.p, rhs.p);
let res_ct: SignedRadixCiphertext = sks.cast_to_signed(
res.clone().into_radix::<SignedRadixCiphertext>(1, &sks),
res.clone().into_radix::<SignedRadixCiphertext>(1, sks),
NB_CTXT_LONG_RUN,
);
@@ -875,7 +979,7 @@ pub(crate) fn signed_random_op_sequence_test<P>(
let (lhs, rhs) = datagen.gen_op_operands(idx, fn_name);
let clear_bool: bool = rng.gen_bool(0.5);
let clear_bool: bool = datagen.gen_bool_uniform();
let bool_input = cks.encrypt_bool(clear_bool);
let res = select_op_executor.execute((&bool_input, &lhs.c, &rhs.c));

View File

@@ -29,8 +29,10 @@ pub(crate) mod test_vector_comparisons;
pub(crate) mod test_vector_find;
use super::tests_cases_unsigned::*;
use crate::core_crypto::commons::generators::DeterministicSeeder;
use crate::core_crypto::prelude::UnsignedInteger;
use crate::integer::keycache::KEY_CACHE;
use crate::integer::server_key::radix_parallel::tests_long_run::OpSequenceFunctionExecutor;
use crate::integer::tests::create_parameterized_test;
use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey};
use crate::shortint::ciphertext::MaxDegree;
@@ -38,9 +40,11 @@ use crate::shortint::ciphertext::MaxDegree;
use crate::shortint::parameters::coverage_parameters::*;
use crate::shortint::parameters::test_params::*;
use crate::shortint::parameters::*;
use crate::CompressedServerKey;
use rand::prelude::ThreadRng;
use rand::Rng;
use std::sync::Arc;
use tfhe_csprng::generators::DefaultRandomGenerator;
#[cfg(not(tarpaulin))]
pub(crate) const NB_CTXT: usize = 4;
@@ -619,6 +623,111 @@ where
}
}
/// The function executor for cpu server key
///
/// It will mainly simply forward call to a server key method
pub(crate) struct OpSequenceCpuFunctionExecutor<F> {
/// The server key is set later, when the test cast calls setup
pub(crate) sks: Option<Arc<ServerKey>>,
/// The server key function which will be called
pub(crate) func: F,
}
impl<F> OpSequenceCpuFunctionExecutor<F> {
pub(crate) fn new(func: F) -> Self {
Self { sks: None, func }
}
pub(crate) fn setup_from_cpu_keys(&mut self, sks: &CompressedServerKey) {
let (isks, _, _, _, _, _, _, _) = sks.decompress().into_raw_parts();
self.sks = Some(Arc::new(isks));
}
}
/// For unary operations
///
/// Note, we need to `NotTuple` constraint to avoid conflicts with binary or ternary operations
impl<F, I1, O> OpSequenceFunctionExecutor<I1, O> for OpSequenceCpuFunctionExecutor<F>
where
F: Fn(&ServerKey, I1) -> O,
I1: NotTuple,
{
fn setup(
&mut self,
_cks: &RadixClientKey,
sks: &CompressedServerKey,
_seeder: &mut DeterministicSeeder<DefaultRandomGenerator>,
) {
let (isks, _, _, _, _, _, _, _) = sks.decompress().into_raw_parts();
self.sks = Some(Arc::new(isks));
}
fn execute(&mut self, input: I1) -> O {
let sks = self.sks.as_ref().expect("setup was not properly called");
(self.func)(sks, input)
}
}
/// For binary operations
impl<F, I1, I2, O> OpSequenceFunctionExecutor<(I1, I2), O> for OpSequenceCpuFunctionExecutor<F>
where
F: Fn(&ServerKey, I1, I2) -> O,
{
fn setup(
&mut self,
_cks: &RadixClientKey,
sks: &CompressedServerKey,
_seeder: &mut DeterministicSeeder<DefaultRandomGenerator>,
) {
self.setup_from_cpu_keys(sks);
}
fn execute(&mut self, input: (I1, I2)) -> O {
let sks = self.sks.as_ref().expect("setup was not properly called");
(self.func)(sks, input.0, input.1)
}
}
/// For ternary operations
impl<F, I1, I2, I3, O> OpSequenceFunctionExecutor<(I1, I2, I3), O>
for OpSequenceCpuFunctionExecutor<F>
where
F: Fn(&ServerKey, I1, I2, I3) -> O,
{
fn setup(
&mut self,
_cks: &RadixClientKey,
sks: &CompressedServerKey,
_seeder: &mut DeterministicSeeder<DefaultRandomGenerator>,
) {
self.setup_from_cpu_keys(sks);
}
fn execute(&mut self, input: (I1, I2, I3)) -> O {
let sks = self.sks.as_ref().expect("setup was not properly called");
(self.func)(sks, input.0, input.1, input.2)
}
}
/// For 4-ary operations
impl<F, I1, I2, I3, I4, O> OpSequenceFunctionExecutor<(I1, I2, I3, I4), O>
for OpSequenceCpuFunctionExecutor<F>
where
F: Fn(&ServerKey, I1, I2, I3, I4) -> O,
{
fn setup(
&mut self,
_cks: &RadixClientKey,
sks: &CompressedServerKey,
_seeder: &mut DeterministicSeeder<DefaultRandomGenerator>,
) {
self.setup_from_cpu_keys(sks);
}
fn execute(&mut self, input: (I1, I2, I3, I4)) -> O {
let sks = self.sks.as_ref().expect("setup was not properly called");
(self.func)(sks, input.0, input.1, input.2, input.3)
}
}
//=============================================================================
// Unchecked Tests
//=============================================================================