refactor!: gate wops behind "experimental" feature

This puts the WOPBS features of shortint and integer
modules behind the "experimental" feature.

Due to the versioning feature, the structs definitions
are not gated behind the "experimental" feature, however
they are only pub(crate) in that case.
This commit is contained in:
tmontaigu
2024-09-26 13:10:11 +02:00
parent d2efa82daf
commit 45effa41d5
12 changed files with 2149 additions and 2049 deletions

View File

@@ -285,12 +285,18 @@ clippy_shortint: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=$(TARGET_ARCH_FEATURE),shortint \
-p $(TFHE_SPEC) -- --no-deps -D warnings
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=$(TARGET_ARCH_FEATURE),shortint,experimental \
-p $(TFHE_SPEC) -- --no-deps -D warnings
.PHONY: clippy_integer # Run clippy lints enabling the integer features
clippy_integer: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=$(TARGET_ARCH_FEATURE),integer \
-p $(TFHE_SPEC) -- --no-deps -D warnings
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=$(TARGET_ARCH_FEATURE),integer,experimental \
-p $(TFHE_SPEC) -- --no-deps -D warnings
.PHONY: clippy # Run clippy lints enabling the boolean, shortint, integer
clippy: install_rs_check_toolchain
@@ -339,6 +345,9 @@ clippy_all_targets: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache,zk-pok \
-p $(TFHE_SPEC) -- --no-deps -D warnings
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache,zk-pok,experimental \
-p $(TFHE_SPEC) -- --no-deps -D warnings
.PHONY: clippy_concrete_csprng # Run clippy lints on concrete-csprng
clippy_concrete_csprng: install_rs_check_toolchain

View File

@@ -155,7 +155,7 @@ cargo "${RUST_TOOLCHAIN}" nextest run \
--cargo-profile "${cargo_profile}" \
--package "${tfhe_package}" \
--profile ci \
--features="${ARCH_FEATURE}",integer,internal-keycache,zk-pok,"${avx512_feature}","${gpu_feature}" \
--features="${ARCH_FEATURE}",integer,internal-keycache,zk-pok,experimental,"${avx512_feature}","${gpu_feature}" \
--test-threads "${test_threads}" \
-E "$filter_expression"
@@ -163,7 +163,7 @@ if [[ -z ${multi_bit_argument} ]]; then
cargo "${RUST_TOOLCHAIN}" test \
--profile "${cargo_profile}" \
--package "${tfhe_package}" \
--features="${ARCH_FEATURE}",integer,internal-keycache,"${avx512_feature}","${gpu_feature}" \
--features="${ARCH_FEATURE}",integer,internal-keycache,experimental,"${avx512_feature}","${gpu_feature}" \
--doc \
-- --test-threads="${doctest_threads}" integer::"${gpu_feature}"
fi

View File

@@ -101,7 +101,7 @@ if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then
--cargo-profile "${cargo_profile}" \
--package "${tfhe_package}" \
--profile ci \
--features="${ARCH_FEATURE}",shortint,internal-keycache,zk-pok \
--features="${ARCH_FEATURE}",shortint,internal-keycache,zk-pok,experimental \
--test-threads "${n_threads_small}" \
-E "${filter_expression_small_params}"
@@ -118,7 +118,7 @@ and not test(~smart_add_and_mul)"""
--cargo-profile "${cargo_profile}" \
--package "${tfhe_package}" \
--profile ci \
--features="${ARCH_FEATURE}",shortint,internal-keycache \
--features="${ARCH_FEATURE}",shortint,internal-keycache,experimental \
--test-threads "${n_threads_big}" \
-E "${filter_expression_big_params}"
@@ -126,7 +126,7 @@ and not test(~smart_add_and_mul)"""
cargo "${RUST_TOOLCHAIN}" test \
--profile "${cargo_profile}" \
--package "${tfhe_package}" \
--features="${ARCH_FEATURE}",shortint,internal-keycache \
--features="${ARCH_FEATURE}",shortint,internal-keycache,experimental \
--doc \
-- shortint::
fi
@@ -140,7 +140,7 @@ else
--cargo-profile "${cargo_profile}" \
--package "${tfhe_package}" \
--profile ci \
--features="${ARCH_FEATURE}",shortint,internal-keycache \
--features="${ARCH_FEATURE}",shortint,internal-keycache,experimental \
--test-threads "${n_threads_big}" \
-E "${filter_expression}"
@@ -148,7 +148,7 @@ else
cargo "${RUST_TOOLCHAIN}" test \
--profile "${cargo_profile}" \
--package "${tfhe_package}" \
--features="${ARCH_FEATURE}",shortint,internal-keycache \
--features="${ARCH_FEATURE}",shortint,internal-keycache,experimental \
--doc \
-- --test-threads="${n_threads_big}" shortint::
fi

View File

@@ -2,7 +2,9 @@ use clap::{Arg, ArgAction, Command};
use tfhe::boolean;
use tfhe::boolean::parameters::{BooleanParameters, DEFAULT_PARAMETERS, DEFAULT_PARAMETERS_KS_PBS};
use tfhe::keycache::NamedParam;
use tfhe::shortint::keycache::{KEY_CACHE, KEY_CACHE_KSK, KEY_CACHE_WOPBS};
#[cfg(feature = "experimental")]
use tfhe::shortint::keycache::KEY_CACHE_WOPBS;
use tfhe::shortint::keycache::{KEY_CACHE, KEY_CACHE_KSK};
#[cfg(tarpaulin)]
use tfhe::shortint::parameters::coverage_parameters::{
COVERAGE_PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS,
@@ -14,14 +16,16 @@ use tfhe::shortint::parameters::key_switching::p_fail_2_minus_64::ks_pbs::PARAM_
use tfhe::shortint::parameters::key_switching::ShortintKeySwitchingParameters;
use tfhe::shortint::parameters::{
ClassicPBSParameters, WopbsParameters, ALL_MULTI_BIT_PARAMETER_VEC,
PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_1_CARRY_2_KS_PBS, PARAM_MESSAGE_1_CARRY_3_KS_PBS,
PARAM_MESSAGE_1_CARRY_4_KS_PBS, PARAM_MESSAGE_1_CARRY_5_KS_PBS, PARAM_MESSAGE_1_CARRY_6_KS_PBS,
PARAM_MESSAGE_2_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_2_CARRY_3_KS_PBS,
PARAM_MESSAGE_3_CARRY_1_KS_PBS, PARAM_MESSAGE_3_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS,
PARAM_MESSAGE_4_CARRY_4_KS_PBS, WOPBS_PARAM_MESSAGE_1_CARRY_1_KS_PBS,
WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, WOPBS_PARAM_MESSAGE_3_CARRY_3_KS_PBS,
WOPBS_PARAM_MESSAGE_4_CARRY_4_KS_PBS,
ClassicPBSParameters, ALL_MULTI_BIT_PARAMETER_VEC, PARAM_MESSAGE_1_CARRY_1_KS_PBS,
PARAM_MESSAGE_1_CARRY_2_KS_PBS, PARAM_MESSAGE_1_CARRY_3_KS_PBS, PARAM_MESSAGE_1_CARRY_4_KS_PBS,
PARAM_MESSAGE_1_CARRY_5_KS_PBS, PARAM_MESSAGE_1_CARRY_6_KS_PBS, PARAM_MESSAGE_2_CARRY_1_KS_PBS,
PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_2_CARRY_3_KS_PBS, PARAM_MESSAGE_3_CARRY_1_KS_PBS,
PARAM_MESSAGE_3_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, PARAM_MESSAGE_4_CARRY_4_KS_PBS,
};
#[cfg(feature = "experimental")]
use tfhe::shortint::parameters::{
WopbsParameters, WOPBS_PARAM_MESSAGE_1_CARRY_1_KS_PBS, WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
WOPBS_PARAM_MESSAGE_3_CARRY_3_KS_PBS, WOPBS_PARAM_MESSAGE_4_CARRY_4_KS_PBS,
};
use tfhe::shortint::MultiBitPBSParameters;
@@ -89,11 +93,14 @@ fn client_server_keys() {
generate_ksk_keys(&KSK_PARAMS);
#[cfg(feature = "experimental")]
{
const WOPBS_PARAMS: [(ClassicPBSParameters, WopbsParameters); 1] = [(
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
)];
generate_wopbs_keys(&WOPBS_PARAMS);
}
const BOOLEAN_PARAMS: [BooleanParameters; 2] =
[DEFAULT_PARAMETERS, DEFAULT_PARAMETERS_KS_PBS];
@@ -116,6 +123,8 @@ fn client_server_keys() {
];
generate_pbs_keys(&PBS_KEYS);
#[cfg(feature = "experimental")]
{
const WOPBS_PARAMS: [(ClassicPBSParameters, WopbsParameters); 4] = [
(
PARAM_MESSAGE_1_CARRY_1_KS_PBS,
@@ -138,6 +147,7 @@ fn client_server_keys() {
generate_wopbs_keys(&WOPBS_PARAMS);
}
}
}
fn generate_pbs_keys(params: &[ClassicPBSParameters]) {
println!("Generating shortint (ClientKey, ServerKey)");
@@ -219,6 +229,7 @@ fn generate_ksk_keys(
}
}
#[cfg(feature = "experimental")]
fn generate_wopbs_keys(params: &[(ClassicPBSParameters, WopbsParameters)]) {
println!("Generating woPBS keys");

View File

@@ -1,6 +1,9 @@
#[cfg(feature = "experimental")]
use crate::integer::wopbs::WopbsKey;
use crate::integer::{ClientKey, IntegerKeyKind, ServerKey};
use crate::shortint::{PBSParameters, WopbsParameters};
use crate::shortint::PBSParameters;
#[cfg(feature = "experimental")]
use crate::shortint::WopbsParameters;
use lazy_static::lazy_static;
#[derive(Default)]
@@ -39,8 +42,10 @@ impl IntegerKeyCache {
}
#[derive(Default)]
#[cfg(feature = "experimental")]
pub struct WopbsKeyCache;
#[cfg(feature = "experimental")]
impl WopbsKeyCache {
pub fn get_from_params<P>(&self, (pbs_params, wopbs_params): (P, WopbsParameters)) -> WopbsKey
where
@@ -65,5 +70,8 @@ impl WopbsKeyCache {
lazy_static! {
pub static ref KEY_CACHE: IntegerKeyCache = IntegerKeyCache;
}
#[cfg(feature = "experimental")]
lazy_static! {
pub static ref KEY_CACHE_WOPBS: WopbsKeyCache = WopbsKeyCache;
}

View File

@@ -63,7 +63,10 @@ pub mod parameters;
pub mod prelude;
pub mod public_key;
pub mod server_key;
#[cfg(feature = "experimental")]
pub mod wopbs;
#[cfg(not(feature = "experimental"))]
pub(crate) mod wopbs;
#[cfg(feature = "gpu")]
pub mod gpu;

View File

@@ -3,19 +3,11 @@
//! This module implements the generation of another server public key, which allows to compute
//! an alternative version of the programmable bootstrapping. This does not require the use of a
//! bit of padding.
#[cfg(test)]
#[cfg(all(test, feature = "experimental"))]
mod test;
use super::backward_compatibility::wopbs::WopbsKeyVersions;
use super::ciphertext::RadixCiphertext;
pub use crate::core_crypto::commons::parameters::{CiphertextCount, PlaintextCount};
use crate::core_crypto::prelude::*;
use crate::integer::client_key::utils::i_crt;
use crate::integer::{ClientKey, CrtCiphertext, IntegerCiphertext, ServerKey};
use crate::shortint::ciphertext::{Degree, NoiseLevel};
use crate::shortint::wopbs::WopbsLUTBase;
use crate::shortint::WopbsParameters;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;
@@ -25,13 +17,34 @@ pub struct WopbsKey {
wopbs_key: crate::shortint::wopbs::WopbsKey,
}
#[cfg(feature = "experimental")]
pub use experimental::*;
#[cfg(feature = "experimental")]
mod experimental {
pub use crate::core_crypto::commons::parameters::{CiphertextCount, PlaintextCount};
use crate::core_crypto::prelude::*;
use crate::integer::client_key::utils::i_crt;
use crate::integer::{ClientKey, CrtCiphertext, IntegerCiphertext, RadixCiphertext, ServerKey};
use crate::shortint::ciphertext::{Degree, NoiseLevel};
use crate::shortint::WopbsParameters;
use crate::shortint::wopbs::WopbsLUTBase;
use super::WopbsKey;
use rayon::prelude::*;
#[must_use]
pub struct IntegerWopbsLUT {
inner: WopbsLUTBase,
}
impl IntegerWopbsLUT {
pub fn new(small_lut_size: PlaintextCount, output_ciphertext_count: CiphertextCount) -> Self {
pub fn new(
small_lut_size: PlaintextCount,
output_ciphertext_count: CiphertextCount,
) -> Self {
Self {
inner: WopbsLUTBase::new(small_lut_size, output_ciphertext_count),
}
@@ -204,8 +217,8 @@ impl From<crate::shortint::wopbs::WopbsKey> for WopbsKey {
}
impl WopbsKey {
/// Generates the server key required to compute a WoPBS from the client and the server keys.
/// # Example
/// Generates the server key required to compute a WoPBS from the client and the server
/// keys. # Example
/// ```rust
/// use tfhe::integer::gen_keys_radix;
/// use tfhe::integer::wopbs::*;
@@ -457,7 +470,11 @@ impl WopbsKey {
/// let res = cks.decrypt_native_crt(&ct_res);
/// assert_eq!(res, clear);
/// ```
pub fn wopbs_native_crt(&self, ct1: &CrtCiphertext, lut: &IntegerWopbsLUT) -> CrtCiphertext {
pub fn wopbs_native_crt(
&self,
ct1: &CrtCiphertext,
lut: &IntegerWopbsLUT,
) -> CrtCiphertext {
self.circuit_bootstrap_vertical_packing_native_crt(&[ct1.clone()], lut)
}
@@ -565,7 +582,9 @@ impl WopbsKey {
let decoded_val = decode_radix(&encoded_with_deg_val, basis);
let f_val = f(decoded_val % modulus) % modulus;
let encoded_f_val = encode_radix(f_val, basis, block_nb as u64);
for (lut_number, radix_encoded_val) in encoded_f_val.iter().enumerate().take(block_nb) {
for (lut_number, radix_encoded_val) in
encoded_f_val.iter().enumerate().take(block_nb)
{
lut[lut_number][lut_index_val as usize] = radix_encoded_val * delta;
}
}
@@ -603,7 +622,8 @@ impl WopbsKey {
F: Fn(u64) -> u64,
T: IntegerCiphertext,
{
let log_message_modulus = f64::log2((self.wopbs_key.param.message_modulus.0) as f64) as u64;
let log_message_modulus =
f64::log2((self.wopbs_key.param.message_modulus.0) as f64) as u64;
let log_carry_modulus = f64::log2((self.wopbs_key.param.carry_modulus.0) as f64) as u64;
let log_basis = log_message_modulus + log_carry_modulus;
let delta = 64 - log_basis;
@@ -679,7 +699,8 @@ impl WopbsKey {
} else {
1 << total_bit
};
let mut lut = IntegerWopbsLUT::new(PlaintextCount(lut_size), CiphertextCount(basis.len()));
let mut lut =
IntegerWopbsLUT::new(PlaintextCount(lut_size), CiphertextCount(basis.len()));
for value in 0..modulus {
let mut index_lut = 0;
@@ -738,7 +759,8 @@ impl WopbsKey {
} else {
1 << total_bit
};
let mut lut = IntegerWopbsLUT::new(PlaintextCount(lut_size), CiphertextCount(basis.len()));
let mut lut =
IntegerWopbsLUT::new(PlaintextCount(lut_size), CiphertextCount(basis.len()));
let delta: u64 = (1 << 63)
/ (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0)
@@ -831,7 +853,8 @@ impl WopbsKey {
} else {
1 << total_bit
};
let mut lut = IntegerWopbsLUT::new(PlaintextCount(lut_size), CiphertextCount(basis.len()));
let mut lut =
IntegerWopbsLUT::new(PlaintextCount(lut_size), CiphertextCount(basis.len()));
let basis = ct1.moduli()[0];
let delta: u64 = (1 << 63)
@@ -850,7 +873,9 @@ impl WopbsKey {
}
let f_val = f(decoded_val[0] % modulus, decoded_val[1] % modulus) % modulus;
let encoded_f_val = encode_radix(f_val, basis, block_nb as u64);
for (lut_number, radix_encoded_val) in encoded_f_val.iter().enumerate().take(block_nb) {
for (lut_number, radix_encoded_val) in
encoded_f_val.iter().enumerate().take(block_nb)
{
lut[lut_number][lut_index_val as usize] = radix_encoded_val * delta;
}
}
@@ -917,7 +942,8 @@ impl WopbsKey {
} else {
1 << total_bit
};
let mut lut = IntegerWopbsLUT::new(PlaintextCount(lut_size), CiphertextCount(basis.len()));
let mut lut =
IntegerWopbsLUT::new(PlaintextCount(lut_size), CiphertextCount(basis.len()));
let delta: u64 = (1 << 63)
/ (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0)
@@ -993,7 +1019,8 @@ impl WopbsKey {
} else {
1 << (2 * total_bit)
};
let mut lut = IntegerWopbsLUT::new(PlaintextCount(lut_size), CiphertextCount(basis.len()));
let mut lut =
IntegerWopbsLUT::new(PlaintextCount(lut_size), CiphertextCount(basis.len()));
for value in 0..1 << (2 * total_bit) {
let value_1 = value % (1 << total_bit);
@@ -1093,7 +1120,8 @@ impl WopbsKey {
lwe_ciphertext_plaintext_sub_assign(
&mut block.ct,
Plaintext(
(1 << (64 - nb_bit_to_extract - 1)) - (1 << (64 - nb_bit_to_extract - 5)),
(1 << (64 - nb_bit_to_extract - 1))
- (1 << (64 - nb_bit_to_extract - 5)),
),
);
@@ -1159,3 +1187,4 @@ impl WopbsKey {
T::from_blocks(blocks)
}
}
}

View File

@@ -24,6 +24,7 @@ use std::fmt::Debug;
mod client_side;
mod public_side;
mod server_side;
#[cfg(feature = "experimental")]
mod wopbs;
thread_local! {
@@ -263,8 +264,6 @@ impl std::fmt::Display for EngineError {
}
}
pub(crate) type EngineResult<T> = Result<T, EngineError>;
/// ShortintEngine
///
/// This 'engine' holds the necessary engines from [`core_crypto`](crate::core_crypto)

View File

@@ -2,7 +2,7 @@
use crate::core_crypto::algorithms::*;
use crate::core_crypto::entities::*;
use crate::shortint::ciphertext::{MaxDegree, MaxNoiseLevel};
use crate::shortint::engine::{EngineResult, ShortintEngine};
use crate::shortint::engine::ShortintEngine;
use crate::shortint::server_key::ShortintBootstrappingKey;
use crate::shortint::wopbs::{WopbsKey, WopbsKeyCreationError};
use crate::shortint::{ClientKey, ServerKey, WopbsParameters};
@@ -13,12 +13,15 @@ impl ShortintEngine {
&mut self,
cks: &ClientKey,
sks: &ServerKey,
) -> EngineResult<WopbsKey> {
) -> crate::Result<WopbsKey> {
if matches!(
sks.bootstrapping_key,
ShortintBootstrappingKey::MultiBit { .. }
) {
return Err(WopbsKeyCreationError::UnsupportedMultiBit.into());
return Err(crate::Error::new(format!(
"{}",
WopbsKeyCreationError::UnsupportedMultiBit
)));
}
let wop_params = cks.parameters.wopbs_parameters().unwrap();

View File

@@ -412,6 +412,10 @@ impl Keycache {
}
}
#[cfg(feature = "experimental")]
mod wopbs {
use super::*;
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct WopbsParamPair(pub PBSParameters, pub WopbsParameters);
@@ -463,6 +467,7 @@ impl KeycacheWopbsV0 {
self.inner.clear_in_memory_cache();
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct KeySwitchingKeyParams(
@@ -531,6 +536,10 @@ impl KeycacheKeySwitchingKey {
lazy_static! {
pub static ref KEY_CACHE: Keycache = Keycache::default();
pub static ref KEY_CACHE_WOPBS: KeycacheWopbsV0 = KeycacheWopbsV0::default();
pub static ref KEY_CACHE_KSK: KeycacheKeySwitchingKey = KeycacheKeySwitchingKey::default();
}
#[cfg(feature = "experimental")]
lazy_static! {
pub static ref KEY_CACHE_WOPBS: wopbs::KeycacheWopbsV0 = wopbs::KeycacheWopbsV0::default();
}

View File

@@ -60,7 +60,10 @@ pub mod parameters;
pub mod prelude;
pub mod public_key;
pub mod server_key;
#[cfg(feature = "experimental")]
pub mod wopbs;
#[cfg(not(feature = "experimental"))]
pub(crate) mod wopbs;
pub use ciphertext::{Ciphertext, CompressedCiphertext, PBSOrder};
pub use client_key::ClientKey;

View File

@@ -7,20 +7,45 @@
//! In the case where a padding bit is defined, keys are generated so that there a compatible for
//! both uses.
use crate::core_crypto::entities::*;
use crate::shortint::{ServerKey, WopbsParameters};
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;
use super::backward_compatibility::wopbs::WopbsKeyVersions;
#[cfg(all(test, feature = "experimental"))]
mod test;
// Struct for WoPBS based on the private functional packing keyswitch.
#[derive(Clone, Debug, Serialize, Deserialize, Versionize)]
#[versionize(WopbsKeyVersions)]
pub struct WopbsKey {
//Key for the private functional keyswitch
pub wopbs_server_key: ServerKey,
pub pbs_server_key: ServerKey,
pub cbs_pfpksk: LwePrivateFunctionalPackingKeyswitchKeyListOwned<u64>,
pub ksk_pbs_to_wopbs: LweKeyswitchKeyOwned<u64>,
pub param: WopbsParameters,
}
#[cfg(feature = "experimental")]
pub use experimental::*;
#[cfg(feature = "experimental")]
mod experimental {
use crate::core_crypto::algorithms::*;
use crate::core_crypto::commons::parameters::*;
pub use crate::core_crypto::commons::parameters::{CiphertextCount, PlaintextCount};
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::fft64::math::fft::Fft;
use crate::shortint::ciphertext::*;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::server_key::ShortintBootstrappingKey;
use crate::shortint::{ClientKey, ServerKey, WopbsParameters};
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;
use super::backward_compatibility::wopbs::WopbsKeyVersions;
use super::WopbsKey;
use crate::shortint::{ClientKey, ServerKey, WopbsParameters};
#[derive(Debug)]
pub enum WopbsKeyCreationError {
@@ -39,27 +64,12 @@ impl std::fmt::Display for WopbsKeyCreationError {
}
}
#[cfg(test)]
mod test;
// Struct for WoPBS based on the private functional packing keyswitch.
#[derive(Clone, Debug, Serialize, Deserialize, Versionize)]
#[versionize(WopbsKeyVersions)]
pub struct WopbsKey {
//Key for the private functional keyswitch
pub wopbs_server_key: ServerKey,
pub pbs_server_key: ServerKey,
pub cbs_pfpksk: LwePrivateFunctionalPackingKeyswitchKeyListOwned<u64>,
pub ksk_pbs_to_wopbs: LweKeyswitchKeyOwned<u64>,
pub param: WopbsParameters,
}
#[must_use]
pub struct WopbsLUTBase {
// Flattened Wopbs LUT
plaintext_list: Vec<u64>,
// How many output ciphertexts will be produced after applying the Wopbs to an input vector of
// ciphertexts encrypting bits
// How many output ciphertexts will be produced after applying the Wopbs to an input vector
// of ciphertexts encrypting bits
output_ciphertext_count: CiphertextCount,
}
@@ -71,7 +81,10 @@ impl WopbsLUTBase {
}
}
pub fn new(small_lut_size: PlaintextCount, output_ciphertext_count: CiphertextCount) -> Self {
pub fn new(
small_lut_size: PlaintextCount,
output_ciphertext_count: CiphertextCount,
) -> Self {
Self {
plaintext_list: vec![0; small_lut_size.0 * output_ciphertext_count.0],
output_ciphertext_count,
@@ -254,8 +267,14 @@ impl WopbsKey {
/// let (cks, sks) = gen_keys(PARAM_MESSAGE_1_CARRY_1_KS_PBS);
/// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_1_CARRY_1_KS_PBS);
/// ```
pub fn new_wopbs_key(cks: &ClientKey, sks: &ServerKey, parameters: &WopbsParameters) -> Self {
ShortintEngine::with_thread_local_mut(|engine| engine.new_wopbs_key(cks, sks, parameters))
pub fn new_wopbs_key(
cks: &ClientKey,
sks: &ServerKey,
parameters: &WopbsParameters,
) -> Self {
ShortintEngine::with_thread_local_mut(|engine| {
engine.new_wopbs_key(cks, sks, parameters)
})
}
/// Deconstruct a [`WopbsKey`] into its constituents.
@@ -593,7 +612,9 @@ impl WopbsKey {
// trick ( ct - delta/2 + delta/2^4 )
lwe_ciphertext_plaintext_sub_assign(
&mut ct_in.ct,
Plaintext((1 << (64 - nb_bit_to_extract - 1)) - (1 << (64 - nb_bit_to_extract - 5))),
Plaintext(
(1 << (64 - nb_bit_to_extract - 1)) - (1 << (64 - nb_bit_to_extract - 5)),
),
);
let ciphertext = self.extract_bits_circuit_bootstrapping(
@@ -839,7 +860,8 @@ impl WopbsKey {
count,
self.param.ciphertext_modulus,
);
let lut = PolynomialListView::from_container(lut.as_ref(), fourier_bsk.polynomial_size());
let lut =
PolynomialListView::from_container(lut.as_ref(), fourier_bsk.polynomial_size());
let fft = Fft::new(fourier_bsk.polynomial_size());
let fft = fft.as_view();
@@ -896,8 +918,11 @@ impl WopbsKey {
) -> Ciphertext {
let extracted_bits = self.extract_bits(delta_log, ct_in, nb_bit_to_extract);
let ciphertext_list =
self.circuit_bootstrap_with_bits(&extracted_bits, &lut.lut(), LweCiphertextCount(1));
let ciphertext_list = self.circuit_bootstrap_with_bits(
&extracted_bits,
&lut.lut(),
LweCiphertextCount(1),
);
// Here the output list contains a single ciphertext, we can consume the container to
// convert it to a single ciphertext
@@ -918,3 +943,4 @@ impl WopbsKey {
)
}
}
}