minimal compiling and working bench it seems

This commit is contained in:
Arthur Meyre
2025-10-10 16:26:45 +00:00
parent 5e875462df
commit c4798dd2b3
2 changed files with 243 additions and 90 deletions

View File

@@ -13,24 +13,30 @@ use rayon::prelude::*;
use std::cell::LazyCell;
use std::cmp::max;
use std::env;
use tfhe::core_crypto::algorithms::modulus_switch::ModulusSwitchedLweCiphertext;
use tfhe::core_crypto::algorithms::{
allocate_and_generate_new_binary_glwe_secret_key,
allocate_and_generate_new_binary_lwe_secret_key, allocate_and_generate_new_lwe_keyswitch_key,
extract_lwe_sample_from_glwe_ciphertext, keyswitch_lwe_ciphertext_with_scalar_change,
lwe_ciphertext_centered_binary_modulus_switch,
};
use tfhe::core_crypto::entities::modulus_switched_lwe_ciphertext::LazyStandardModulusSwitchedLweCiphertext;
use tfhe::core_crypto::entities::LweCiphertext;
use tfhe::integer::keycache::KEY_CACHE;
use tfhe::integer::prelude::*;
use tfhe::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey, U256};
use tfhe::keycache::NamedParam;
use tfhe::shortint::atomic_pattern::{AtomicPattern, AtomicPatternMut};
use tfhe::shortint::atomic_pattern::{AtomicPattern, AtomicPatternMut, AtomicPatternServerKey};
use tfhe::shortint::ciphertext::{Degree, NoiseLevel};
use tfhe::shortint::engine::ShortintEngine;
use tfhe::{get_pbs_count, reset_pbs_count};
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct ParamKS32MB {
pub lwe_dimension: LweDimension,
pub glwe_dimension: GlweDimension,
pub polynomial_size: PolynomialSize,
pub lwe_noise_distribution: DynamicDistribution<u64>,
pub lwe_noise_distribution: DynamicDistribution<u32>,
pub glwe_noise_distribution: DynamicDistribution<u64>,
pub pbs_base_log: DecompositionBaseLog,
pub pbs_level: DecompositionLevelCount,
@@ -69,68 +75,119 @@ impl ClientKeyKS32MB {
Self {
lwe_secret_key,
glwe_secret_key,
parameters: *params.to_owned(),
parameters: *params,
}
}
// warning fake encryption for benchmark
pub fn encrypt_shortint(&self, msg: u64) -> Ciphertext {
todo!()
let cont = (0..self
.parameters
.glwe_dimension
.to_equivalent_lwe_dimension(self.parameters.polynomial_size)
.to_lwe_size()
.0)
.map(|_| rand::random())
.collect::<Vec<_>>();
Ciphertext::new(
LweCiphertext::from_container(cont, self.parameters.ciphertext_modulus),
Degree::new(self.parameters.message_modulus.0 - 1),
NoiseLevel::NOMINAL,
self.parameters.message_modulus,
self.parameters.carry_modulus,
tfhe::shortint::AtomicPatternKind::KeySwitch32,
)
}
// tfhe/src/integer/ciphertext/base.rs:29
pub fn encrypt_radix(&self, msg: u64, block_count: usize) -> RadixCiphertext {
let blocks: Vec<Ciphertext> = todo!();
let blocks: Vec<Ciphertext> = (0..block_count)
.map(|_| self.encrypt_shortint(msg))
.collect();
RadixCiphertext::from(blocks)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct ServerKeyKS32MB {
lwe_ksk: LweKeyswitchKeyOwned<u32>,
lwe_bsk: ShortintBootstrappingKey<u32>,
is_2m128: bool,
parameters: ParamKS32MB,
}
impl ServerKeyKS32MB {
pub fn new(cks: &ClientKeyKS32MB, is_2m128: bool) -> Self {
let params = &cks.parameters;
let in_key = cks.lwe_secret_key.as_view();
let in_key_u64 = {
// Convert to u64 as a workaround
let cont = cks
.lwe_secret_key
.as_ref()
.iter()
.map(|x| *x as u64)
.collect::<Vec<_>>();
LweSecretKey::from_container(cont)
};
let out_key = &cks.glwe_secret_key;
let (key_switching_key, bootstrapping_key) = ShortintEngine::with_thread_local_mut(|engine| {
let (key_switching_key, bootstrapping_key) =
ShortintEngine::with_thread_local_mut(|engine| {
let bootstrapping_key_base = engine.new_multibit_bootstrapping_key(
&in_key_u64,
&out_key,
params.glwe_noise_distribution,
params.pbs_base_log,
params.pbs_level,
params.grouping_factor,
params.ciphertext_modulus,
);
let bootstrapping_key_base = engine.new_multibit_bootstrapping_key(
&in_key,
&out_key,
params.glwe_noise_distribution,
params.pbs_base_log,
params.pbs_level,
params.grouping_factor,
params.ciphertext_modulus,
// Creation of the key switching key
let key_switching_key = allocate_and_generate_new_lwe_keyswitch_key(
&cks.glwe_secret_key.as_lwe_secret_key(),
&cks.lwe_secret_key,
params.ks_base_log,
params.ks_level,
params.lwe_noise_distribution,
CiphertextModulus32::new_native(),
&mut engine.encryption_generator,
);
(key_switching_key, bootstrapping_key_base)
});
let thread_count = ShortintEngine::get_thread_count_for_multi_bit_pbs(
bootstrapping_key.input_lwe_dimension(),
bootstrapping_key.glwe_size().to_glwe_dimension(),
bootstrapping_key.polynomial_size(),
bootstrapping_key.decomposition_base_log(),
bootstrapping_key.decomposition_level_count(),
bootstrapping_key.grouping_factor(),
);
// Creation of the key switching key
let key_switching_key = allocate_and_generate_new_lwe_keyswitch_key(
&cks.glwe_secret_key.as_lwe_secret_key(),
&in_key,
params.ks_base_log,
params.ks_level,
params.lwe_noise_distribution,
CiphertextModulus32::new_native(),
&mut engine.encryption_generator,
);
(key_switching_key, bootstrapping_key_base)
});
Self {
lwe_ksk: key_switching_key,
lwe_bsk: bootstrapping_key,
lwe_bsk: ShortintBootstrappingKey::MultiBit {
fourier_bsk: bootstrapping_key,
thread_count,
deterministic_execution: false,
},
is_2m128,
parameters: *params,
}
}
pub fn intermediate_lwe_dimension(&self) -> LweDimension {
self.lwe_bsk.input_lwe_dimension()
}
fn intermediate_ciphertext_modulus(&self) -> CiphertextModulus32 {
self.lwe_ksk.ciphertext_modulus()
}
}
impl AtomicPatternMut for ServerKeyKS32MB {
@@ -150,7 +207,10 @@ impl AtomicPattern for ServerKeyKS32MB {
}
fn ciphertext_modulus_for_key(&self, key_choice: EncryptionKeyChoice) -> CiphertextModulus {
todo!() // not required normally
match key_choice {
EncryptionKeyChoice::Big => self.parameters.ciphertext_modulus,
EncryptionKeyChoice::Small => CiphertextModulus32::new_native().try_to().unwrap(),
}
}
fn ciphertext_decompression_method(&self) -> tfhe::core_crypto::prelude::MsDecompressionType {
@@ -163,7 +223,48 @@ impl AtomicPattern for ServerKeyKS32MB {
acc: &tfhe::shortint::server_key::LookupTableOwned,
) {
if self.is_2m128 {
todo!("code bizarre");
ShortintEngine::with_thread_local_mut(|engine| {
// :warning: the mean compensation is classical mean compensation + classical
// modulus switch + multi bit bootstrap (not sure this is
// implemented in tfhe-rs yet) :warning:
let (mut ciphertext_buffer, buffers) = engine.get_buffers(
self.intermediate_lwe_dimension(),
self.intermediate_ciphertext_modulus(),
);
keyswitch_lwe_ciphertext_with_scalar_change(
&self.lwe_ksk,
&ct.ct,
&mut ciphertext_buffer,
);
let br_input_modulus_log = self
.lwe_bsk
.polynomial_size()
.to_blind_rotation_input_modulus_log();
let msed: LazyStandardModulusSwitchedLweCiphertext<u32, u32, &[u32]> =
lwe_ciphertext_centered_binary_modulus_switch(
ciphertext_buffer.as_view(),
self.lwe_bsk
.polynomial_size()
.to_blind_rotation_input_modulus_log(),
);
let mut lwe_rescaled_cont =
Vec::with_capacity(msed.lwe_dimension().to_lwe_size().0);
for mask in msed.mask() {
lwe_rescaled_cont.push(mask << (u32::BITS - br_input_modulus_log.0 as u32));
}
lwe_rescaled_cont.push(msed.body() << (u32::BITS - br_input_modulus_log.0 as u32));
let lwe = LweCiphertext::from_container(
lwe_rescaled_cont,
self.lwe_ksk.ciphertext_modulus(),
);
apply_programmable_bootstrap(&self.lwe_bsk, &lwe, &mut ct.ct, &acc.acc, buffers);
})
} else {
ShortintEngine::with_thread_local_mut(|engine| {
let (mut ciphertext_buffer, buffers) = engine.get_buffers(
@@ -195,26 +296,71 @@ impl AtomicPattern for ServerKeyKS32MB {
) -> Vec<Ciphertext> {
let mut acc = lut.acc.clone();
ShortintEngine::with_thread_local_mut(|engine| {
let (mut ciphertext_buffer, buffers) = engine.get_buffers(
self.ciphertext_lwe_dimension_for_key(EncryptionKeyChoice::Small),
self.lwe_ksk.ciphertext_modulus(),
);
if self.is_2m128 {
ShortintEngine::with_thread_local_mut(|engine| {
// :warning: the mean compensation is classical mean compensation + classical
// modulus switch + multi bit bootstrap (not sure this is
// implemented in tfhe-rs yet) :warning:
let (mut ciphertext_buffer, buffers) = engine.get_buffers(
self.intermediate_lwe_dimension(),
self.intermediate_ciphertext_modulus(),
);
// Compute a key switch
keyswitch_lwe_ciphertext_with_scalar_change(
&self.lwe_ksk,
&ct.ct,
&mut ciphertext_buffer,
);
keyswitch_lwe_ciphertext_with_scalar_change(
&self.lwe_ksk,
&ct.ct,
&mut ciphertext_buffer,
);
apply_ms_blind_rotate(
&self.lwe_bsk,
&ciphertext_buffer.as_view(),
&mut acc,
buffers,
);
});
let br_input_modulus_log = self
.lwe_bsk
.polynomial_size()
.to_blind_rotation_input_modulus_log();
let msed: LazyStandardModulusSwitchedLweCiphertext<u32, u32, &[u32]> =
lwe_ciphertext_centered_binary_modulus_switch(
ciphertext_buffer.as_view(),
self.lwe_bsk
.polynomial_size()
.to_blind_rotation_input_modulus_log(),
);
let mut lwe_rescaled_cont =
Vec::with_capacity(msed.lwe_dimension().to_lwe_size().0);
for mask in msed.mask() {
lwe_rescaled_cont.push(mask << (u32::BITS - br_input_modulus_log.0 as u32));
}
lwe_rescaled_cont.push(msed.body() << (u32::BITS - br_input_modulus_log.0 as u32));
let lwe = LweCiphertext::from_container(
lwe_rescaled_cont,
self.lwe_ksk.ciphertext_modulus(),
);
apply_ms_blind_rotate(&self.lwe_bsk, &lwe, &mut acc, buffers);
})
} else {
ShortintEngine::with_thread_local_mut(|engine| {
let (mut ciphertext_buffer, buffers) = engine.get_buffers(
self.ciphertext_lwe_dimension_for_key(EncryptionKeyChoice::Small),
self.lwe_ksk.ciphertext_modulus(),
);
// Compute a key switch
keyswitch_lwe_ciphertext_with_scalar_change(
&self.lwe_ksk,
&ct.ct,
&mut ciphertext_buffer,
);
apply_ms_blind_rotate(
&self.lwe_bsk,
&ciphertext_buffer.as_view(),
&mut acc,
buffers,
);
});
}
// The accumulator has been rotated, we can now proceed with the various sample extractions
let function_count = lut.function_count();
@@ -239,12 +385,14 @@ impl AtomicPattern for ServerKeyKS32MB {
}
fn lookup_table_size(&self) -> tfhe::shortint::server_key::LookupTableSize {
// See KS32 shortint
todo!()
tfhe::shortint::server_key::LookupTableSize::new(
self.lwe_bsk.glwe_size(),
self.lwe_bsk.polynomial_size(),
)
}
fn kind(&self) -> tfhe::shortint::AtomicPatternKind {
todo!() // maybe
tfhe::shortint::AtomicPatternKind::KeySwitch32
}
fn generate_oblivious_pseudo_random(
@@ -282,7 +430,7 @@ impl AtomicPattern for ServerKeyKS32MB {
pub fn create_sk(cks: &ClientKeyKS32MB, is_2m128: bool) -> ServerKey {
let key = ServerKeyKS32MB::new(cks, is_2m128);
let boxed: Box<dyn AtomicPattern> = Box::new(key);
let boxed = Box::new(key);
// tfhe/src/shortint/engine/server_side.rs:21
@@ -292,7 +440,7 @@ pub fn create_sk(cks: &ClientKeyKS32MB, is_2m128: bool) -> ServerKey {
);
let shortint_sk = tfhe::shortint::ServerKey::from_raw_parts(
*boxed,
AtomicPatternServerKey::Dynamic(boxed),
cks.parameters.message_modulus,
cks.parameters.carry_modulus,
max_degree,
@@ -303,7 +451,7 @@ pub fn create_sk(cks: &ClientKeyKS32MB, is_2m128: bool) -> ServerKey {
}
const BENCH_MULTI_BIT_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_2M40: ParamKS32MB = ParamKS32MB {
lwe_dimension: LweDimension(429),
lwe_dimension: LweDimension(429 * 2),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(2048),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(2.34899e-6)),
@@ -324,7 +472,7 @@ const BENCH_MULTI_BIT_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_2M40: ParamKS32M
};
const BENCH_MULTI_BIT_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_2M64: ParamKS32MB = ParamKS32MB {
lwe_dimension: LweDimension(444),
lwe_dimension: LweDimension(444 * 2),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(2048),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(1.39987e-6)),
@@ -345,7 +493,7 @@ const BENCH_MULTI_BIT_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_2M64: ParamKS32M
};
const BENCH_MULTI_BIT_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_2M128: ParamKS32MB = ParamKS32MB {
lwe_dimension: LweDimension(457),
lwe_dimension: LweDimension(457 * 2),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(2048),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
@@ -742,38 +890,44 @@ fn bench_server_key_binary_function_clean_inputs_with_ks32<F>(
.measurement_time(std::time::Duration::from_secs(60));
let mut rng = rand::thread_rng();
let bit_sizes = [4u32, 8, 16, 32, 64, 128, 256];
for (param_name, param, is_2m128) in KS32_BENCHMARK_PARAM_SET.iter() {
let bench_id;
for bit_size in bit_sizes {
let num_block = bit_size.div_ceil(param.message_modulus.0.ilog2()) as usize;
let bench_data = LazyCell::new(|| {
let cks = ClientKeyKS32MB::new(param);
let sks = create_sk(&cks, *is_2m128);
let bench_id;
let clear_0 = gen_random_u256(&mut rng);
let clear_1 = gen_random_u256(&mut rng);
let bench_data = LazyCell::new(|| {
let cks = ClientKeyKS32MB::new(param);
let sks = create_sk(&cks, *is_2m128);
let ct_0 = cks.encrypt_radix(clear_0, num_block);
let ct_1 = cks.encrypt_radix(clear_1, num_block);
(sks, ct_0, ct_1)
});
let clear_0 = rand::random::<u64>();
let clear_1 = rand::random::<u64>();
bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits");
bench_group.bench_function(&bench_id, |b| {
let (sks, ct_0, ct_1) = (&bench_data.0, &bench_data.1, &bench_data.2);
b.iter(|| {
binary_op(sks, ct_0, ct_1);
})
});
let ct_0 = cks.encrypt_radix(clear_0, num_block);
let ct_1 = cks.encrypt_radix(clear_1, num_block);
(sks, ct_0, ct_1)
});
// write_to_json::<u64, _>(
// &bench_id,
// param,
// *param_name,
// display_name,
// &OperatorType::Atomic,
// bit_size as u32,
// vec![param.message_modulus.0.ilog2(); num_block],
// );
bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits");
bench_group.bench_function(&bench_id, |b| {
let (sks, ct_0, ct_1) = (&bench_data.0, &bench_data.1, &bench_data.2);
b.iter(|| {
binary_op(sks, ct_0, ct_1);
})
});
// write_to_json::<u64, _>(
// &bench_id,
// param,
// *param_name,
// display_name,
// &OperatorType::Atomic,
// bit_size as u32,
// vec![param.message_modulus.0.ilog2(); num_block],
// );
}
}
bench_group.finish()
@@ -3382,16 +3536,15 @@ use cuda::{
cuda_cast_ops, default_cuda_dedup_ops, default_cuda_ops, default_scalar_cuda_ops,
unchecked_cuda_ops, unchecked_scalar_cuda_ops,
};
use tfhe::boolean::parameters::{
DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, GlweDimension,
LweDimension, PolynomialSize, StandardDev,
};
use tfhe::core_crypto::entities::LweKeyswitchKeyOwned;
use tfhe::core_crypto::prelude::{
GlweSecretKey, LweBskGroupingFactor, LweSecretKey, MonomialDegree,
};
use tfhe::shortint::ciphertext::MaxDegree;
use tfhe::shortint::parameters::CiphertextModulus32;
use tfhe::shortint::parameters::{
CiphertextModulus32, DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution,
GlweDimension, LweDimension, PolynomialSize, StandardDev,
};
use tfhe::shortint::server_key::{
apply_ms_blind_rotate, apply_programmable_bootstrap, ShortintBootstrappingKey,
};

View File

@@ -26,7 +26,7 @@ impl ShortintEngine {
self.new_server_key_with_max_degree(cks, max_degree)
}
pub(crate) fn get_thread_count_for_multi_bit_pbs(
pub fn get_thread_count_for_multi_bit_pbs(
lwe_dimension: LweDimension,
glwe_dimension: GlweDimension,
polynomial_size: PolynomialSize,