mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 06:13:58 -05:00
chore(gpu): make deterministic long run GPU test
This commit is contained in:
committed by
Andrei Stoian
parent
45a849ad36
commit
73de886c07
@@ -287,7 +287,7 @@ mod gpu {
|
||||
}
|
||||
|
||||
impl CudaGpuChoice {
|
||||
pub(in crate::high_level_api) fn build_streams(self) -> CudaStreams {
|
||||
pub(crate) fn build_streams(self) -> CudaStreams {
|
||||
match self {
|
||||
Self::Single(idx) => CudaStreams::new_single_gpu(idx),
|
||||
Self::Multi => CudaStreams::new_multi_gpu(),
|
||||
|
||||
@@ -53,6 +53,8 @@ use backward_compatibility::compressed_ciphertext_list::SquashedNoiseCiphertextS
|
||||
pub use config::{Config, ConfigBuilder};
|
||||
#[cfg(feature = "gpu")]
|
||||
pub use global_state::CudaGpuChoice;
|
||||
#[cfg(feature = "gpu")]
|
||||
pub use global_state::CustomMultiGpuIndexes;
|
||||
pub use global_state::{set_server_key, unset_server_key, with_server_key_as_context};
|
||||
|
||||
pub use integers::{
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::create_gpu_parameterized_test;
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::server_key::radix_parallel::tests_long_run::test_erc20::{
|
||||
|
||||
@@ -1,17 +1,137 @@
|
||||
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_long_run::OpSequenceGpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::create_gpu_parameterized_test;
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::server_key::radix_parallel::tests_long_run::test_random_op_sequence::{
|
||||
random_op_sequence_test, BinaryOpExecutor, ComparisonOpExecutor, DivRemOpExecutor,
|
||||
Log2OpExecutor, OverflowingOpExecutor, ScalarBinaryOpExecutor, ScalarComparisonOpExecutor,
|
||||
ScalarDivRemOpExecutor, ScalarOverflowingOpExecutor, SelectOpExecutor, UnaryOpExecutor,
|
||||
};
|
||||
use crate::integer::server_key::radix_parallel::tests_long_run::{
|
||||
get_user_defined_seed, RandomOpSequenceDataGenerator, NB_CTXT_LONG_RUN,
|
||||
};
|
||||
use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey};
|
||||
use crate::shortint::parameters::*;
|
||||
use crate::{ClientKey, CompressedServerKey, Tag};
|
||||
use std::cmp::{max, min};
|
||||
use std::sync::Arc;
|
||||
|
||||
create_gpu_parameterized_test!(random_op_sequence {
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
|
||||
});
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn random_op_sequence_test_init_gpu<P>(
|
||||
param: P,
|
||||
binary_ops: &mut [(BinaryOpExecutor, impl Fn(u64, u64) -> u64, String)],
|
||||
unary_ops: &mut [(UnaryOpExecutor, impl Fn(u64) -> u64, String)],
|
||||
scalar_binary_ops: &mut [(ScalarBinaryOpExecutor, impl Fn(u64, u64) -> u64, String)],
|
||||
overflowing_ops: &mut [(
|
||||
OverflowingOpExecutor,
|
||||
impl Fn(u64, u64) -> (u64, bool),
|
||||
String,
|
||||
)],
|
||||
scalar_overflowing_ops: &mut [(
|
||||
ScalarOverflowingOpExecutor,
|
||||
impl Fn(u64, u64) -> (u64, bool),
|
||||
String,
|
||||
)],
|
||||
comparison_ops: &mut [(ComparisonOpExecutor, impl Fn(u64, u64) -> bool, String)],
|
||||
scalar_comparison_ops: &mut [(
|
||||
ScalarComparisonOpExecutor,
|
||||
impl Fn(u64, u64) -> bool,
|
||||
String,
|
||||
)],
|
||||
select_op: &mut [(SelectOpExecutor, impl Fn(bool, u64, u64) -> u64, String)],
|
||||
div_rem_op: &mut [(DivRemOpExecutor, impl Fn(u64, u64) -> (u64, u64), String)],
|
||||
scalar_div_rem_op: &mut [(
|
||||
ScalarDivRemOpExecutor,
|
||||
impl Fn(u64, u64) -> (u64, u64),
|
||||
String,
|
||||
)],
|
||||
log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)],
|
||||
) -> (
|
||||
RadixClientKey,
|
||||
Arc<ServerKey>,
|
||||
RandomOpSequenceDataGenerator<u64, RadixCiphertext>,
|
||||
)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks0, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
let cks = ClientKey::from_raw_parts(cks0.clone(), None, None, None, None, None, Tag::default());
|
||||
let comp_sks = CompressedServerKey::new(&cks);
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
|
||||
println!("Setting up operations");
|
||||
let cks = RadixClientKey::from((cks0, NB_CTXT_LONG_RUN));
|
||||
|
||||
let total_num_ops = binary_ops.len()
|
||||
+ unary_ops.len()
|
||||
+ scalar_binary_ops.len()
|
||||
+ overflowing_ops.len()
|
||||
+ scalar_overflowing_ops.len()
|
||||
+ comparison_ops.len()
|
||||
+ scalar_comparison_ops.len()
|
||||
+ select_op.len()
|
||||
+ div_rem_op.len()
|
||||
+ scalar_div_rem_op.len()
|
||||
+ log2_ops.len();
|
||||
println!("Total num ops {total_num_ops}");
|
||||
|
||||
let mut datagen = get_user_defined_seed().map_or_else(
|
||||
|| RandomOpSequenceDataGenerator::<u64, RadixCiphertext>::new(total_num_ops, &cks),
|
||||
|seed| {
|
||||
RandomOpSequenceDataGenerator::<u64, RadixCiphertext>::new_with_seed(
|
||||
total_num_ops,
|
||||
seed,
|
||||
&cks,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
println!("random_op_sequence_test::seed = {}", datagen.get_seed().0);
|
||||
|
||||
for x in binary_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in unary_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_binary_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in overflowing_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_overflowing_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in comparison_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_comparison_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in select_op.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in div_rem_op.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_div_rem_op.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in log2_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
|
||||
(cks, sks, datagen)
|
||||
}
|
||||
|
||||
fn random_op_sequence<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters> + Clone,
|
||||
@@ -19,18 +139,24 @@ where
|
||||
println!("Running random op sequence test");
|
||||
|
||||
// Binary Ops Executors
|
||||
let add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
|
||||
let sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
|
||||
let bitwise_and_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitand);
|
||||
let bitwise_or_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitor);
|
||||
let bitwise_xor_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitxor);
|
||||
let mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::mul);
|
||||
let rotate_left_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
|
||||
let left_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
|
||||
let rotate_right_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
|
||||
let right_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
|
||||
let max_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
|
||||
let min_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
|
||||
let add_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
|
||||
let sub_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
|
||||
let bitwise_and_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitand);
|
||||
let bitwise_or_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitor);
|
||||
let bitwise_xor_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitxor);
|
||||
let mul_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::mul);
|
||||
let rotate_left_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
|
||||
let left_shift_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
|
||||
let rotate_right_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
|
||||
let right_shift_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
|
||||
let max_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
|
||||
let min_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
|
||||
|
||||
// Binary Ops Clear functions
|
||||
let clear_add = |x, y| x + y;
|
||||
@@ -93,10 +219,10 @@ where
|
||||
];
|
||||
|
||||
// Unary Ops Executors
|
||||
let neg_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::neg);
|
||||
let bitnot_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitnot);
|
||||
let neg_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::neg);
|
||||
let bitnot_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitnot);
|
||||
//let reverse_bits_executor =
|
||||
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::reverse_bits); Unary Ops Clear
|
||||
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::reverse_bits); Unary Ops Clear
|
||||
// functions
|
||||
let clear_neg = |x: u64| x.wrapping_neg();
|
||||
let clear_bitnot = |x: u64| !x;
|
||||
@@ -117,23 +243,26 @@ where
|
||||
];
|
||||
|
||||
// Scalar binary Ops Executors
|
||||
let scalar_add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
|
||||
let scalar_sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
|
||||
let scalar_add_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
|
||||
let scalar_sub_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
|
||||
let scalar_bitwise_and_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitand);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitand);
|
||||
let scalar_bitwise_or_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitor);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitor);
|
||||
let scalar_bitwise_xor_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitxor);
|
||||
let scalar_mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_mul);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitxor);
|
||||
let scalar_mul_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_mul);
|
||||
let scalar_rotate_left_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
|
||||
let scalar_left_shift_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
|
||||
let scalar_rotate_right_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
|
||||
let scalar_right_shift_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_binary_ops: Vec<(ScalarBinaryOpExecutor, &dyn Fn(u64, u64) -> u64, String)> = vec![
|
||||
@@ -191,11 +320,11 @@ where
|
||||
|
||||
// Overflowing Ops Executors
|
||||
let overflowing_add_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_add);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_add);
|
||||
let overflowing_sub_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_sub);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_sub);
|
||||
//let overflowing_mul_executor =
|
||||
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_mul);
|
||||
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_mul);
|
||||
|
||||
// Overflowing Ops Clear functions
|
||||
let clear_overflowing_add = |x: u64, y: u64| -> (u64, bool) { x.overflowing_add(y) };
|
||||
@@ -226,10 +355,12 @@ where
|
||||
];
|
||||
|
||||
// Scalar Overflowing Ops Executors
|
||||
let overflowing_scalar_add_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_scalar_add);
|
||||
let overflowing_scalar_add_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(
|
||||
&CudaServerKey::unsigned_overflowing_scalar_add,
|
||||
);
|
||||
// let overflowing_scalar_sub_executor =
|
||||
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_scalar_sub);
|
||||
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&
|
||||
// CudaServerKey::unsigned_overflowing_scalar_sub);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_overflowing_ops: Vec<(
|
||||
@@ -250,12 +381,12 @@ where
|
||||
];
|
||||
|
||||
// Comparison Ops Executors
|
||||
let gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::gt);
|
||||
let ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
|
||||
let lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
|
||||
let le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
|
||||
let eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
|
||||
let ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
|
||||
let gt_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::gt);
|
||||
let ge_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
|
||||
let lt_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
|
||||
let le_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
|
||||
let eq_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
|
||||
let ne_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
|
||||
|
||||
// Comparison Ops Clear functions
|
||||
let clear_gt = |x: u64, y: u64| -> bool { x > y };
|
||||
@@ -276,12 +407,18 @@ where
|
||||
];
|
||||
|
||||
// Scalar Comparison Ops Executors
|
||||
let scalar_gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_gt);
|
||||
let scalar_ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
|
||||
let scalar_lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
|
||||
let scalar_le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
|
||||
let scalar_eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
|
||||
let scalar_ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
|
||||
let scalar_gt_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_gt);
|
||||
let scalar_ge_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
|
||||
let scalar_lt_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
|
||||
let scalar_le_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
|
||||
let scalar_eq_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
|
||||
let scalar_ne_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_comparison_ops: Vec<(
|
||||
@@ -322,7 +459,8 @@ where
|
||||
];
|
||||
|
||||
// Select Executor
|
||||
let select_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::if_then_else);
|
||||
let select_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::if_then_else);
|
||||
|
||||
// Select
|
||||
let clear_select = |b: bool, x: u64, y: u64| if b { x } else { y };
|
||||
@@ -335,7 +473,7 @@ where
|
||||
)];
|
||||
|
||||
// Div executor
|
||||
let div_rem_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::div_rem);
|
||||
let div_rem_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::div_rem);
|
||||
// Div Rem Clear functions
|
||||
let clear_div_rem = |x: u64, y: u64| -> (u64, u64) { (x.wrapping_div(y), x.wrapping_rem(y)) };
|
||||
#[allow(clippy::type_complexity)]
|
||||
@@ -347,7 +485,7 @@ where
|
||||
|
||||
// Scalar Div executor
|
||||
let scalar_div_rem_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_div_rem);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_div_rem);
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_div_rem_op: Vec<(
|
||||
ScalarDivRemOpExecutor,
|
||||
@@ -360,9 +498,11 @@ where
|
||||
)];
|
||||
|
||||
// Log2/Hamming weight ops
|
||||
let ilog2_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ilog2);
|
||||
//let count_zeros_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_zeros);
|
||||
//let count_ones_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_ones);
|
||||
let ilog2_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ilog2);
|
||||
//let count_zeros_executor =
|
||||
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_zeros);
|
||||
// let count_ones_executor =
|
||||
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_ones);
|
||||
let clear_ilog2 = |x: u64| x.ilog2() as u64;
|
||||
//let clear_count_zeros = |x: u64| x.count_zeros() as u64;
|
||||
//let clear_count_ones = |x: u64| x.count_ones() as u64;
|
||||
@@ -382,7 +522,7 @@ where
|
||||
//),
|
||||
];
|
||||
|
||||
random_op_sequence_test(
|
||||
let (cks, sks, mut datagen) = random_op_sequence_test_init_gpu(
|
||||
param,
|
||||
&mut binary_ops,
|
||||
&mut unary_ops,
|
||||
@@ -396,4 +536,21 @@ where
|
||||
&mut scalar_div_rem_op,
|
||||
&mut log2_ops,
|
||||
);
|
||||
|
||||
random_op_sequence_test(
|
||||
&mut datagen,
|
||||
&cks,
|
||||
&sks,
|
||||
&mut binary_ops,
|
||||
&mut unary_ops,
|
||||
&mut scalar_binary_ops,
|
||||
&mut overflowing_ops,
|
||||
&mut scalar_overflowing_ops,
|
||||
&mut comparison_ops,
|
||||
&mut scalar_comparison_ops,
|
||||
&mut select_op,
|
||||
&mut div_rem_op,
|
||||
&mut scalar_div_rem_op,
|
||||
&mut log2_ops,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::create_gpu_parameterized_test;
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::server_key::radix_parallel::tests_long_run::test_signed_erc20::{
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_long_run::OpSequenceGpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::create_gpu_parameterized_test;
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::server_key::radix_parallel::tests_long_run::test_signed_random_op_sequence::{
|
||||
signed_random_op_sequence_test, SignedBinaryOpExecutor, SignedComparisonOpExecutor,
|
||||
SignedDivRemOpExecutor, SignedLog2OpExecutor, SignedOverflowingOpExecutor,
|
||||
@@ -8,25 +9,177 @@ use crate::integer::server_key::radix_parallel::tests_long_run::test_signed_rand
|
||||
SignedScalarOverflowingOpExecutor, SignedScalarShiftRotateExecutor, SignedSelectOpExecutor,
|
||||
SignedShiftRotateExecutor, SignedUnaryOpExecutor,
|
||||
};
|
||||
use crate::integer::server_key::radix_parallel::tests_long_run::{
|
||||
get_user_defined_seed, RandomOpSequenceDataGenerator, NB_CTXT_LONG_RUN,
|
||||
};
|
||||
use crate::integer::{IntegerKeyKind, RadixClientKey, ServerKey, SignedRadixCiphertext};
|
||||
use crate::shortint::parameters::*;
|
||||
use crate::{ClientKey, CompressedServerKey, Tag};
|
||||
use std::cmp::{max, min};
|
||||
use std::sync::Arc;
|
||||
|
||||
create_gpu_parameterized_test!(signed_random_op_sequence {
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
|
||||
});
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn signed_random_op_sequence_test_init_gpu<P>(
|
||||
param: P,
|
||||
binary_ops: &mut [(SignedBinaryOpExecutor, impl Fn(i64, i64) -> i64, String)],
|
||||
unary_ops: &mut [(SignedUnaryOpExecutor, impl Fn(i64) -> i64, String)],
|
||||
scalar_binary_ops: &mut [(
|
||||
SignedScalarBinaryOpExecutor,
|
||||
impl Fn(i64, i64) -> i64,
|
||||
String,
|
||||
)],
|
||||
overflowing_ops: &mut [(
|
||||
SignedOverflowingOpExecutor,
|
||||
impl Fn(i64, i64) -> (i64, bool),
|
||||
String,
|
||||
)],
|
||||
scalar_overflowing_ops: &mut [(
|
||||
SignedScalarOverflowingOpExecutor,
|
||||
impl Fn(i64, i64) -> (i64, bool),
|
||||
String,
|
||||
)],
|
||||
comparison_ops: &mut [(
|
||||
SignedComparisonOpExecutor,
|
||||
impl Fn(i64, i64) -> bool,
|
||||
String,
|
||||
)],
|
||||
scalar_comparison_ops: &mut [(
|
||||
SignedScalarComparisonOpExecutor,
|
||||
impl Fn(i64, i64) -> bool,
|
||||
String,
|
||||
)],
|
||||
select_op: &mut [(
|
||||
SignedSelectOpExecutor,
|
||||
impl Fn(bool, i64, i64) -> i64,
|
||||
String,
|
||||
)],
|
||||
div_rem_op: &mut [(
|
||||
SignedDivRemOpExecutor,
|
||||
impl Fn(i64, i64) -> (i64, i64),
|
||||
String,
|
||||
)],
|
||||
scalar_div_rem_op: &mut [(
|
||||
SignedScalarDivRemOpExecutor,
|
||||
impl Fn(i64, i64) -> (i64, i64),
|
||||
String,
|
||||
)],
|
||||
log2_ops: &mut [(SignedLog2OpExecutor, impl Fn(i64) -> u64, String)],
|
||||
rotate_shift_ops: &mut [(SignedShiftRotateExecutor, impl Fn(i64, u64) -> i64, String)],
|
||||
scalar_rotate_shift_ops: &mut [(
|
||||
SignedScalarShiftRotateExecutor,
|
||||
impl Fn(i64, u64) -> i64,
|
||||
String,
|
||||
)],
|
||||
) -> (
|
||||
RadixClientKey,
|
||||
Arc<ServerKey>,
|
||||
RandomOpSequenceDataGenerator<i64, SignedRadixCiphertext>,
|
||||
)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks0, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
let total_num_ops = binary_ops.len()
|
||||
+ unary_ops.len()
|
||||
+ scalar_binary_ops.len()
|
||||
+ overflowing_ops.len()
|
||||
+ scalar_overflowing_ops.len()
|
||||
+ comparison_ops.len()
|
||||
+ scalar_comparison_ops.len()
|
||||
+ select_op.len()
|
||||
+ div_rem_op.len()
|
||||
+ scalar_div_rem_op.len()
|
||||
+ log2_ops.len()
|
||||
+ rotate_shift_ops.len()
|
||||
+ scalar_rotate_shift_ops.len();
|
||||
|
||||
let cks = ClientKey::from_raw_parts(cks0.clone(), None, None, None, None, None, Tag::default());
|
||||
let comp_sks = CompressedServerKey::new(&cks);
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
let cks = RadixClientKey::from((cks0, NB_CTXT_LONG_RUN));
|
||||
|
||||
let mut datagen = get_user_defined_seed().map_or_else(
|
||||
|| RandomOpSequenceDataGenerator::<i64, SignedRadixCiphertext>::new(total_num_ops, &cks),
|
||||
|seed| {
|
||||
RandomOpSequenceDataGenerator::<i64, SignedRadixCiphertext>::new_with_seed(
|
||||
total_num_ops,
|
||||
seed,
|
||||
&cks,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
println!(
|
||||
"signed_random_op_sequence_test::seed = {}",
|
||||
datagen.get_seed().0
|
||||
);
|
||||
|
||||
for x in binary_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in unary_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_binary_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in overflowing_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_overflowing_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in comparison_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_comparison_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in select_op.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in div_rem_op.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_div_rem_op.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in log2_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in rotate_shift_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_rotate_shift_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
|
||||
(cks, sks, datagen)
|
||||
}
|
||||
|
||||
fn signed_random_op_sequence<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters> + Clone,
|
||||
{
|
||||
// Binary Ops Executors
|
||||
let add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
|
||||
let sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
|
||||
let bitwise_and_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitand);
|
||||
let bitwise_or_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitor);
|
||||
let bitwise_xor_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitxor);
|
||||
let mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::mul);
|
||||
let max_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
|
||||
let min_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
|
||||
let add_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add);
|
||||
let sub_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub);
|
||||
let bitwise_and_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitand);
|
||||
let bitwise_or_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitor);
|
||||
let bitwise_xor_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitxor);
|
||||
let mul_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::mul);
|
||||
let max_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max);
|
||||
let min_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min);
|
||||
|
||||
// Binary Ops Clear functions
|
||||
let clear_add = |x, y| x + y;
|
||||
@@ -62,10 +215,14 @@ where
|
||||
(Box::new(min_executor), &clear_min, "min".to_string()),
|
||||
];
|
||||
|
||||
let rotate_left_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
|
||||
let left_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
|
||||
let rotate_right_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
|
||||
let right_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
|
||||
let rotate_left_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left);
|
||||
let left_shift_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift);
|
||||
let rotate_right_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right);
|
||||
let right_shift_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift);
|
||||
// Warning this rotate definition only works with 64-bit ciphertexts
|
||||
let clear_rotate_left = |x: i64, y: u64| x.rotate_left(y as u32);
|
||||
let clear_left_shift = |x: i64, y: u64| x << y;
|
||||
@@ -101,11 +258,11 @@ where
|
||||
];
|
||||
|
||||
// Unary Ops Executors
|
||||
let neg_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::neg);
|
||||
let bitnot_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitnot);
|
||||
let abs_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::abs);
|
||||
let neg_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::neg);
|
||||
let bitnot_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitnot);
|
||||
let abs_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::abs);
|
||||
//let reverse_bits_executor =
|
||||
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::reverse_bits); Unary Ops Clear
|
||||
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::reverse_bits); Unary Ops Clear
|
||||
// functions
|
||||
let clear_neg = |x: i64| x.wrapping_neg();
|
||||
let clear_abs = |x: i64| x.wrapping_abs();
|
||||
@@ -129,15 +286,18 @@ where
|
||||
];
|
||||
|
||||
// Scalar binary Ops Executors
|
||||
let scalar_add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
|
||||
let scalar_sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
|
||||
let scalar_add_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add);
|
||||
let scalar_sub_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub);
|
||||
let scalar_bitwise_and_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitand);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitand);
|
||||
let scalar_bitwise_or_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitor);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitor);
|
||||
let scalar_bitwise_xor_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitxor);
|
||||
let scalar_mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_mul);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitxor);
|
||||
let scalar_mul_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_mul);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_binary_ops: Vec<(
|
||||
@@ -178,13 +338,13 @@ where
|
||||
];
|
||||
|
||||
let scalar_rotate_left_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left);
|
||||
let scalar_left_shift_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift);
|
||||
let scalar_rotate_right_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right);
|
||||
let scalar_right_shift_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift);
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_shift_rotate_ops: Vec<(
|
||||
SignedScalarShiftRotateExecutor,
|
||||
@@ -215,11 +375,11 @@ where
|
||||
|
||||
// Overflowing Ops Executors
|
||||
let overflowing_add_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_add);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_add);
|
||||
let overflowing_sub_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_sub);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_sub);
|
||||
//let overflowing_mul_executor =
|
||||
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_mul);
|
||||
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_mul);
|
||||
|
||||
// Overflowing Ops Clear functions
|
||||
let clear_overflowing_add = |x: i64, y: i64| -> (i64, bool) { x.overflowing_add(y) };
|
||||
@@ -250,10 +410,12 @@ where
|
||||
];
|
||||
|
||||
// Scalar Overflowing Ops Executors
|
||||
let overflowing_scalar_add_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_scalar_add);
|
||||
let overflowing_scalar_add_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(
|
||||
&CudaServerKey::signed_overflowing_scalar_add,
|
||||
);
|
||||
// let overflowing_scalar_sub_executor =
|
||||
// GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_overflowing_scalar_sub);
|
||||
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&
|
||||
// CudaServerKey::signed_overflowing_scalar_sub);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_overflowing_ops: Vec<(
|
||||
@@ -274,12 +436,12 @@ where
|
||||
];
|
||||
|
||||
// Comparison Ops Executors
|
||||
let gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::gt);
|
||||
let ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
|
||||
let lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
|
||||
let le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
|
||||
let eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
|
||||
let ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
|
||||
let gt_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::gt);
|
||||
let ge_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge);
|
||||
let lt_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt);
|
||||
let le_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le);
|
||||
let eq_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq);
|
||||
let ne_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne);
|
||||
|
||||
// Comparison Ops Clear functions
|
||||
let clear_gt = |x: i64, y: i64| -> bool { x > y };
|
||||
@@ -304,12 +466,18 @@ where
|
||||
];
|
||||
|
||||
// Scalar Comparison Ops Executors
|
||||
let scalar_gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_gt);
|
||||
let scalar_ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
|
||||
let scalar_lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
|
||||
let scalar_le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
|
||||
let scalar_eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
|
||||
let scalar_ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
|
||||
let scalar_gt_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_gt);
|
||||
let scalar_ge_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge);
|
||||
let scalar_lt_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt);
|
||||
let scalar_le_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le);
|
||||
let scalar_eq_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq);
|
||||
let scalar_ne_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_comparison_ops: Vec<(
|
||||
@@ -350,7 +518,8 @@ where
|
||||
];
|
||||
|
||||
// Select Executor
|
||||
let select_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::if_then_else);
|
||||
let select_executor =
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::if_then_else);
|
||||
|
||||
// Select
|
||||
let clear_select = |b: bool, x: i64, y: i64| if b { x } else { y };
|
||||
@@ -367,7 +536,7 @@ where
|
||||
)];
|
||||
|
||||
// Div executor
|
||||
let div_rem_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::div_rem);
|
||||
let div_rem_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::div_rem);
|
||||
// Div Rem Clear functions
|
||||
let clear_div_rem = |x: i64, y: i64| -> (i64, i64) { (x.wrapping_div(y), x.wrapping_rem(y)) };
|
||||
#[allow(clippy::type_complexity)]
|
||||
@@ -383,7 +552,7 @@ where
|
||||
|
||||
// Scalar Div executor
|
||||
let scalar_div_rem_executor =
|
||||
GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_scalar_div_rem);
|
||||
OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::signed_scalar_div_rem);
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_div_rem_op: Vec<(
|
||||
SignedScalarDivRemOpExecutor,
|
||||
@@ -396,9 +565,11 @@ where
|
||||
)];
|
||||
|
||||
// Log2/Hamming weight ops
|
||||
let ilog2_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ilog2);
|
||||
//let count_zeros_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_zeros);
|
||||
//let count_ones_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_ones);
|
||||
let ilog2_executor = OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ilog2);
|
||||
//let count_zeros_executor =
|
||||
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_zeros);
|
||||
// let count_ones_executor =
|
||||
// OpSequenceGpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_ones);
|
||||
let clear_ilog2 = |x: i64| x.ilog2() as u64;
|
||||
//let clear_count_zeros = |x: i64| x.count_zeros() as i64;
|
||||
//let clear_count_ones = |x: i64| x.count_ones() as i64;
|
||||
@@ -418,7 +589,7 @@ where
|
||||
//),
|
||||
];
|
||||
|
||||
signed_random_op_sequence_test(
|
||||
let (cks, sks, mut datagen) = signed_random_op_sequence_test_init_gpu(
|
||||
param,
|
||||
&mut binary_ops,
|
||||
&mut unary_ops,
|
||||
@@ -434,4 +605,23 @@ where
|
||||
&mut shift_rotate_ops,
|
||||
&mut scalar_shift_rotate_ops,
|
||||
);
|
||||
|
||||
signed_random_op_sequence_test(
|
||||
&mut datagen,
|
||||
&cks,
|
||||
&sks,
|
||||
&mut binary_ops,
|
||||
&mut unary_ops,
|
||||
&mut scalar_binary_ops,
|
||||
&mut overflowing_ops,
|
||||
&mut scalar_overflowing_ops,
|
||||
&mut comparison_ops,
|
||||
&mut scalar_comparison_ops,
|
||||
&mut select_op,
|
||||
&mut div_rem_op,
|
||||
&mut scalar_div_rem_op,
|
||||
&mut log2_ops,
|
||||
&mut shift_rotate_ops,
|
||||
&mut scalar_shift_rotate_ops,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -20,15 +20,18 @@ pub(crate) mod test_shift;
|
||||
pub(crate) mod test_sub;
|
||||
pub(crate) mod test_vector_comparisons;
|
||||
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::GpuFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::{GpuContext, GpuFunctionExecutor};
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
|
||||
use crate::integer::{
|
||||
BooleanBlock, RadixCiphertext, RadixClientKey, ServerKey, SignedRadixCiphertext,
|
||||
BooleanBlock, RadixCiphertext, RadixClientKey, ServerKey, SignedRadixCiphertext, U256,
|
||||
};
|
||||
use crate::GpuIndex;
|
||||
use rand::seq::SliceRandom;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// For default/unchecked unary functions
|
||||
@@ -643,3 +646,684 @@ where
|
||||
d_block.to_boolean_block(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct GpuMultiDeviceFunctionExecutor<F> {
|
||||
pub(crate) context: Option<GpuContext>,
|
||||
pub(crate) func: F,
|
||||
}
|
||||
|
||||
impl<F> GpuMultiDeviceFunctionExecutor<F> {
|
||||
pub(crate) fn new(func: F) -> Self {
|
||||
Self {
|
||||
context: None,
|
||||
func,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> GpuMultiDeviceFunctionExecutor<F> {
|
||||
pub(crate) fn setup_from_keys(&mut self, cks: &RadixClientKey, _sks: &Arc<ServerKey>) {
|
||||
// Sample a random subset of 1-N gpus, where N is the number of available GPUs
|
||||
// A GPU index should not appear twice in the subset
|
||||
let num_gpus = get_number_of_gpus();
|
||||
let mut rng = rand::thread_rng();
|
||||
let num_gpus_to_use = rng.gen_range(1..=num_gpus as usize);
|
||||
let mut all_gpu_indexes: Vec<u32> = (0..num_gpus).collect();
|
||||
all_gpu_indexes.shuffle(&mut rng);
|
||||
let gpu_indexes_to_use = &all_gpu_indexes[..num_gpus_to_use];
|
||||
let gpu_indexes: Vec<GpuIndex> = gpu_indexes_to_use
|
||||
.iter()
|
||||
.map(|idx| GpuIndex::new(*idx))
|
||||
.collect();
|
||||
println!("Setting up server key on GPUs: [{gpu_indexes_to_use:?}]");
|
||||
|
||||
let streams = CudaStreams::new_multi_gpu_with_indexes(&gpu_indexes);
|
||||
|
||||
let sks = CudaServerKey::new(cks.as_ref(), &streams);
|
||||
streams.synchronize();
|
||||
let context = GpuContext { streams, sks };
|
||||
self.context = Some(context);
|
||||
}
|
||||
}
|
||||
|
||||
/// For default/unchecked binary signed functions
|
||||
impl<'a, F>
|
||||
FunctionExecutor<(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), SignedRadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
&CudaSignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
) -> SignedRadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
let d_ctxt_2 =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.streams);
|
||||
|
||||
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
|
||||
|
||||
gpu_result.to_signed_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F>
|
||||
FunctionExecutor<(&'a SignedRadixCiphertext, &'a RadixCiphertext), SignedRadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a SignedRadixCiphertext, &'a RadixCiphertext),
|
||||
) -> SignedRadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
let d_ctxt_2 =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
|
||||
|
||||
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
|
||||
|
||||
gpu_result.to_signed_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
/// For unchecked/default assign binary functions
|
||||
impl<'a, F> FunctionExecutor<(&'a mut SignedRadixCiphertext, &'a SignedRadixCiphertext), ()>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&CudaServerKey, &mut CudaSignedRadixCiphertext, &CudaSignedRadixCiphertext, &CudaStreams),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a mut SignedRadixCiphertext, &'a SignedRadixCiphertext)) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let mut d_ctxt_1 =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
let d_ctxt_2 =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.streams);
|
||||
|
||||
(self.func)(&context.sks, &mut d_ctxt_1, &d_ctxt_2, &context.streams);
|
||||
|
||||
*input.0 = d_ctxt_1.to_signed_radix_ciphertext(&context.streams);
|
||||
}
|
||||
}
|
||||
|
||||
/// For unchecked/default binary functions with one scalar input
|
||||
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
i64,
|
||||
&CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a SignedRadixCiphertext, i64)) -> SignedRadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
gpu_result.to_signed_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
/// For unchecked/default binary functions with one scalar input
|
||||
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, u64), SignedRadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
u64,
|
||||
&CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a SignedRadixCiphertext, u64)) -> SignedRadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
gpu_result.to_signed_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
/// For unchecked/default binary functions with one scalar input
|
||||
impl<F> FunctionExecutor<(SignedRadixCiphertext, i64), SignedRadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
i64,
|
||||
&CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (SignedRadixCiphertext, i64)) -> SignedRadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&input.0, &context.streams);
|
||||
|
||||
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
gpu_result.to_signed_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
// Unary Function
|
||||
impl<'a, F> FunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&CudaServerKey, &CudaSignedRadixCiphertext, &CudaStreams) -> CudaSignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: &'a SignedRadixCiphertext) -> SignedRadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input, &context.streams);
|
||||
|
||||
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, &context.streams);
|
||||
|
||||
gpu_result.to_signed_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&CudaServerKey, &CudaSignedRadixCiphertext, &CudaStreams) -> CudaUnsignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: &'a SignedRadixCiphertext) -> RadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input, &context.streams);
|
||||
|
||||
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, &context.streams);
|
||||
|
||||
gpu_result.to_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
// Unary assign Function
|
||||
impl<'a, F> FunctionExecutor<&'a mut SignedRadixCiphertext, ()>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&CudaServerKey, &mut CudaSignedRadixCiphertext, &CudaStreams),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: &'a mut SignedRadixCiphertext) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let mut d_ctxt_1 =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input, &context.streams);
|
||||
|
||||
(self.func)(&context.sks, &mut d_ctxt_1, &context.streams);
|
||||
|
||||
*input = d_ctxt_1.to_signed_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<&'a Vec<SignedRadixCiphertext>, Option<SignedRadixCiphertext>>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&CudaServerKey, Vec<CudaSignedRadixCiphertext>) -> Option<CudaSignedRadixCiphertext>,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: &'a Vec<SignedRadixCiphertext>) -> Option<SignedRadixCiphertext> {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: Vec<CudaSignedRadixCiphertext> = input
|
||||
.iter()
|
||||
.map(|ct| CudaSignedRadixCiphertext::from_signed_radix_ciphertext(ct, &context.streams))
|
||||
.collect();
|
||||
|
||||
let d_res = (self.func)(&context.sks, d_ctxt_1);
|
||||
|
||||
Some(d_res.unwrap().to_signed_radix_ciphertext(&context.streams))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F>
|
||||
FunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
(SignedRadixCiphertext, BooleanBlock),
|
||||
> for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
&CudaSignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
) -> (SignedRadixCiphertext, BooleanBlock) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
let d_ctxt_2: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
|
||||
|
||||
(
|
||||
d_res.0.to_signed_radix_ciphertext(&context.streams),
|
||||
d_res.1.to_boolean_block(&context.streams),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// For unchecked/default unsigned overflowing scalar operations
|
||||
impl<'a, F>
|
||||
FunctionExecutor<(&'a SignedRadixCiphertext, i64), (SignedRadixCiphertext, BooleanBlock)>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
i64,
|
||||
&CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a SignedRadixCiphertext, i64),
|
||||
) -> (SignedRadixCiphertext, BooleanBlock) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
(
|
||||
d_res.0.to_signed_radix_ciphertext(&context.streams),
|
||||
d_res.1.to_boolean_block(&context.streams),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<&'a SignedRadixCiphertext, (SignedRadixCiphertext, BooleanBlock)>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: &'a SignedRadixCiphertext,
|
||||
) -> (SignedRadixCiphertext, BooleanBlock) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, &context.streams);
|
||||
|
||||
(
|
||||
d_res.0.to_signed_radix_ciphertext(&context.streams),
|
||||
d_res.1.to_boolean_block(&context.streams),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F>
|
||||
FunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
(SignedRadixCiphertext, SignedRadixCiphertext),
|
||||
> for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
&CudaSignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
) -> (SignedRadixCiphertext, SignedRadixCiphertext) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
let d_ctxt_2: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
|
||||
|
||||
(
|
||||
d_res.0.to_signed_radix_ciphertext(&context.streams),
|
||||
d_res.1.to_signed_radix_ciphertext(&context.streams),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F>
|
||||
FunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, i64),
|
||||
(SignedRadixCiphertext, SignedRadixCiphertext),
|
||||
> for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
i64,
|
||||
&CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a SignedRadixCiphertext, i64),
|
||||
) -> (SignedRadixCiphertext, SignedRadixCiphertext) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
(
|
||||
d_res.0.to_signed_radix_ciphertext(&context.streams),
|
||||
d_res.1.to_signed_radix_ciphertext(&context.streams),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext), BooleanBlock>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
&CudaSignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> CudaBooleanBlock,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
) -> BooleanBlock {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
let d_ctxt_2: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
|
||||
|
||||
d_res.to_boolean_block(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, i64), BooleanBlock>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&CudaServerKey, &CudaSignedRadixCiphertext, i64, &CudaStreams) -> CudaBooleanBlock,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a SignedRadixCiphertext, i64)) -> BooleanBlock {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
d_res.to_boolean_block(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, U256), BooleanBlock>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&CudaServerKey, &CudaSignedRadixCiphertext, U256, &CudaStreams) -> CudaBooleanBlock,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a SignedRadixCiphertext, U256)) -> BooleanBlock {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
d_res.to_boolean_block(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<(&'a SignedRadixCiphertext, U256), SignedRadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
U256,
|
||||
&CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a SignedRadixCiphertext, U256)) -> SignedRadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
d_res.to_signed_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F>
|
||||
FunctionExecutor<
|
||||
(
|
||||
&'a BooleanBlock,
|
||||
&'a SignedRadixCiphertext,
|
||||
&'a SignedRadixCiphertext,
|
||||
),
|
||||
SignedRadixCiphertext,
|
||||
> for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaBooleanBlock,
|
||||
&CudaSignedRadixCiphertext,
|
||||
&CudaSignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (
|
||||
&'a BooleanBlock,
|
||||
&'a SignedRadixCiphertext,
|
||||
&'a SignedRadixCiphertext,
|
||||
),
|
||||
) -> SignedRadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaBooleanBlock =
|
||||
CudaBooleanBlock::from_boolean_block(input.0, &context.streams);
|
||||
let d_ctxt_2: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.1, &context.streams);
|
||||
let d_ctxt_3: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.2, &context.streams);
|
||||
|
||||
let d_res = (self.func)(
|
||||
&context.sks,
|
||||
&d_ctxt_1,
|
||||
&d_ctxt_2,
|
||||
&d_ctxt_3,
|
||||
&context.streams,
|
||||
);
|
||||
|
||||
d_res.to_signed_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,6 +50,7 @@ macro_rules! create_gpu_parameterized_test{
|
||||
};
|
||||
}
|
||||
|
||||
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
|
||||
pub(crate) use create_gpu_parameterized_test;
|
||||
|
||||
pub(crate) struct GpuContext {
|
||||
@@ -904,3 +905,530 @@ where
|
||||
d_block.to_boolean_block(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
/// For default/unchecked binary functions
|
||||
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a RadixCiphertext, &'a RadixCiphertext)) -> RadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
let d_ctxt_2 =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
|
||||
|
||||
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
|
||||
|
||||
gpu_result.to_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
/// For unchecked/default assign binary functions
|
||||
impl<'a, F> FunctionExecutor<(&'a mut RadixCiphertext, &'a RadixCiphertext), ()>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&mut CudaUnsignedRadixCiphertext,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a mut RadixCiphertext, &'a RadixCiphertext)) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let mut d_ctxt_1 =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
let d_ctxt_2 =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
|
||||
|
||||
(self.func)(&context.sks, &mut d_ctxt_1, &d_ctxt_2, &context.streams);
|
||||
|
||||
*input.0 = d_ctxt_1.to_radix_ciphertext(&context.streams);
|
||||
}
|
||||
}
|
||||
|
||||
/// For unchecked/default binary functions with one scalar input
|
||||
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
u64,
|
||||
&CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a RadixCiphertext, u64)) -> RadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
gpu_result.to_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
/// For unchecked/default binary functions with one scalar input
|
||||
impl<F> FunctionExecutor<(RadixCiphertext, u64), RadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
u64,
|
||||
&CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (RadixCiphertext, u64)) -> RadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(&input.0, &context.streams);
|
||||
|
||||
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
gpu_result.to_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
// Unary Function
|
||||
impl<'a, F> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: &'a RadixCiphertext) -> RadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(input, &context.streams);
|
||||
|
||||
let gpu_result = (self.func)(&context.sks, &d_ctxt_1, &context.streams);
|
||||
|
||||
gpu_result.to_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
// Unary assign Function
|
||||
impl<'a, F> FunctionExecutor<&'a mut RadixCiphertext, ()> for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&CudaServerKey, &mut CudaUnsignedRadixCiphertext, &CudaStreams),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: &'a mut RadixCiphertext) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let mut d_ctxt_1 =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input, &context.streams);
|
||||
|
||||
(self.func)(&context.sks, &mut d_ctxt_1, &context.streams);
|
||||
|
||||
*input = d_ctxt_1.to_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<&'a Vec<RadixCiphertext>, Option<RadixCiphertext>>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&CudaServerKey, Vec<CudaUnsignedRadixCiphertext>) -> Option<CudaUnsignedRadixCiphertext>,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: &'a Vec<RadixCiphertext>) -> Option<RadixCiphertext> {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: Vec<CudaUnsignedRadixCiphertext> = input
|
||||
.iter()
|
||||
.map(|ct| CudaUnsignedRadixCiphertext::from_radix_ciphertext(ct, &context.streams))
|
||||
.collect();
|
||||
|
||||
let d_res = (self.func)(&context.sks, d_ctxt_1);
|
||||
|
||||
Some(d_res.unwrap().to_radix_ciphertext(&context.streams))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F>
|
||||
FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), (RadixCiphertext, BooleanBlock)>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a RadixCiphertext, &'a RadixCiphertext),
|
||||
) -> (RadixCiphertext, BooleanBlock) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
let d_ctxt_2: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
|
||||
|
||||
(
|
||||
d_res.0.to_radix_ciphertext(&context.streams),
|
||||
d_res.1.to_boolean_block(&context.streams),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// For unchecked/default unsigned overflowing scalar operations
|
||||
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, u64), (RadixCiphertext, BooleanBlock)>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
u64,
|
||||
&CudaStreams,
|
||||
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a RadixCiphertext, u64)) -> (RadixCiphertext, BooleanBlock) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
(
|
||||
d_res.0.to_radix_ciphertext(&context.streams),
|
||||
d_res.1.to_boolean_block(&context.streams),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// For ilog operation
|
||||
impl<'a, F> FunctionExecutor<&'a RadixCiphertext, (RadixCiphertext, BooleanBlock)>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: &'a RadixCiphertext) -> (RadixCiphertext, BooleanBlock) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, &context.streams);
|
||||
|
||||
(
|
||||
d_res.0.to_radix_ciphertext(&context.streams),
|
||||
d_res.1.to_boolean_block(&context.streams),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F>
|
||||
FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), (RadixCiphertext, RadixCiphertext)>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a RadixCiphertext, &'a RadixCiphertext),
|
||||
) -> (RadixCiphertext, RadixCiphertext) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
let d_ctxt_2: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
|
||||
|
||||
(
|
||||
d_res.0.to_radix_ciphertext(&context.streams),
|
||||
d_res.1.to_radix_ciphertext(&context.streams),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, u64), (RadixCiphertext, RadixCiphertext)>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
u64,
|
||||
&CudaStreams,
|
||||
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a RadixCiphertext, u64)) -> (RadixCiphertext, RadixCiphertext) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
(
|
||||
d_res.0.to_radix_ciphertext(&context.streams),
|
||||
d_res.1.to_radix_ciphertext(&context.streams),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), BooleanBlock>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> CudaBooleanBlock,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a RadixCiphertext, &'a RadixCiphertext)) -> BooleanBlock {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
let d_ctxt_2: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, &d_ctxt_2, &context.streams);
|
||||
|
||||
d_res.to_boolean_block(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, u64), BooleanBlock>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&CudaServerKey, &CudaUnsignedRadixCiphertext, u64, &CudaStreams) -> CudaBooleanBlock,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a RadixCiphertext, u64)) -> BooleanBlock {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
d_res.to_boolean_block(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, U256), BooleanBlock>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&CudaServerKey, &CudaUnsignedRadixCiphertext, U256, &CudaStreams) -> CudaBooleanBlock,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a RadixCiphertext, U256)) -> BooleanBlock {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
d_res.to_boolean_block(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F> FunctionExecutor<(&'a RadixCiphertext, U256), RadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
U256,
|
||||
&CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (&'a RadixCiphertext, U256)) -> RadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let d_res = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
d_res.to_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, F>
|
||||
FunctionExecutor<(&'a BooleanBlock, &'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>
|
||||
for GpuMultiDeviceFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaBooleanBlock,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaUnsignedRadixCiphertext,
|
||||
&CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext,
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a BooleanBlock, &'a RadixCiphertext, &'a RadixCiphertext),
|
||||
) -> RadixCiphertext {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaBooleanBlock =
|
||||
CudaBooleanBlock::from_boolean_block(input.0, &context.streams);
|
||||
let d_ctxt_2: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.1, &context.streams);
|
||||
let d_ctxt_3: CudaUnsignedRadixCiphertext =
|
||||
CudaUnsignedRadixCiphertext::from_radix_ciphertext(input.2, &context.streams);
|
||||
|
||||
let d_res = (self.func)(
|
||||
&context.sks,
|
||||
&d_ctxt_1,
|
||||
&d_ctxt_2,
|
||||
&d_ctxt_3,
|
||||
&context.streams,
|
||||
);
|
||||
|
||||
d_res.to_radix_ciphertext(&context.streams)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::core_crypto::gpu::{get_number_of_gpus, CudaStreams};
|
||||
use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
|
||||
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::{
|
||||
create_gpu_parameterized_test, GpuFunctionExecutor,
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::core_crypto::gpu::get_number_of_gpus;
|
||||
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::{
|
||||
create_gpu_parameterized_test, GpuFunctionExecutor,
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::core_crypto::gpu::get_number_of_gpus;
|
||||
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::{
|
||||
create_gpu_parameterized_test, GpuFunctionExecutor,
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::core_crypto::gpu::get_number_of_gpus;
|
||||
use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_signed::GpuMultiDeviceFunctionExecutor;
|
||||
use crate::integer::gpu::server_key::radix::tests_unsigned::{
|
||||
create_gpu_parameterized_test, GpuFunctionExecutor,
|
||||
};
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::integer::{
|
||||
BooleanBlock, IntegerCiphertext, RadixCiphertext, RadixClientKey, SignedRadixCiphertext,
|
||||
};
|
||||
use crate::shortint::parameters::NoiseLevel;
|
||||
use crate::CompressedServerKey;
|
||||
use rand::Rng;
|
||||
use tfhe_csprng::generators::DefaultRandomGenerator;
|
||||
use tfhe_csprng::seeders::{Seed, Seeder};
|
||||
@@ -17,6 +18,42 @@ pub(crate) const NB_CTXT_LONG_RUN: usize = 32;
|
||||
pub(crate) const NB_TESTS_LONG_RUN: usize = 20000;
|
||||
pub(crate) const NB_TESTS_LONG_RUN_MINIMAL: usize = 200;
|
||||
|
||||
pub(crate) fn get_user_defined_seed() -> Option<Seed> {
|
||||
match std::env::var("TFHE_RS_LONGRUN_TESTS_SEED") {
|
||||
Ok(val) => match val.parse::<u128>() {
|
||||
Ok(s) => Some(Seed(s)),
|
||||
Err(_e) => None,
|
||||
},
|
||||
Err(_e) => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// This trait is to be implemented by a struct that is capable
|
||||
/// of executing a particular function to be tested in the
|
||||
/// random op sequence tests. This executor
|
||||
/// can execute ops either on CPU or on both CPU and GPU
|
||||
pub(crate) trait OpSequenceFunctionExecutor<TestInput, TestOutput> {
|
||||
/// Setups the executor
|
||||
///
|
||||
/// Implementers are expected to be fully functional after this
|
||||
/// function has been called.
|
||||
fn setup(
|
||||
&mut self,
|
||||
cks: &RadixClientKey,
|
||||
sks: &CompressedServerKey,
|
||||
seeder: &mut DeterministicSeeder<DefaultRandomGenerator>,
|
||||
);
|
||||
|
||||
/// Executes the function
|
||||
///
|
||||
/// The function receives some inputs and return some output.
|
||||
/// Implementers may have to do more than just calling the function
|
||||
/// that is being tested (for example input/output may need to be converted)
|
||||
///
|
||||
/// Look at the test case function to know what are the expected inputs and outputs.
|
||||
fn execute(&mut self, input: TestInput) -> TestOutput;
|
||||
}
|
||||
|
||||
pub(crate) fn get_long_test_iterations() -> usize {
|
||||
static ENV_KEY_LONG_TESTS: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
|
||||
|
||||
@@ -60,11 +97,13 @@ impl RadixEncryptable for i64 {
|
||||
key.encrypt_signed(*self)
|
||||
}
|
||||
}
|
||||
|
||||
struct RandomOpSequenceDataGenerator<P, C> {
|
||||
// Generates data for the random op sequence tests. It
|
||||
// can generate data deterministically using a user-specified
|
||||
// seed
|
||||
pub(crate) struct RandomOpSequenceDataGenerator<P, C> {
|
||||
pub(crate) lhs: Vec<TestDataSample<P, C>>,
|
||||
pub(crate) rhs: Vec<TestDataSample<P, C>>,
|
||||
deterministic_seeder: DeterministicSeeder<DefaultRandomGenerator>,
|
||||
pub(crate) deterministic_seeder: DeterministicSeeder<DefaultRandomGenerator>,
|
||||
seed: Seed,
|
||||
cks: RadixClientKey,
|
||||
op_counter: usize,
|
||||
@@ -75,7 +114,7 @@ impl<
|
||||
C: IntegerCiphertext,
|
||||
> RandomOpSequenceDataGenerator<P, C>
|
||||
{
|
||||
fn new(total_num_ops: usize, cks: &RadixClientKey) -> Self {
|
||||
pub(crate) fn new(total_num_ops: usize, cks: &RadixClientKey) -> Self {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let seed: u128 = rng.gen();
|
||||
@@ -99,7 +138,7 @@ impl<
|
||||
}) // Generate random i64 values and encrypt them
|
||||
.collect()
|
||||
}
|
||||
fn new_with_seed(total_num_ops: usize, seed: Seed, cks: &RadixClientKey) -> Self {
|
||||
pub(crate) fn new_with_seed(total_num_ops: usize, seed: Seed, cks: &RadixClientKey) -> Self {
|
||||
let mut deterministic_seeder = DeterministicSeeder::<DefaultRandomGenerator>::new(seed);
|
||||
|
||||
Self {
|
||||
@@ -112,7 +151,7 @@ impl<
|
||||
}
|
||||
}
|
||||
|
||||
fn get_seed(&self) -> Seed {
|
||||
pub(crate) fn get_seed(&self) -> Seed {
|
||||
self.seed
|
||||
}
|
||||
|
||||
@@ -199,6 +238,11 @@ impl<
|
||||
let val = self.deterministic_seeder.seed().0 % 2;
|
||||
(val == 1, self.cks.encrypt_bool(val == 1))
|
||||
}
|
||||
|
||||
fn gen_bool_uniform(&mut self) -> bool {
|
||||
let val = self.deterministic_seeder.seed().0 % 2;
|
||||
val == 1
|
||||
}
|
||||
}
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn sanity_check_op_sequence_result_u64(
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
|
||||
use crate::integer::server_key::radix_parallel::tests_long_run::{
|
||||
get_long_test_iterations, sanity_check_op_sequence_result_bool,
|
||||
sanity_check_op_sequence_result_u64, RandomOpSequenceDataGenerator, NB_CTXT_LONG_RUN,
|
||||
get_long_test_iterations, get_user_defined_seed, sanity_check_op_sequence_result_bool,
|
||||
sanity_check_op_sequence_result_u64, OpSequenceFunctionExecutor, RandomOpSequenceDataGenerator,
|
||||
NB_CTXT_LONG_RUN,
|
||||
};
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::OpSequenceCpuFunctionExecutor;
|
||||
use crate::integer::tests::create_parameterized_test;
|
||||
use crate::integer::{BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey};
|
||||
use crate::shortint::parameters::*;
|
||||
use crate::{ClientKey, CompressedServerKey, Tag};
|
||||
use std::cmp::{max, min};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -19,59 +20,78 @@ create_parameterized_test!(random_op_sequence_data_generator {
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128
|
||||
});
|
||||
|
||||
pub(crate) type BinaryOpExecutor =
|
||||
Box<dyn for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>>;
|
||||
pub(crate) type BinaryOpExecutor = Box<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a RadixCiphertext),
|
||||
RadixCiphertext,
|
||||
>,
|
||||
>;
|
||||
pub(crate) type UnaryOpExecutor =
|
||||
Box<dyn for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>>;
|
||||
Box<dyn for<'a> OpSequenceFunctionExecutor<&'a RadixCiphertext, RadixCiphertext>>;
|
||||
|
||||
pub(crate) type ScalarBinaryOpExecutor =
|
||||
Box<dyn for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>>;
|
||||
Box<dyn for<'a> OpSequenceFunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>>;
|
||||
pub(crate) type OverflowingOpExecutor = Box<
|
||||
dyn for<'a> FunctionExecutor<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a RadixCiphertext),
|
||||
(RadixCiphertext, BooleanBlock),
|
||||
>,
|
||||
>;
|
||||
pub(crate) type ScalarOverflowingOpExecutor =
|
||||
Box<dyn for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), (RadixCiphertext, BooleanBlock)>>;
|
||||
pub(crate) type ComparisonOpExecutor =
|
||||
Box<dyn for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), BooleanBlock>>;
|
||||
pub(crate) type ScalarOverflowingOpExecutor = Box<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a RadixCiphertext, u64),
|
||||
(RadixCiphertext, BooleanBlock),
|
||||
>,
|
||||
>;
|
||||
pub(crate) type ComparisonOpExecutor = Box<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a RadixCiphertext),
|
||||
BooleanBlock,
|
||||
>,
|
||||
>;
|
||||
pub(crate) type ScalarComparisonOpExecutor =
|
||||
Box<dyn for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), BooleanBlock>>;
|
||||
Box<dyn for<'a> OpSequenceFunctionExecutor<(&'a RadixCiphertext, u64), BooleanBlock>>;
|
||||
pub(crate) type SelectOpExecutor = Box<
|
||||
dyn for<'a> FunctionExecutor<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a BooleanBlock, &'a RadixCiphertext, &'a RadixCiphertext),
|
||||
RadixCiphertext,
|
||||
>,
|
||||
>;
|
||||
pub(crate) type DivRemOpExecutor = Box<
|
||||
dyn for<'a> FunctionExecutor<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a RadixCiphertext, &'a RadixCiphertext),
|
||||
(RadixCiphertext, RadixCiphertext),
|
||||
>,
|
||||
>;
|
||||
pub(crate) type ScalarDivRemOpExecutor = Box<
|
||||
dyn for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), (RadixCiphertext, RadixCiphertext)>,
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a RadixCiphertext, u64),
|
||||
(RadixCiphertext, RadixCiphertext),
|
||||
>,
|
||||
>;
|
||||
pub(crate) type Log2OpExecutor =
|
||||
Box<dyn for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>>;
|
||||
Box<dyn for<'a> OpSequenceFunctionExecutor<&'a RadixCiphertext, RadixCiphertext>>;
|
||||
fn random_op_sequence<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters> + Clone,
|
||||
{
|
||||
// Binary Ops Executors
|
||||
let add_executor = CpuFunctionExecutor::new(&ServerKey::add_parallelized);
|
||||
let sub_executor = CpuFunctionExecutor::new(&ServerKey::sub_parallelized);
|
||||
let bitwise_and_executor = CpuFunctionExecutor::new(&ServerKey::bitand_parallelized);
|
||||
let bitwise_or_executor = CpuFunctionExecutor::new(&ServerKey::bitor_parallelized);
|
||||
let bitwise_xor_executor = CpuFunctionExecutor::new(&ServerKey::bitxor_parallelized);
|
||||
let mul_executor = CpuFunctionExecutor::new(&ServerKey::mul_parallelized);
|
||||
let rotate_left_executor = CpuFunctionExecutor::new(&ServerKey::rotate_left_parallelized);
|
||||
let left_shift_executor = CpuFunctionExecutor::new(&ServerKey::left_shift_parallelized);
|
||||
let rotate_right_executor = CpuFunctionExecutor::new(&ServerKey::rotate_right_parallelized);
|
||||
let right_shift_executor = CpuFunctionExecutor::new(&ServerKey::right_shift_parallelized);
|
||||
let max_executor = CpuFunctionExecutor::new(&ServerKey::max_parallelized);
|
||||
let min_executor = CpuFunctionExecutor::new(&ServerKey::min_parallelized);
|
||||
let add_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::add_parallelized);
|
||||
let sub_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::sub_parallelized);
|
||||
let bitwise_and_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitand_parallelized);
|
||||
let bitwise_or_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitor_parallelized);
|
||||
let bitwise_xor_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitxor_parallelized);
|
||||
let mul_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::mul_parallelized);
|
||||
let rotate_left_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::rotate_left_parallelized);
|
||||
let left_shift_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::left_shift_parallelized);
|
||||
let rotate_right_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::rotate_right_parallelized);
|
||||
let right_shift_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::right_shift_parallelized);
|
||||
let max_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::max_parallelized);
|
||||
let min_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::min_parallelized);
|
||||
|
||||
// Binary Ops Clear functions
|
||||
let clear_add = |x: u64, y: u64| x.wrapping_add(y);
|
||||
@@ -134,9 +154,10 @@ where
|
||||
];
|
||||
|
||||
// Unary Ops Executors
|
||||
let neg_executor = CpuFunctionExecutor::new(&ServerKey::neg_parallelized);
|
||||
let bitnot_executor = CpuFunctionExecutor::new(&ServerKey::bitnot);
|
||||
let reverse_bits_executor = CpuFunctionExecutor::new(&ServerKey::reverse_bits_parallelized);
|
||||
let neg_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::neg_parallelized);
|
||||
let bitnot_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitnot);
|
||||
let reverse_bits_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::reverse_bits_parallelized);
|
||||
// Unary Ops Clear functions
|
||||
let clear_neg = |x: u64| x.wrapping_neg();
|
||||
let clear_bitnot = |x: u64| !x;
|
||||
@@ -157,23 +178,26 @@ where
|
||||
];
|
||||
|
||||
// Scalar binary Ops Executors
|
||||
let scalar_add_executor = CpuFunctionExecutor::new(&ServerKey::scalar_add_parallelized);
|
||||
let scalar_sub_executor = CpuFunctionExecutor::new(&ServerKey::scalar_sub_parallelized);
|
||||
let scalar_add_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_add_parallelized);
|
||||
let scalar_sub_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_sub_parallelized);
|
||||
let scalar_bitwise_and_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_bitand_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_bitand_parallelized);
|
||||
let scalar_bitwise_or_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_bitor_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_bitor_parallelized);
|
||||
let scalar_bitwise_xor_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_bitxor_parallelized);
|
||||
let scalar_mul_executor = CpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_bitxor_parallelized);
|
||||
let scalar_mul_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized);
|
||||
let scalar_rotate_left_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_rotate_left_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_rotate_left_parallelized);
|
||||
let scalar_left_shift_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_left_shift_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_left_shift_parallelized);
|
||||
let scalar_rotate_right_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_rotate_right_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_rotate_right_parallelized);
|
||||
let scalar_right_shift_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_right_shift_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_right_shift_parallelized);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_binary_ops: Vec<(ScalarBinaryOpExecutor, &dyn Fn(u64, u64) -> u64, String)> = vec![
|
||||
@@ -231,11 +255,11 @@ where
|
||||
|
||||
// Overflowing Ops Executors
|
||||
let overflowing_add_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_add_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_add_parallelized);
|
||||
let overflowing_sub_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub_parallelized);
|
||||
let overflowing_mul_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_mul_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_mul_parallelized);
|
||||
|
||||
// Overflowing Ops Clear functions
|
||||
let clear_overflowing_add = |x: u64, y: u64| -> (u64, bool) { x.overflowing_add(y) };
|
||||
@@ -266,10 +290,12 @@ where
|
||||
];
|
||||
|
||||
// Scalar Overflowing Ops Executors
|
||||
let overflowing_scalar_add_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_add_parallelized);
|
||||
let overflowing_scalar_sub_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_sub_parallelized);
|
||||
let overflowing_scalar_add_executor = OpSequenceCpuFunctionExecutor::new(
|
||||
&ServerKey::unsigned_overflowing_scalar_add_parallelized,
|
||||
);
|
||||
let overflowing_scalar_sub_executor = OpSequenceCpuFunctionExecutor::new(
|
||||
&ServerKey::unsigned_overflowing_scalar_sub_parallelized,
|
||||
);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_overflowing_ops: Vec<(
|
||||
@@ -290,12 +316,12 @@ where
|
||||
];
|
||||
|
||||
// Comparison Ops Executors
|
||||
let gt_executor = CpuFunctionExecutor::new(&ServerKey::gt_parallelized);
|
||||
let ge_executor = CpuFunctionExecutor::new(&ServerKey::ge_parallelized);
|
||||
let lt_executor = CpuFunctionExecutor::new(&ServerKey::lt_parallelized);
|
||||
let le_executor = CpuFunctionExecutor::new(&ServerKey::le_parallelized);
|
||||
let eq_executor = CpuFunctionExecutor::new(&ServerKey::eq_parallelized);
|
||||
let ne_executor = CpuFunctionExecutor::new(&ServerKey::ne_parallelized);
|
||||
let gt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::gt_parallelized);
|
||||
let ge_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::ge_parallelized);
|
||||
let lt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::lt_parallelized);
|
||||
let le_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::le_parallelized);
|
||||
let eq_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::eq_parallelized);
|
||||
let ne_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::ne_parallelized);
|
||||
|
||||
// Comparison Ops Clear functions
|
||||
let clear_gt = |x: u64, y: u64| -> bool { x > y };
|
||||
@@ -316,12 +342,12 @@ where
|
||||
];
|
||||
|
||||
// Scalar Comparison Ops Executors
|
||||
let scalar_gt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_gt_parallelized);
|
||||
let scalar_ge_executor = CpuFunctionExecutor::new(&ServerKey::scalar_ge_parallelized);
|
||||
let scalar_lt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_lt_parallelized);
|
||||
let scalar_le_executor = CpuFunctionExecutor::new(&ServerKey::scalar_le_parallelized);
|
||||
let scalar_eq_executor = CpuFunctionExecutor::new(&ServerKey::scalar_eq_parallelized);
|
||||
let scalar_ne_executor = CpuFunctionExecutor::new(&ServerKey::scalar_ne_parallelized);
|
||||
let scalar_gt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_gt_parallelized);
|
||||
let scalar_ge_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_ge_parallelized);
|
||||
let scalar_lt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_lt_parallelized);
|
||||
let scalar_le_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_le_parallelized);
|
||||
let scalar_eq_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_eq_parallelized);
|
||||
let scalar_ne_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_ne_parallelized);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_comparison_ops: Vec<(
|
||||
@@ -362,7 +388,7 @@ where
|
||||
];
|
||||
|
||||
// Select Executor
|
||||
let select_executor = CpuFunctionExecutor::new(&ServerKey::cmux_parallelized);
|
||||
let select_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::cmux_parallelized);
|
||||
|
||||
// Select
|
||||
let clear_select = |b: bool, x: u64, y: u64| if b { x } else { y };
|
||||
@@ -375,7 +401,7 @@ where
|
||||
)];
|
||||
|
||||
// Div executor
|
||||
let div_rem_executor = CpuFunctionExecutor::new(&ServerKey::div_rem_parallelized);
|
||||
let div_rem_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::div_rem_parallelized);
|
||||
// Div Rem Clear functions
|
||||
let clear_div_rem = |x: u64, y: u64| -> (u64, u64) { (x.wrapping_div(y), x.wrapping_rem(y)) };
|
||||
#[allow(clippy::type_complexity)]
|
||||
@@ -386,7 +412,8 @@ where
|
||||
)];
|
||||
|
||||
// Scalar Div executor
|
||||
let scalar_div_rem_executor = CpuFunctionExecutor::new(&ServerKey::scalar_div_rem_parallelized);
|
||||
let scalar_div_rem_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_div_rem_parallelized);
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_div_rem_op: Vec<(
|
||||
ScalarDivRemOpExecutor,
|
||||
@@ -399,9 +426,11 @@ where
|
||||
)];
|
||||
|
||||
// Log2/Hamming weight ops
|
||||
let ilog2_executor = CpuFunctionExecutor::new(&ServerKey::ilog2_parallelized);
|
||||
let count_zeros_executor = CpuFunctionExecutor::new(&ServerKey::count_zeros_parallelized);
|
||||
let count_ones_executor = CpuFunctionExecutor::new(&ServerKey::count_ones_parallelized);
|
||||
let ilog2_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::ilog2_parallelized);
|
||||
let count_zeros_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::count_zeros_parallelized);
|
||||
let count_ones_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::count_ones_parallelized);
|
||||
let clear_ilog2 = |x: u64| x.ilog2() as u64;
|
||||
let clear_count_zeros = |x: u64| x.count_zeros() as u64;
|
||||
let clear_count_ones = |x: u64| x.count_ones() as u64;
|
||||
@@ -421,7 +450,7 @@ where
|
||||
),
|
||||
];
|
||||
|
||||
random_op_sequence_test(
|
||||
let (cks, sks, mut datagen) = random_op_sequence_test_init_cpu(
|
||||
param,
|
||||
&mut binary_ops,
|
||||
&mut unary_ops,
|
||||
@@ -435,10 +464,27 @@ where
|
||||
&mut scalar_div_rem_op,
|
||||
&mut log2_ops,
|
||||
);
|
||||
|
||||
random_op_sequence_test(
|
||||
&mut datagen,
|
||||
&cks,
|
||||
&sks,
|
||||
&mut binary_ops,
|
||||
&mut unary_ops,
|
||||
&mut scalar_binary_ops,
|
||||
&mut overflowing_ops,
|
||||
&mut scalar_overflowing_ops,
|
||||
&mut comparison_ops,
|
||||
&mut scalar_comparison_ops,
|
||||
&mut select_op,
|
||||
&mut div_rem_op,
|
||||
&mut scalar_div_rem_op,
|
||||
&mut log2_ops,
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn random_op_sequence_test<P>(
|
||||
pub(crate) fn random_op_sequence_test_init_cpu<P>(
|
||||
param: P,
|
||||
binary_ops: &mut [(BinaryOpExecutor, impl Fn(u64, u64) -> u64, String)],
|
||||
unary_ops: &mut [(UnaryOpExecutor, impl Fn(u64) -> u64, String)],
|
||||
@@ -467,51 +513,21 @@ pub(crate) fn random_op_sequence_test<P>(
|
||||
String,
|
||||
)],
|
||||
log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)],
|
||||
) where
|
||||
) -> (
|
||||
RadixClientKey,
|
||||
Arc<ServerKey>,
|
||||
RandomOpSequenceDataGenerator<u64, RadixCiphertext>,
|
||||
)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT_LONG_RUN));
|
||||
let temp_cks =
|
||||
ClientKey::from_raw_parts(cks.clone(), None, None, None, None, None, Tag::default());
|
||||
let comp_sks = CompressedServerKey::new(&temp_cks);
|
||||
|
||||
println!("Setting up operations");
|
||||
|
||||
for x in binary_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in unary_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in scalar_binary_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in overflowing_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in scalar_overflowing_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in comparison_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in scalar_comparison_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in select_op.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in div_rem_op.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in scalar_div_rem_op.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in log2_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
let total_num_ops = binary_ops.len()
|
||||
+ unary_ops.len()
|
||||
+ scalar_binary_ops.len()
|
||||
@@ -525,6 +541,93 @@ pub(crate) fn random_op_sequence_test<P>(
|
||||
+ log2_ops.len();
|
||||
println!("Total num ops {total_num_ops}");
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT_LONG_RUN));
|
||||
|
||||
let mut datagen = get_user_defined_seed().map_or_else(
|
||||
|| RandomOpSequenceDataGenerator::<u64, RadixCiphertext>::new(total_num_ops, &cks),
|
||||
|seed| {
|
||||
RandomOpSequenceDataGenerator::<u64, RadixCiphertext>::new_with_seed(
|
||||
total_num_ops,
|
||||
seed,
|
||||
&cks,
|
||||
)
|
||||
},
|
||||
);
|
||||
println!("random_op_sequence_test::seed = {}", datagen.get_seed().0);
|
||||
|
||||
println!("Setting up operations");
|
||||
|
||||
for x in binary_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in unary_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_binary_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in overflowing_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_overflowing_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in comparison_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_comparison_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in select_op.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in div_rem_op.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_div_rem_op.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in log2_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
|
||||
(cks, sks, datagen)
|
||||
}
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn random_op_sequence_test(
|
||||
datagen: &mut RandomOpSequenceDataGenerator<u64, RadixCiphertext>,
|
||||
cks: &RadixClientKey,
|
||||
sks: &Arc<ServerKey>,
|
||||
binary_ops: &mut [(BinaryOpExecutor, impl Fn(u64, u64) -> u64, String)],
|
||||
unary_ops: &mut [(UnaryOpExecutor, impl Fn(u64) -> u64, String)],
|
||||
scalar_binary_ops: &mut [(ScalarBinaryOpExecutor, impl Fn(u64, u64) -> u64, String)],
|
||||
overflowing_ops: &mut [(
|
||||
OverflowingOpExecutor,
|
||||
impl Fn(u64, u64) -> (u64, bool),
|
||||
String,
|
||||
)],
|
||||
scalar_overflowing_ops: &mut [(
|
||||
ScalarOverflowingOpExecutor,
|
||||
impl Fn(u64, u64) -> (u64, bool),
|
||||
String,
|
||||
)],
|
||||
comparison_ops: &mut [(ComparisonOpExecutor, impl Fn(u64, u64) -> bool, String)],
|
||||
scalar_comparison_ops: &mut [(
|
||||
ScalarComparisonOpExecutor,
|
||||
impl Fn(u64, u64) -> bool,
|
||||
String,
|
||||
)],
|
||||
select_op: &mut [(SelectOpExecutor, impl Fn(bool, u64, u64) -> u64, String)],
|
||||
div_rem_op: &mut [(DivRemOpExecutor, impl Fn(u64, u64) -> (u64, u64), String)],
|
||||
scalar_div_rem_op: &mut [(
|
||||
ScalarDivRemOpExecutor,
|
||||
impl Fn(u64, u64) -> (u64, u64),
|
||||
String,
|
||||
)],
|
||||
log2_ops: &mut [(Log2OpExecutor, impl Fn(u64) -> u64, String)],
|
||||
) {
|
||||
let binary_ops_range = 0..binary_ops.len();
|
||||
let unary_ops_range = binary_ops_range.end..binary_ops_range.end + unary_ops.len();
|
||||
let scalar_binary_ops_range =
|
||||
@@ -544,10 +647,6 @@ pub(crate) fn random_op_sequence_test<P>(
|
||||
div_rem_op_range.end..div_rem_op_range.end + scalar_div_rem_op.len();
|
||||
let log2_ops_range = scalar_div_rem_op_range.end..scalar_div_rem_op_range.end + log2_ops.len();
|
||||
|
||||
let mut datagen =
|
||||
RandomOpSequenceDataGenerator::<u64, RadixCiphertext>::new(total_num_ops, &cks);
|
||||
println!("random_op_sequence_test::seed = {}", datagen.get_seed().0);
|
||||
|
||||
for fn_index in 0..get_long_test_iterations() {
|
||||
let (i, idx) = datagen.gen_op_index();
|
||||
|
||||
@@ -718,7 +817,7 @@ pub(crate) fn random_op_sequence_test<P>(
|
||||
let expected_res = clear_fn(lhs.p, rhs.p);
|
||||
|
||||
let res_ct = sks.cast_to_unsigned(
|
||||
res.clone().into_radix::<RadixCiphertext>(1, &sks),
|
||||
res.clone().into_radix::<RadixCiphertext>(1, sks),
|
||||
NB_CTXT_LONG_RUN,
|
||||
);
|
||||
datagen.put_op_result_random_side(expected_res as u64, &res_ct, fn_name, idx);
|
||||
@@ -749,7 +848,7 @@ pub(crate) fn random_op_sequence_test<P>(
|
||||
let expected_res = clear_fn(lhs.p, rhs.p);
|
||||
|
||||
let res_ct: RadixCiphertext = sks.cast_to_unsigned(
|
||||
res.clone().into_radix::<RadixCiphertext>(1, &sks),
|
||||
res.clone().into_radix::<RadixCiphertext>(1, sks),
|
||||
NB_CTXT_LONG_RUN,
|
||||
);
|
||||
datagen.put_op_result_random_side(expected_res as u64, &res_ct, fn_name, idx);
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor;
|
||||
use crate::integer::server_key::radix_parallel::tests_long_run::{
|
||||
get_long_test_iterations, sanity_check_op_sequence_result_bool,
|
||||
sanity_check_op_sequence_result_i64, sanity_check_op_sequence_result_u64, NB_CTXT_LONG_RUN,
|
||||
get_long_test_iterations, get_user_defined_seed, sanity_check_op_sequence_result_bool,
|
||||
sanity_check_op_sequence_result_i64, sanity_check_op_sequence_result_u64,
|
||||
OpSequenceFunctionExecutor, RandomOpSequenceDataGenerator, NB_CTXT_LONG_RUN,
|
||||
};
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor;
|
||||
use crate::integer::server_key::radix_parallel::tests_unsigned::OpSequenceCpuFunctionExecutor;
|
||||
use crate::integer::tests::create_parameterized_test;
|
||||
use crate::integer::{
|
||||
BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey, SignedRadixCiphertext,
|
||||
};
|
||||
use crate::shortint::parameters::*;
|
||||
use rand::Rng;
|
||||
use crate::{ClientKey, CompressedServerKey, Tag};
|
||||
use std::cmp::{max, min};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -19,46 +19,48 @@ create_parameterized_test!(random_op_sequence {
|
||||
});
|
||||
|
||||
pub(crate) type SignedBinaryOpExecutor = Box<
|
||||
dyn for<'a> FunctionExecutor<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
SignedRadixCiphertext,
|
||||
>,
|
||||
>;
|
||||
pub(crate) type SignedShiftRotateExecutor = Box<
|
||||
dyn for<'a> FunctionExecutor<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, &'a RadixCiphertext),
|
||||
SignedRadixCiphertext,
|
||||
>,
|
||||
>;
|
||||
pub(crate) type SignedUnaryOpExecutor =
|
||||
Box<dyn for<'a> FunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>>;
|
||||
Box<dyn for<'a> OpSequenceFunctionExecutor<&'a SignedRadixCiphertext, SignedRadixCiphertext>>;
|
||||
|
||||
pub(crate) type SignedScalarBinaryOpExecutor =
|
||||
Box<dyn for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>>;
|
||||
pub(crate) type SignedScalarShiftRotateExecutor =
|
||||
Box<dyn for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, u64), SignedRadixCiphertext>>;
|
||||
pub(crate) type SignedScalarBinaryOpExecutor = Box<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<(&'a SignedRadixCiphertext, i64), SignedRadixCiphertext>,
|
||||
>;
|
||||
pub(crate) type SignedScalarShiftRotateExecutor = Box<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<(&'a SignedRadixCiphertext, u64), SignedRadixCiphertext>,
|
||||
>;
|
||||
pub(crate) type SignedOverflowingOpExecutor = Box<
|
||||
dyn for<'a> FunctionExecutor<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
(SignedRadixCiphertext, BooleanBlock),
|
||||
>,
|
||||
>;
|
||||
pub(crate) type SignedScalarOverflowingOpExecutor = Box<
|
||||
dyn for<'a> FunctionExecutor<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, i64),
|
||||
(SignedRadixCiphertext, BooleanBlock),
|
||||
>,
|
||||
>;
|
||||
pub(crate) type SignedComparisonOpExecutor = Box<
|
||||
dyn for<'a> FunctionExecutor<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
BooleanBlock,
|
||||
>,
|
||||
>;
|
||||
pub(crate) type SignedScalarComparisonOpExecutor =
|
||||
Box<dyn for<'a> FunctionExecutor<(&'a SignedRadixCiphertext, i64), BooleanBlock>>;
|
||||
Box<dyn for<'a> OpSequenceFunctionExecutor<(&'a SignedRadixCiphertext, i64), BooleanBlock>>;
|
||||
pub(crate) type SignedSelectOpExecutor = Box<
|
||||
dyn for<'a> FunctionExecutor<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(
|
||||
&'a BooleanBlock,
|
||||
&'a SignedRadixCiphertext,
|
||||
@@ -68,32 +70,32 @@ pub(crate) type SignedSelectOpExecutor = Box<
|
||||
>,
|
||||
>;
|
||||
pub(crate) type SignedDivRemOpExecutor = Box<
|
||||
dyn for<'a> FunctionExecutor<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, &'a SignedRadixCiphertext),
|
||||
(SignedRadixCiphertext, SignedRadixCiphertext),
|
||||
>,
|
||||
>;
|
||||
pub(crate) type SignedScalarDivRemOpExecutor = Box<
|
||||
dyn for<'a> FunctionExecutor<
|
||||
dyn for<'a> OpSequenceFunctionExecutor<
|
||||
(&'a SignedRadixCiphertext, i64),
|
||||
(SignedRadixCiphertext, SignedRadixCiphertext),
|
||||
>,
|
||||
>;
|
||||
pub(crate) type SignedLog2OpExecutor =
|
||||
Box<dyn for<'a> FunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext>>;
|
||||
Box<dyn for<'a> OpSequenceFunctionExecutor<&'a SignedRadixCiphertext, RadixCiphertext>>;
|
||||
fn random_op_sequence<P>(param: P)
|
||||
where
|
||||
P: Into<TestParameters> + Clone,
|
||||
{
|
||||
// Binary Ops Executors
|
||||
let add_executor = CpuFunctionExecutor::new(&ServerKey::add_parallelized);
|
||||
let sub_executor = CpuFunctionExecutor::new(&ServerKey::sub_parallelized);
|
||||
let bitwise_and_executor = CpuFunctionExecutor::new(&ServerKey::bitand_parallelized);
|
||||
let bitwise_or_executor = CpuFunctionExecutor::new(&ServerKey::bitor_parallelized);
|
||||
let bitwise_xor_executor = CpuFunctionExecutor::new(&ServerKey::bitxor_parallelized);
|
||||
let mul_executor = CpuFunctionExecutor::new(&ServerKey::mul_parallelized);
|
||||
let max_executor = CpuFunctionExecutor::new(&ServerKey::max_parallelized);
|
||||
let min_executor = CpuFunctionExecutor::new(&ServerKey::min_parallelized);
|
||||
let add_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::add_parallelized);
|
||||
let sub_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::sub_parallelized);
|
||||
let bitwise_and_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitand_parallelized);
|
||||
let bitwise_or_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitor_parallelized);
|
||||
let bitwise_xor_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitxor_parallelized);
|
||||
let mul_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::mul_parallelized);
|
||||
let max_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::max_parallelized);
|
||||
let min_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::min_parallelized);
|
||||
|
||||
// Binary Ops Clear functions
|
||||
let clear_add = |x: i64, y: i64| x.wrapping_add(y);
|
||||
@@ -129,10 +131,14 @@ where
|
||||
(Box::new(min_executor), &clear_min, "min".to_string()),
|
||||
];
|
||||
|
||||
let rotate_left_executor = CpuFunctionExecutor::new(&ServerKey::rotate_left_parallelized);
|
||||
let left_shift_executor = CpuFunctionExecutor::new(&ServerKey::left_shift_parallelized);
|
||||
let rotate_right_executor = CpuFunctionExecutor::new(&ServerKey::rotate_right_parallelized);
|
||||
let right_shift_executor = CpuFunctionExecutor::new(&ServerKey::right_shift_parallelized);
|
||||
let rotate_left_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::rotate_left_parallelized);
|
||||
let left_shift_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::left_shift_parallelized);
|
||||
let rotate_right_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::rotate_right_parallelized);
|
||||
let right_shift_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::right_shift_parallelized);
|
||||
// Warning this rotate definition only works with 64-bit ciphertexts
|
||||
let clear_rotate_left = |x: i64, y: u64| x.rotate_left(y as u32);
|
||||
let clear_left_shift = |x: i64, y: u64| x << y;
|
||||
@@ -167,9 +173,10 @@ where
|
||||
),
|
||||
];
|
||||
// Unary Ops Executors
|
||||
let neg_executor = CpuFunctionExecutor::new(&ServerKey::neg_parallelized);
|
||||
let bitnot_executor = CpuFunctionExecutor::new(&ServerKey::bitnot);
|
||||
let reverse_bits_executor = CpuFunctionExecutor::new(&ServerKey::reverse_bits_parallelized);
|
||||
let neg_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::neg_parallelized);
|
||||
let bitnot_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::bitnot);
|
||||
let reverse_bits_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::reverse_bits_parallelized);
|
||||
// Unary Ops Clear functions
|
||||
let clear_neg = |x: i64| x.wrapping_neg();
|
||||
let clear_bitnot = |x: i64| !x;
|
||||
@@ -190,15 +197,18 @@ where
|
||||
];
|
||||
|
||||
// Scalar binary Ops Executors
|
||||
let scalar_add_executor = CpuFunctionExecutor::new(&ServerKey::scalar_add_parallelized);
|
||||
let scalar_sub_executor = CpuFunctionExecutor::new(&ServerKey::scalar_sub_parallelized);
|
||||
let scalar_add_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_add_parallelized);
|
||||
let scalar_sub_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_sub_parallelized);
|
||||
let scalar_bitwise_and_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_bitand_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_bitand_parallelized);
|
||||
let scalar_bitwise_or_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_bitor_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_bitor_parallelized);
|
||||
let scalar_bitwise_xor_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_bitxor_parallelized);
|
||||
let scalar_mul_executor = CpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_bitxor_parallelized);
|
||||
let scalar_mul_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_binary_ops: Vec<(
|
||||
@@ -239,13 +249,13 @@ where
|
||||
];
|
||||
|
||||
let scalar_rotate_left_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_rotate_left_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_rotate_left_parallelized);
|
||||
let scalar_left_shift_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_left_shift_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_left_shift_parallelized);
|
||||
let scalar_rotate_right_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_rotate_right_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_rotate_right_parallelized);
|
||||
let scalar_right_shift_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::scalar_right_shift_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_right_shift_parallelized);
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_shift_rotate_ops: Vec<(
|
||||
SignedScalarShiftRotateExecutor,
|
||||
@@ -276,11 +286,11 @@ where
|
||||
|
||||
// Overflowing Ops Executors
|
||||
let overflowing_add_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::signed_overflowing_add_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::signed_overflowing_add_parallelized);
|
||||
let overflowing_sub_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::signed_overflowing_sub_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::signed_overflowing_sub_parallelized);
|
||||
let overflowing_mul_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::signed_overflowing_mul_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::signed_overflowing_mul_parallelized);
|
||||
// Overflowing Ops Clear functions
|
||||
let clear_overflowing_add = |x: i64, y: i64| -> (i64, bool) { x.overflowing_add(y) };
|
||||
let clear_overflowing_sub = |x: i64, y: i64| -> (i64, bool) { x.overflowing_sub(y) };
|
||||
@@ -311,9 +321,9 @@ where
|
||||
|
||||
// Scalar Overflowing Ops Executors
|
||||
let overflowing_scalar_add_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_add_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_add_parallelized);
|
||||
let overflowing_scalar_sub_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_sub_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::signed_overflowing_scalar_sub_parallelized);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_overflowing_ops: Vec<(
|
||||
@@ -334,12 +344,12 @@ where
|
||||
];
|
||||
|
||||
// Comparison Ops Executors
|
||||
let gt_executor = CpuFunctionExecutor::new(&ServerKey::gt_parallelized);
|
||||
let ge_executor = CpuFunctionExecutor::new(&ServerKey::ge_parallelized);
|
||||
let lt_executor = CpuFunctionExecutor::new(&ServerKey::lt_parallelized);
|
||||
let le_executor = CpuFunctionExecutor::new(&ServerKey::le_parallelized);
|
||||
let eq_executor = CpuFunctionExecutor::new(&ServerKey::eq_parallelized);
|
||||
let ne_executor = CpuFunctionExecutor::new(&ServerKey::ne_parallelized);
|
||||
let gt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::gt_parallelized);
|
||||
let ge_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::ge_parallelized);
|
||||
let lt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::lt_parallelized);
|
||||
let le_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::le_parallelized);
|
||||
let eq_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::eq_parallelized);
|
||||
let ne_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::ne_parallelized);
|
||||
|
||||
// Comparison Ops Clear functions
|
||||
let clear_gt = |x: i64, y: i64| -> bool { x > y };
|
||||
@@ -364,12 +374,12 @@ where
|
||||
];
|
||||
|
||||
// Scalar Comparison Ops Executors
|
||||
let scalar_gt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_gt_parallelized);
|
||||
let scalar_ge_executor = CpuFunctionExecutor::new(&ServerKey::scalar_ge_parallelized);
|
||||
let scalar_lt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_lt_parallelized);
|
||||
let scalar_le_executor = CpuFunctionExecutor::new(&ServerKey::scalar_le_parallelized);
|
||||
let scalar_eq_executor = CpuFunctionExecutor::new(&ServerKey::scalar_eq_parallelized);
|
||||
let scalar_ne_executor = CpuFunctionExecutor::new(&ServerKey::scalar_ne_parallelized);
|
||||
let scalar_gt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_gt_parallelized);
|
||||
let scalar_ge_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_ge_parallelized);
|
||||
let scalar_lt_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_lt_parallelized);
|
||||
let scalar_le_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_le_parallelized);
|
||||
let scalar_eq_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_eq_parallelized);
|
||||
let scalar_ne_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::scalar_ne_parallelized);
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_comparison_ops: Vec<(
|
||||
@@ -410,7 +420,7 @@ where
|
||||
];
|
||||
|
||||
// Select Executor
|
||||
let select_executor = CpuFunctionExecutor::new(&ServerKey::cmux_parallelized);
|
||||
let select_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::cmux_parallelized);
|
||||
|
||||
// Select
|
||||
let clear_select = |b: bool, x: i64, y: i64| if b { x } else { y };
|
||||
@@ -427,7 +437,7 @@ where
|
||||
)];
|
||||
|
||||
// Div executor
|
||||
let div_rem_executor = CpuFunctionExecutor::new(&ServerKey::div_rem_parallelized);
|
||||
let div_rem_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::div_rem_parallelized);
|
||||
// Div Rem Clear functions
|
||||
let clear_div_rem = |x: i64, y: i64| -> (i64, i64) { (x.wrapping_div(y), x.wrapping_rem(y)) };
|
||||
#[allow(clippy::type_complexity)]
|
||||
@@ -443,7 +453,7 @@ where
|
||||
|
||||
// Scalar Div executor
|
||||
let scalar_div_rem_executor =
|
||||
CpuFunctionExecutor::new(&ServerKey::signed_scalar_div_rem_parallelized);
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::signed_scalar_div_rem_parallelized);
|
||||
#[allow(clippy::type_complexity)]
|
||||
let mut scalar_div_rem_op: Vec<(
|
||||
SignedScalarDivRemOpExecutor,
|
||||
@@ -456,9 +466,11 @@ where
|
||||
)];
|
||||
|
||||
// Log2/Hamming weight ops
|
||||
let ilog2_executor = CpuFunctionExecutor::new(&ServerKey::ilog2_parallelized);
|
||||
let count_zeros_executor = CpuFunctionExecutor::new(&ServerKey::count_zeros_parallelized);
|
||||
let count_ones_executor = CpuFunctionExecutor::new(&ServerKey::count_ones_parallelized);
|
||||
let ilog2_executor = OpSequenceCpuFunctionExecutor::new(&ServerKey::ilog2_parallelized);
|
||||
let count_zeros_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::count_zeros_parallelized);
|
||||
let count_ones_executor =
|
||||
OpSequenceCpuFunctionExecutor::new(&ServerKey::count_ones_parallelized);
|
||||
let clear_ilog2 = |x: i64| x.ilog2() as u64;
|
||||
let clear_count_zeros = |x: i64| x.count_zeros() as u64;
|
||||
let clear_count_ones = |x: i64| x.count_ones() as u64;
|
||||
@@ -478,7 +490,7 @@ where
|
||||
),
|
||||
];
|
||||
|
||||
signed_random_op_sequence_test(
|
||||
let (cks, sks, mut datagen) = signed_random_op_sequence_test_init_cpu(
|
||||
param,
|
||||
&mut binary_ops,
|
||||
&mut unary_ops,
|
||||
@@ -494,10 +506,29 @@ where
|
||||
&mut rotate_shift_ops,
|
||||
&mut scalar_shift_rotate_ops,
|
||||
);
|
||||
|
||||
signed_random_op_sequence_test(
|
||||
&mut datagen,
|
||||
&cks,
|
||||
&sks,
|
||||
&mut binary_ops,
|
||||
&mut unary_ops,
|
||||
&mut scalar_binary_ops,
|
||||
&mut overflowing_ops,
|
||||
&mut scalar_overflowing_ops,
|
||||
&mut comparison_ops,
|
||||
&mut scalar_comparison_ops,
|
||||
&mut select_op,
|
||||
&mut div_rem_op,
|
||||
&mut scalar_div_rem_op,
|
||||
&mut log2_ops,
|
||||
&mut rotate_shift_ops,
|
||||
&mut scalar_shift_rotate_ops,
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
pub(crate) fn signed_random_op_sequence_test_init_cpu<P>(
|
||||
param: P,
|
||||
binary_ops: &mut [(SignedBinaryOpExecutor, impl Fn(i64, i64) -> i64, String)],
|
||||
unary_ops: &mut [(SignedUnaryOpExecutor, impl Fn(i64) -> i64, String)],
|
||||
@@ -548,57 +579,14 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
impl Fn(i64, u64) -> i64,
|
||||
String,
|
||||
)],
|
||||
) where
|
||||
) -> (
|
||||
RadixClientKey,
|
||||
Arc<ServerKey>,
|
||||
RandomOpSequenceDataGenerator<i64, SignedRadixCiphertext>,
|
||||
)
|
||||
where
|
||||
P: Into<TestParameters>,
|
||||
{
|
||||
let param = param.into();
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT_LONG_RUN));
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
for x in binary_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in unary_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in scalar_binary_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in overflowing_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in scalar_overflowing_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in comparison_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in scalar_comparison_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in select_op.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in div_rem_op.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in scalar_div_rem_op.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in log2_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in rotate_shift_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
for x in scalar_rotate_shift_ops.iter_mut() {
|
||||
x.0.setup(&cks, sks.clone());
|
||||
}
|
||||
let total_num_ops = binary_ops.len()
|
||||
+ unary_ops.len()
|
||||
+ scalar_binary_ops.len()
|
||||
@@ -612,6 +600,132 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
+ log2_ops.len()
|
||||
+ rotate_shift_ops.len()
|
||||
+ scalar_rotate_shift_ops.len();
|
||||
|
||||
let param = param.into();
|
||||
let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix);
|
||||
|
||||
let temp_cks =
|
||||
ClientKey::from_raw_parts(cks.clone(), None, None, None, None, None, Tag::default());
|
||||
let comp_sks = CompressedServerKey::new(&temp_cks);
|
||||
|
||||
sks.set_deterministic_pbs_execution(true);
|
||||
let sks = Arc::new(sks);
|
||||
let cks = RadixClientKey::from((cks, NB_CTXT_LONG_RUN));
|
||||
|
||||
let mut datagen = get_user_defined_seed().map_or_else(
|
||||
|| RandomOpSequenceDataGenerator::<i64, SignedRadixCiphertext>::new(total_num_ops, &cks),
|
||||
|seed| {
|
||||
RandomOpSequenceDataGenerator::<i64, SignedRadixCiphertext>::new_with_seed(
|
||||
total_num_ops,
|
||||
seed,
|
||||
&cks,
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
println!(
|
||||
"signed_random_op_sequence_test::seed = {}",
|
||||
datagen.get_seed().0
|
||||
);
|
||||
|
||||
for x in binary_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in unary_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_binary_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in overflowing_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_overflowing_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in comparison_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_comparison_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in select_op.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in div_rem_op.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_div_rem_op.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in log2_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in rotate_shift_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
for x in scalar_rotate_shift_ops.iter_mut() {
|
||||
x.0.setup(&cks, &comp_sks, &mut datagen.deterministic_seeder);
|
||||
}
|
||||
|
||||
(cks, sks, datagen)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn signed_random_op_sequence_test(
|
||||
datagen: &mut RandomOpSequenceDataGenerator<i64, SignedRadixCiphertext>,
|
||||
cks: &RadixClientKey,
|
||||
sks: &Arc<ServerKey>,
|
||||
binary_ops: &mut [(SignedBinaryOpExecutor, impl Fn(i64, i64) -> i64, String)],
|
||||
unary_ops: &mut [(SignedUnaryOpExecutor, impl Fn(i64) -> i64, String)],
|
||||
scalar_binary_ops: &mut [(
|
||||
SignedScalarBinaryOpExecutor,
|
||||
impl Fn(i64, i64) -> i64,
|
||||
String,
|
||||
)],
|
||||
overflowing_ops: &mut [(
|
||||
SignedOverflowingOpExecutor,
|
||||
impl Fn(i64, i64) -> (i64, bool),
|
||||
String,
|
||||
)],
|
||||
scalar_overflowing_ops: &mut [(
|
||||
SignedScalarOverflowingOpExecutor,
|
||||
impl Fn(i64, i64) -> (i64, bool),
|
||||
String,
|
||||
)],
|
||||
comparison_ops: &mut [(
|
||||
SignedComparisonOpExecutor,
|
||||
impl Fn(i64, i64) -> bool,
|
||||
String,
|
||||
)],
|
||||
scalar_comparison_ops: &mut [(
|
||||
SignedScalarComparisonOpExecutor,
|
||||
impl Fn(i64, i64) -> bool,
|
||||
String,
|
||||
)],
|
||||
select_op: &mut [(
|
||||
SignedSelectOpExecutor,
|
||||
impl Fn(bool, i64, i64) -> i64,
|
||||
String,
|
||||
)],
|
||||
div_rem_op: &mut [(
|
||||
SignedDivRemOpExecutor,
|
||||
impl Fn(i64, i64) -> (i64, i64),
|
||||
String,
|
||||
)],
|
||||
scalar_div_rem_op: &mut [(
|
||||
SignedScalarDivRemOpExecutor,
|
||||
impl Fn(i64, i64) -> (i64, i64),
|
||||
String,
|
||||
)],
|
||||
log2_ops: &mut [(SignedLog2OpExecutor, impl Fn(i64) -> u64, String)],
|
||||
rotate_shift_ops: &mut [(SignedShiftRotateExecutor, impl Fn(i64, u64) -> i64, String)],
|
||||
scalar_rotate_shift_ops: &mut [(
|
||||
SignedScalarShiftRotateExecutor,
|
||||
impl Fn(i64, u64) -> i64,
|
||||
String,
|
||||
)],
|
||||
) {
|
||||
let binary_ops_range = 0..binary_ops.len();
|
||||
let unary_ops_range = binary_ops_range.end..binary_ops_range.end + unary_ops.len();
|
||||
let scalar_binary_ops_range =
|
||||
@@ -634,16 +748,6 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
let scalar_rotate_shift_ops_range =
|
||||
rotate_shift_ops_range.end..rotate_shift_ops_range.end + scalar_rotate_shift_ops.len();
|
||||
|
||||
let mut datagen =
|
||||
crate::integer::server_key::radix_parallel::tests_long_run::RandomOpSequenceDataGenerator::<
|
||||
i64,
|
||||
SignedRadixCiphertext,
|
||||
>::new(total_num_ops, &cks);
|
||||
println!(
|
||||
"signed_random_op_sequence_test::seed = {}",
|
||||
datagen.get_seed().0
|
||||
);
|
||||
|
||||
for fn_index in 0..get_long_test_iterations() {
|
||||
let (i, idx) = datagen.gen_op_index();
|
||||
|
||||
@@ -819,7 +923,7 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
let expected_res = clear_fn(lhs.p, rhs.p);
|
||||
|
||||
let res_ct: SignedRadixCiphertext = sks.cast_to_signed(
|
||||
res.clone().into_radix::<SignedRadixCiphertext>(1, &sks),
|
||||
res.clone().into_radix::<SignedRadixCiphertext>(1, sks),
|
||||
NB_CTXT_LONG_RUN,
|
||||
);
|
||||
|
||||
@@ -851,7 +955,7 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
let expected_res = clear_fn(lhs.p, rhs.p);
|
||||
|
||||
let res_ct: SignedRadixCiphertext = sks.cast_to_signed(
|
||||
res.clone().into_radix::<SignedRadixCiphertext>(1, &sks),
|
||||
res.clone().into_radix::<SignedRadixCiphertext>(1, sks),
|
||||
NB_CTXT_LONG_RUN,
|
||||
);
|
||||
|
||||
@@ -875,7 +979,7 @@ pub(crate) fn signed_random_op_sequence_test<P>(
|
||||
|
||||
let (lhs, rhs) = datagen.gen_op_operands(idx, fn_name);
|
||||
|
||||
let clear_bool: bool = rng.gen_bool(0.5);
|
||||
let clear_bool: bool = datagen.gen_bool_uniform();
|
||||
let bool_input = cks.encrypt_bool(clear_bool);
|
||||
|
||||
let res = select_op_executor.execute((&bool_input, &lhs.c, &rhs.c));
|
||||
|
||||
@@ -29,8 +29,10 @@ pub(crate) mod test_vector_comparisons;
|
||||
pub(crate) mod test_vector_find;
|
||||
|
||||
use super::tests_cases_unsigned::*;
|
||||
use crate::core_crypto::commons::generators::DeterministicSeeder;
|
||||
use crate::core_crypto::prelude::UnsignedInteger;
|
||||
use crate::integer::keycache::KEY_CACHE;
|
||||
use crate::integer::server_key::radix_parallel::tests_long_run::OpSequenceFunctionExecutor;
|
||||
use crate::integer::tests::create_parameterized_test;
|
||||
use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey};
|
||||
use crate::shortint::ciphertext::MaxDegree;
|
||||
@@ -38,9 +40,11 @@ use crate::shortint::ciphertext::MaxDegree;
|
||||
use crate::shortint::parameters::coverage_parameters::*;
|
||||
use crate::shortint::parameters::test_params::*;
|
||||
use crate::shortint::parameters::*;
|
||||
use crate::CompressedServerKey;
|
||||
use rand::prelude::ThreadRng;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
use tfhe_csprng::generators::DefaultRandomGenerator;
|
||||
|
||||
#[cfg(not(tarpaulin))]
|
||||
pub(crate) const NB_CTXT: usize = 4;
|
||||
@@ -619,6 +623,111 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// The function executor for cpu server key
|
||||
///
|
||||
/// It will mainly simply forward call to a server key method
|
||||
pub(crate) struct OpSequenceCpuFunctionExecutor<F> {
|
||||
/// The server key is set later, when the test cast calls setup
|
||||
pub(crate) sks: Option<Arc<ServerKey>>,
|
||||
/// The server key function which will be called
|
||||
pub(crate) func: F,
|
||||
}
|
||||
|
||||
impl<F> OpSequenceCpuFunctionExecutor<F> {
|
||||
pub(crate) fn new(func: F) -> Self {
|
||||
Self { sks: None, func }
|
||||
}
|
||||
pub(crate) fn setup_from_cpu_keys(&mut self, sks: &CompressedServerKey) {
|
||||
let (isks, _, _, _, _, _, _, _) = sks.decompress().into_raw_parts();
|
||||
self.sks = Some(Arc::new(isks));
|
||||
}
|
||||
}
|
||||
|
||||
/// For unary operations
|
||||
///
|
||||
/// Note, we need to `NotTuple` constraint to avoid conflicts with binary or ternary operations
|
||||
impl<F, I1, O> OpSequenceFunctionExecutor<I1, O> for OpSequenceCpuFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&ServerKey, I1) -> O,
|
||||
I1: NotTuple,
|
||||
{
|
||||
fn setup(
|
||||
&mut self,
|
||||
_cks: &RadixClientKey,
|
||||
sks: &CompressedServerKey,
|
||||
_seeder: &mut DeterministicSeeder<DefaultRandomGenerator>,
|
||||
) {
|
||||
let (isks, _, _, _, _, _, _, _) = sks.decompress().into_raw_parts();
|
||||
self.sks = Some(Arc::new(isks));
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: I1) -> O {
|
||||
let sks = self.sks.as_ref().expect("setup was not properly called");
|
||||
(self.func)(sks, input)
|
||||
}
|
||||
}
|
||||
|
||||
/// For binary operations
|
||||
impl<F, I1, I2, O> OpSequenceFunctionExecutor<(I1, I2), O> for OpSequenceCpuFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&ServerKey, I1, I2) -> O,
|
||||
{
|
||||
fn setup(
|
||||
&mut self,
|
||||
_cks: &RadixClientKey,
|
||||
sks: &CompressedServerKey,
|
||||
_seeder: &mut DeterministicSeeder<DefaultRandomGenerator>,
|
||||
) {
|
||||
self.setup_from_cpu_keys(sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (I1, I2)) -> O {
|
||||
let sks = self.sks.as_ref().expect("setup was not properly called");
|
||||
(self.func)(sks, input.0, input.1)
|
||||
}
|
||||
}
|
||||
|
||||
/// For ternary operations
|
||||
impl<F, I1, I2, I3, O> OpSequenceFunctionExecutor<(I1, I2, I3), O>
|
||||
for OpSequenceCpuFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&ServerKey, I1, I2, I3) -> O,
|
||||
{
|
||||
fn setup(
|
||||
&mut self,
|
||||
_cks: &RadixClientKey,
|
||||
sks: &CompressedServerKey,
|
||||
_seeder: &mut DeterministicSeeder<DefaultRandomGenerator>,
|
||||
) {
|
||||
self.setup_from_cpu_keys(sks);
|
||||
}
|
||||
|
||||
fn execute(&mut self, input: (I1, I2, I3)) -> O {
|
||||
let sks = self.sks.as_ref().expect("setup was not properly called");
|
||||
(self.func)(sks, input.0, input.1, input.2)
|
||||
}
|
||||
}
|
||||
|
||||
/// For 4-ary operations
|
||||
impl<F, I1, I2, I3, I4, O> OpSequenceFunctionExecutor<(I1, I2, I3, I4), O>
|
||||
for OpSequenceCpuFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(&ServerKey, I1, I2, I3, I4) -> O,
|
||||
{
|
||||
fn setup(
|
||||
&mut self,
|
||||
_cks: &RadixClientKey,
|
||||
sks: &CompressedServerKey,
|
||||
_seeder: &mut DeterministicSeeder<DefaultRandomGenerator>,
|
||||
) {
|
||||
self.setup_from_cpu_keys(sks);
|
||||
}
|
||||
fn execute(&mut self, input: (I1, I2, I3, I4)) -> O {
|
||||
let sks = self.sks.as_ref().expect("setup was not properly called");
|
||||
(self.func)(sks, input.0, input.1, input.2, input.3)
|
||||
}
|
||||
}
|
||||
|
||||
//=============================================================================
|
||||
// Unchecked Tests
|
||||
//=============================================================================
|
||||
|
||||
Reference in New Issue
Block a user