diff --git a/tfhe-benchmark/benches/core_crypto/pbs128_bench.rs b/tfhe-benchmark/benches/core_crypto/pbs128_bench.rs index d55fb3b9c..a3a51e22a 100644 --- a/tfhe-benchmark/benches/core_crypto/pbs128_bench.rs +++ b/tfhe-benchmark/benches/core_crypto/pbs128_bench.rs @@ -23,14 +23,14 @@ fn pbs_128(c: &mut Criterion) { let base_params = BENCH_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; let lwe_dimension = base_params.lwe_dimension; // From PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 - let glwe_dimension = noise_params.glwe_dimension; - let polynomial_size = noise_params.polynomial_size; + let glwe_dimension = noise_params.glwe_dimension(); + let polynomial_size = noise_params.polynomial_size(); let lwe_noise_distribution = base_params.lwe_noise_distribution; - let glwe_noise_distribution = noise_params.glwe_noise_distribution; - let pbs_base_log = noise_params.decomp_base_log; - let pbs_level = noise_params.decomp_level_count; + let glwe_noise_distribution = noise_params.glwe_noise_distribution(); + let pbs_base_log = noise_params.decomp_base_log(); + let pbs_level = noise_params.decomp_level_count(); let input_ciphertext_modulus = base_params.ciphertext_modulus; - let output_ciphertext_modulus = noise_params.ciphertext_modulus; + let output_ciphertext_modulus = noise_params.ciphertext_modulus(); let mut boxed_seeder = new_seeder(); let seeder = boxed_seeder.as_mut(); @@ -179,7 +179,8 @@ mod cuda { use tfhe::core_crypto::prelude::*; use tfhe::shortint::engine::ShortintEngine; use tfhe::shortint::parameters::{ - ModulusSwitchType, NOISE_SQUASHING_PARAM_GPU_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + ModulusSwitchType, NoiseSquashingParameters, + NOISE_SQUASHING_PARAM_GPU_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, }; @@ -196,6 +197,10 @@ mod cuda { let input_params = PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; let squash_params = NOISE_SQUASHING_PARAM_GPU_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + let NoiseSquashingParameters::Classic(squash_params) = squash_params else { + panic!("Multi bit noise squashing PBS currently not supported on GPU"); + }; + let lwe_noise_distribution_u64 = DynamicDistribution::new_t_uniform(46); let ct_modulus_u64: CiphertextModulus = CiphertextModulus::new_native(); @@ -452,8 +457,11 @@ mod cuda { type Scalar = u128; let input_params = PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; - let squash_params = - NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + let NoiseSquashingParameters::MultiBit(squash_params) = + NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 + else { + panic!("Expected Multi bit params") + }; let lwe_noise_distribution_u64 = DynamicDistribution::new_t_uniform(46); let ct_modulus_u64: CiphertextModulus = CiphertextModulus::new_native(); diff --git a/tfhe/examples/utilities/params_to_file.rs b/tfhe/examples/utilities/params_to_file.rs index 4ce9eccd6..f5f6e4e2a 100644 --- a/tfhe/examples/utilities/params_to_file.rs +++ b/tfhe/examples/utilities/params_to_file.rs @@ -193,7 +193,7 @@ impl ParamDetails for NoiseSquashingParameters { } fn glwe_dimension(&self) -> GlweDimension { - self.glwe_dimension + self.glwe_dimension() } fn lwe_noise_distribution(&self) -> DynamicDistribution { @@ -201,11 +201,11 @@ impl ParamDetails for NoiseSquashingParameters { } fn glwe_noise_distribution(&self) -> DynamicDistribution { - self.glwe_noise_distribution + self.glwe_noise_distribution() } fn polynomial_size(&self) -> PolynomialSize { - self.polynomial_size + self.polynomial_size() } fn lwe_ciphertext_modulus(&self) -> ParamModulus { @@ -213,7 +213,7 @@ impl ParamDetails for NoiseSquashingParameters { } fn glwe_ciphertext_modulus(&self) -> ParamModulus { - ParamModulus::from_ciphertext_modulus(self.ciphertext_modulus) + ParamModulus::from_ciphertext_modulus(self.ciphertext_modulus()) } } diff --git a/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping_128.rs b/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping_128.rs index bbaa3ded2..8ab788a58 100644 --- a/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping_128.rs +++ b/tfhe/src/core_crypto/gpu/algorithms/test/lwe_programmable_bootstrapping_128.rs @@ -3,7 +3,7 @@ use crate::shortint::parameters::{ DynamicDistribution, NOISE_SQUASHING_PARAM_GPU_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, }; -use crate::shortint::prelude::{DecompositionBaseLog, LweDimension}; +use crate::shortint::prelude::DecompositionBaseLog; use crate::core_crypto::algorithms::par_allocate_and_generate_new_lwe_bootstrap_key; use crate::core_crypto::algorithms::test::{FftBootstrapKeys, TestResources}; @@ -75,14 +75,18 @@ pub fn execute_bootstrap_u128( squash_params: NoiseSquashingParameters, input_params: MultiBitPBSParameters, ) { + let NoiseSquashingParameters::Classic(squash_params) = squash_params else { + panic!("Multi bit noise squashing PBS currently not supported on GPU"); + }; + let glwe_dimension = squash_params.glwe_dimension; let polynomial_size = squash_params.polynomial_size; let ciphertext_modulus = squash_params.ciphertext_modulus; let mut rsc = TestResources::new(); - let noise_squashing_test_params: NoiseSquashingTestParams = NoiseSquashingTestParams { - lwe_dimension: LweDimension(input_params.lwe_dimension.0), + let noise_squashing_test_params = NoiseSquashingTestParams:: { + lwe_dimension: input_params.lwe_dimension, glwe_dimension: squash_params.glwe_dimension, polynomial_size: squash_params.polynomial_size, lwe_noise_distribution: DynamicDistribution::new_t_uniform(46), diff --git a/tfhe/src/integer/gpu/noise_squashing/noise_squashing_keys.rs b/tfhe/src/integer/gpu/noise_squashing/noise_squashing_keys.rs index acc9f117d..9b92c5866 100644 --- a/tfhe/src/integer/gpu/noise_squashing/noise_squashing_keys.rs +++ b/tfhe/src/integer/gpu/noise_squashing/noise_squashing_keys.rs @@ -2,37 +2,42 @@ use super::keys::CudaNoiseSquashingKey; use crate::core_crypto::gpu::lwe_bootstrap_key::CudaLweBootstrapKey; use crate::core_crypto::gpu::CudaStreams; use crate::integer::noise_squashing::CompressedNoiseSquashingKey; +use crate::shortint::noise_squashing::CompressedShortint128BootstrappingKey; use crate::shortint::server_key::CompressedModulusSwitchConfiguration; impl CompressedNoiseSquashingKey { pub fn decompress_to_cuda(&self, streams: &CudaStreams) -> CudaNoiseSquashingKey { - let std_bsk = self - .key - .bootstrapping_key() - .as_view() - .par_decompress_into_lwe_bootstrap_key(); + if let CompressedShortint128BootstrappingKey::Classic { + bsk, + modulus_switch_noise_reduction_key, + } = self.key.bootstrapping_key() + { + let std_bsk = bsk.as_view().par_decompress_into_lwe_bootstrap_key(); - let ms_noise_reduction_key = match self.key.modulus_switch_noise_reduction_key() { - CompressedModulusSwitchConfiguration::Standard => None, - CompressedModulusSwitchConfiguration::DriftTechniqueNoiseReduction( - modulus_switch_noise_reduction_key, - ) => Some(modulus_switch_noise_reduction_key.decompress()), - CompressedModulusSwitchConfiguration::CenteredMeanNoiseReduction => { - panic!("Centered MS not supportred on GPU") + let ms_noise_reduction_key = match modulus_switch_noise_reduction_key { + CompressedModulusSwitchConfiguration::Standard => None, + CompressedModulusSwitchConfiguration::DriftTechniqueNoiseReduction( + modulus_switch_noise_reduction_key, + ) => Some(modulus_switch_noise_reduction_key.decompress()), + CompressedModulusSwitchConfiguration::CenteredMeanNoiseReduction => { + panic!("Centered MS not supportred on GPU") + } + }; + + let bootstrapping_key = CudaLweBootstrapKey::from_lwe_bootstrap_key( + &std_bsk, + ms_noise_reduction_key.as_ref(), + streams, + ); + + CudaNoiseSquashingKey { + bootstrapping_key, + message_modulus: self.key.message_modulus(), + carry_modulus: self.key.carry_modulus(), + output_ciphertext_modulus: self.key.output_ciphertext_modulus(), } - }; - - let bootstrapping_key = CudaLweBootstrapKey::from_lwe_bootstrap_key( - &std_bsk, - ms_noise_reduction_key.as_ref(), - streams, - ); - - CudaNoiseSquashingKey { - bootstrapping_key, - message_modulus: self.key.message_modulus(), - carry_modulus: self.key.carry_modulus(), - output_ciphertext_modulus: self.key.output_ciphertext_modulus(), + } else { + panic!("Multi bit noise squashing PBS currently not supported on GPU"); } } } diff --git a/tfhe/src/integer/noise_squashing/keys.rs b/tfhe/src/integer/noise_squashing/keys.rs index be7918ff3..93adee438 100644 --- a/tfhe/src/integer/noise_squashing/keys.rs +++ b/tfhe/src/integer/noise_squashing/keys.rs @@ -236,11 +236,11 @@ impl NoiseSquashingPrivateKey { pub fn new(params: NoiseSquashingParameters) -> Self { assert!( - params.carry_modulus.0 >= params.message_modulus.0, + params.carry_modulus().0 >= params.message_modulus().0, "NoiseSquashingPrivateKey requires its CarryModulus {:?} to be greater \ or equal to its MessageModulus {:?}", - params.carry_modulus.0, - params.message_modulus.0, + params.carry_modulus().0, + params.message_modulus().0, ); Self { diff --git a/tfhe/src/shortint/backward_compatibility/noise_squashing.rs b/tfhe/src/shortint/backward_compatibility/noise_squashing.rs index 37bec952a..252b38e4d 100644 --- a/tfhe/src/shortint/backward_compatibility/noise_squashing.rs +++ b/tfhe/src/shortint/backward_compatibility/noise_squashing.rs @@ -2,7 +2,8 @@ use std::convert::Infallible; use crate::core_crypto::prelude::*; use crate::shortint::noise_squashing::{ - CompressedNoiseSquashingKey, NoiseSquashingKey, NoiseSquashingPrivateKey, + CompressedNoiseSquashingKey, CompressedShortint128BootstrappingKey, NoiseSquashingKey, + NoiseSquashingPrivateKey, Shortint128BootstrappingKey, }; use crate::shortint::parameters::CoreCiphertextModulus; use crate::shortint::server_key::{ @@ -26,7 +27,45 @@ pub struct NoiseSquashingKeyV0 { output_ciphertext_modulus: CoreCiphertextModulus, } -impl Upgrade for NoiseSquashingKeyV0 { +impl Upgrade for NoiseSquashingKeyV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + let Self { + bootstrapping_key, + modulus_switch_noise_reduction_key, + message_modulus, + carry_modulus, + output_ciphertext_modulus, + } = self; + + Ok(NoiseSquashingKeyV1 { + bootstrapping_key, + modulus_switch_noise_reduction_key: modulus_switch_noise_reduction_key.map_or( + ModulusSwitchConfiguration::Standard, + |modulus_switch_noise_reduction_key| { + ModulusSwitchConfiguration::DriftTechniqueNoiseReduction( + modulus_switch_noise_reduction_key, + ) + }, + ), + message_modulus, + carry_modulus, + output_ciphertext_modulus, + }) + } +} + +#[derive(Version)] +pub struct NoiseSquashingKeyV1 { + bootstrapping_key: Fourier128LweBootstrapKeyOwned, + modulus_switch_noise_reduction_key: ModulusSwitchConfiguration, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + output_ciphertext_modulus: CoreCiphertextModulus, +} + +impl Upgrade for NoiseSquashingKeyV1 { type Error = Infallible; fn upgrade(self) -> Result { @@ -38,16 +77,13 @@ impl Upgrade for NoiseSquashingKeyV0 { output_ciphertext_modulus, } = self; + let bootstrapping_key = Shortint128BootstrappingKey::Classic { + bsk: bootstrapping_key, + modulus_switch_noise_reduction_key, + }; + Ok(NoiseSquashingKey::from_raw_parts( bootstrapping_key, - modulus_switch_noise_reduction_key.map_or( - ModulusSwitchConfiguration::Standard, - |modulus_switch_noise_reduction_key| { - ModulusSwitchConfiguration::DriftTechniqueNoiseReduction( - modulus_switch_noise_reduction_key, - ) - }, - ), message_modulus, carry_modulus, output_ciphertext_modulus, @@ -58,7 +94,18 @@ impl Upgrade for NoiseSquashingKeyV0 { #[derive(VersionsDispatch)] pub enum NoiseSquashingKeyVersions { V0(NoiseSquashingKeyV0), - V1(NoiseSquashingKey), + V1(NoiseSquashingKeyV1), + V2(NoiseSquashingKey), +} + +#[derive(VersionsDispatch)] +pub enum Shortint128BootstrappingKeyVersions { + V0(Shortint128BootstrappingKey), +} + +#[derive(VersionsDispatch)] +pub enum CompressedShortint128BootstrappingKeyVersions { + V0(CompressedShortint128BootstrappingKey), } #[derive(Version)] @@ -70,7 +117,45 @@ pub struct CompressedNoiseSquashingKeyV0 { output_ciphertext_modulus: CoreCiphertextModulus, } -impl Upgrade for CompressedNoiseSquashingKeyV0 { +impl Upgrade for CompressedNoiseSquashingKeyV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + let Self { + bootstrapping_key, + modulus_switch_noise_reduction_key, + message_modulus, + carry_modulus, + output_ciphertext_modulus, + } = self; + + Ok(CompressedNoiseSquashingKeyV1 { + bootstrapping_key, + modulus_switch_noise_reduction_key: modulus_switch_noise_reduction_key.map_or( + CompressedModulusSwitchConfiguration::Standard, + |modulus_switch_noise_reduction_key| { + CompressedModulusSwitchConfiguration::DriftTechniqueNoiseReduction( + modulus_switch_noise_reduction_key, + ) + }, + ), + message_modulus, + carry_modulus, + output_ciphertext_modulus, + }) + } +} + +#[derive(Version)] +pub struct CompressedNoiseSquashingKeyV1 { + bootstrapping_key: SeededLweBootstrapKeyOwned, + modulus_switch_noise_reduction_key: CompressedModulusSwitchConfiguration, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + output_ciphertext_modulus: CoreCiphertextModulus, +} + +impl Upgrade for CompressedNoiseSquashingKeyV1 { type Error = Infallible; fn upgrade(self) -> Result { @@ -82,16 +167,13 @@ impl Upgrade for CompressedNoiseSquashingKeyV0 { output_ciphertext_modulus, } = self; + let bootstrapping_key = CompressedShortint128BootstrappingKey::Classic { + bsk: bootstrapping_key, + modulus_switch_noise_reduction_key, + }; + Ok(CompressedNoiseSquashingKey::from_raw_parts( bootstrapping_key, - modulus_switch_noise_reduction_key.map_or( - CompressedModulusSwitchConfiguration::Standard, - |modulus_switch_noise_reduction_key| { - CompressedModulusSwitchConfiguration::DriftTechniqueNoiseReduction( - modulus_switch_noise_reduction_key, - ) - }, - ), message_modulus, carry_modulus, output_ciphertext_modulus, @@ -102,5 +184,6 @@ impl Upgrade for CompressedNoiseSquashingKeyV0 { #[derive(VersionsDispatch)] pub enum CompressedNoiseSquashingKeyVersions { V0(CompressedNoiseSquashingKeyV0), - V1(CompressedNoiseSquashingKey), + V1(CompressedNoiseSquashingKeyV1), + V2(CompressedNoiseSquashingKey), } diff --git a/tfhe/src/shortint/backward_compatibility/parameters/noise_squashing.rs b/tfhe/src/shortint/backward_compatibility/parameters/noise_squashing.rs index 3cd79fa62..a40954c93 100644 --- a/tfhe/src/shortint/backward_compatibility/parameters/noise_squashing.rs +++ b/tfhe/src/shortint/backward_compatibility/parameters/noise_squashing.rs @@ -1,6 +1,7 @@ use crate::core_crypto::prelude::*; use crate::shortint::parameters::noise_squashing::{ - NoiseSquashingCompressionParameters, NoiseSquashingMultiBitParameters, NoiseSquashingParameters, + NoiseSquashingClassicParameters, NoiseSquashingCompressionParameters, + NoiseSquashingMultiBitParameters, NoiseSquashingParameters, }; use crate::shortint::parameters::{ CoreCiphertextModulus, ModulusSwitchNoiseReductionParams, ModulusSwitchType, @@ -22,10 +23,10 @@ pub struct NoiseSquashingParametersV0 { pub ciphertext_modulus: CoreCiphertextModulus, } -impl Upgrade for NoiseSquashingParametersV0 { +impl Upgrade for NoiseSquashingParametersV0 { type Error = Infallible; - fn upgrade(self) -> Result { + fn upgrade(self) -> Result { let Self { glwe_dimension, polynomial_size, @@ -38,7 +39,7 @@ impl Upgrade for NoiseSquashingParametersV0 { ciphertext_modulus, } = self; - Ok(NoiseSquashingParameters { + Ok(NoiseSquashingParametersV1 { glwe_dimension, polynomial_size, glwe_noise_distribution, @@ -59,18 +60,69 @@ impl Upgrade for NoiseSquashingParametersV0 { } } -#[derive(VersionsDispatch)] -pub enum NoiseSquashingParametersVersions { - V0(NoiseSquashingParametersV0), - V1(NoiseSquashingParameters), +#[derive(Version)] +pub struct NoiseSquashingParametersV1 { + pub glwe_dimension: GlweDimension, + pub polynomial_size: PolynomialSize, + pub glwe_noise_distribution: DynamicDistribution, + pub decomp_base_log: DecompositionBaseLog, + pub decomp_level_count: DecompositionLevelCount, + pub modulus_switch_noise_reduction_params: ModulusSwitchType, + pub message_modulus: MessageModulus, + pub carry_modulus: CarryModulus, + pub ciphertext_modulus: CoreCiphertextModulus, +} + +impl Upgrade for NoiseSquashingParametersV1 { + type Error = Infallible; + + fn upgrade(self) -> Result { + let Self { + glwe_dimension, + polynomial_size, + glwe_noise_distribution, + decomp_base_log, + decomp_level_count, + modulus_switch_noise_reduction_params, + message_modulus, + carry_modulus, + ciphertext_modulus, + } = self; + + Ok(NoiseSquashingParameters::Classic( + NoiseSquashingClassicParameters { + glwe_dimension, + polynomial_size, + glwe_noise_distribution, + decomp_base_log, + decomp_level_count, + modulus_switch_noise_reduction_params, + message_modulus, + carry_modulus, + ciphertext_modulus, + }, + )) + } } #[derive(VersionsDispatch)] -pub enum NoiseSquashingCompressionParametersVersions { - V0(NoiseSquashingCompressionParameters), +pub enum NoiseSquashingParametersVersions { + V0(NoiseSquashingParametersV0), + V1(NoiseSquashingParametersV1), + V2(NoiseSquashingParameters), +} + +#[derive(VersionsDispatch)] +pub enum NoiseSquashingClassicParametersVersions { + V0(NoiseSquashingClassicParameters), } #[derive(VersionsDispatch)] pub enum NoiseSquashingMultiBitParametersVersions { V0(NoiseSquashingMultiBitParameters), } + +#[derive(VersionsDispatch)] +pub enum NoiseSquashingCompressionParametersVersions { + V0(NoiseSquashingCompressionParameters), +} diff --git a/tfhe/src/shortint/keycache.rs b/tfhe/src/shortint/keycache.rs index 3d7646e3c..a5b7a927b 100644 --- a/tfhe/src/shortint/keycache.rs +++ b/tfhe/src/shortint/keycache.rs @@ -6,8 +6,6 @@ use crate::keycache::*; #[cfg(tarpaulin)] use crate::shortint::parameters::coverage_parameters::*; use crate::shortint::parameters::current_params::*; -use crate::shortint::parameters::noise_squashing::NoiseSquashingMultiBitParameters; -use crate::shortint::parameters::parameters_wopbs::*; use crate::shortint::parameters::*; use crate::shortint::wopbs::WopbsKey; use crate::shortint::{ClientKey, KeySwitchingKey, ServerKey}; @@ -489,11 +487,7 @@ fn cpke_params_default_name(params: &CompactPublicKeyEncryptionParameters) -> St } named_params_impl!( NoiseSquashingParameters => - V1_3_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, -); - -named_params_impl!( NoiseSquashingMultiBitParameters => - V1_3_NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + V1_3_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,V1_3_NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, ); named_params_impl!( NoiseSquashingCompressionParameters => diff --git a/tfhe/src/shortint/list_compression/server_keys.rs b/tfhe/src/shortint/list_compression/server_keys.rs index f9f43bb43..12c3da08e 100644 --- a/tfhe/src/shortint/list_compression/server_keys.rs +++ b/tfhe/src/shortint/list_compression/server_keys.rs @@ -360,8 +360,8 @@ impl packing_ks_polynomial_size: compression_params.packing_ks_polynomial_size, packing_ks_glwe_dimension: compression_params.packing_ks_glwe_dimension, lwe_per_glwe: compression_params.lwe_per_glwe, - uncompressed_polynomial_size: squashing_params.polynomial_size, - uncompressed_glwe_dimension: squashing_params.glwe_dimension, + uncompressed_polynomial_size: squashing_params.polynomial_size(), + uncompressed_glwe_dimension: squashing_params.glwe_dimension(), cipherext_modulus: compression_params.ciphertext_modulus, } } diff --git a/tfhe/src/shortint/noise_squashing/compressed_server_key.rs b/tfhe/src/shortint/noise_squashing/compressed_server_key.rs index b6cb091b7..b9634b409 100644 --- a/tfhe/src/shortint/noise_squashing/compressed_server_key.rs +++ b/tfhe/src/shortint/noise_squashing/compressed_server_key.rs @@ -3,12 +3,21 @@ use crate::conformance::ParameterSetConformant; use crate::core_crypto::algorithms::lwe_bootstrap_key_conversion::par_convert_standard_lwe_bootstrap_key_to_fourier_128; use crate::core_crypto::algorithms::lwe_bootstrap_key_generation::par_allocate_and_generate_new_seeded_lwe_bootstrap_key; use crate::core_crypto::entities::{Fourier128LweBootstrapKeyOwned, SeededLweBootstrapKeyOwned}; -use crate::shortint::backward_compatibility::noise_squashing::CompressedNoiseSquashingKeyVersions; +use crate::core_crypto::prelude::{ + par_allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key, + par_convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_128, + Fourier128LweMultiBitBootstrapKey, SeededLweMultiBitBootstrapKeyOwned, ThreadCount, +}; +use crate::shortint::backward_compatibility::noise_squashing::{ + CompressedNoiseSquashingKeyVersions, CompressedShortint128BootstrappingKeyVersions, +}; use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey; use crate::shortint::client_key::ClientKey; use crate::shortint::engine::ShortintEngine; +use crate::shortint::noise_squashing::server_key::Shortint128BootstrappingKey; use crate::shortint::parameters::{ CarryModulus, CoreCiphertextModulus, MessageModulus, ModulusSwitchType, + NoiseSquashingParameters, }; use crate::shortint::server_key::{ CompressedModulusSwitchConfiguration, ModulusSwitchNoiseReductionKeyConformanceParams, @@ -16,25 +25,96 @@ use crate::shortint::server_key::{ use serde::{Deserialize, Serialize}; use tfhe_versionable::Versionize; +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)] +#[versionize(CompressedShortint128BootstrappingKeyVersions)] +pub enum CompressedShortint128BootstrappingKey { + Classic { + bsk: SeededLweBootstrapKeyOwned, + modulus_switch_noise_reduction_key: CompressedModulusSwitchConfiguration, + }, + MultiBit { + bsk: SeededLweMultiBitBootstrapKeyOwned, + thread_count: ThreadCount, + deterministic_execution: bool, + }, +} + +impl CompressedShortint128BootstrappingKey { + fn decompress(&self) -> Shortint128BootstrappingKey { + match self { + Self::Classic { + bsk, + modulus_switch_noise_reduction_key, + } => { + let (bootstrapping_key, modulus_switch_noise_reduction_key) = { + let std_bsk = bsk.as_view().par_decompress_into_lwe_bootstrap_key(); + + let mut fbsk = Fourier128LweBootstrapKeyOwned::new( + std_bsk.input_lwe_dimension(), + std_bsk.glwe_size(), + std_bsk.polynomial_size(), + std_bsk.decomposition_base_log(), + std_bsk.decomposition_level_count(), + ); + + par_convert_standard_lwe_bootstrap_key_to_fourier_128(&std_bsk, &mut fbsk); + + (fbsk, modulus_switch_noise_reduction_key.decompress()) + }; + + Shortint128BootstrappingKey::Classic { + bsk: bootstrapping_key, + modulus_switch_noise_reduction_key, + } + } + Self::MultiBit { + bsk, + thread_count, + deterministic_execution, + } => { + let bsk = bsk + .as_view() + .par_decompress_into_lwe_multi_bit_bootstrap_key(); + + let mut fourier_bsk = Fourier128LweMultiBitBootstrapKey::new( + bsk.input_lwe_dimension(), + bsk.glwe_size(), + bsk.polynomial_size(), + bsk.decomposition_base_log(), + bsk.decomposition_level_count(), + bsk.grouping_factor(), + ); + + par_convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_128( + &bsk, + &mut fourier_bsk, + ); + + Shortint128BootstrappingKey::MultiBit { + bsk: fourier_bsk, + thread_count: *thread_count, + deterministic_execution: *deterministic_execution, + } + } + } + } +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)] #[versionize(CompressedNoiseSquashingKeyVersions)] pub struct CompressedNoiseSquashingKey { - bootstrapping_key: SeededLweBootstrapKeyOwned, - modulus_switch_noise_reduction_key: CompressedModulusSwitchConfiguration, + bootstrapping_key: CompressedShortint128BootstrappingKey, message_modulus: MessageModulus, carry_modulus: CarryModulus, output_ciphertext_modulus: CoreCiphertextModulus, } impl CompressedNoiseSquashingKey { - pub fn bootstrapping_key(&self) -> &SeededLweBootstrapKeyOwned { + pub fn bootstrapping_key(&self) -> &CompressedShortint128BootstrappingKey { &self.bootstrapping_key } - - pub fn modulus_switch_noise_reduction_key(&self) -> &CompressedModulusSwitchConfiguration { - &self.modulus_switch_noise_reduction_key - } } + impl ClientKey { pub fn new_compressed_noise_squashing_key( &self, @@ -46,61 +126,88 @@ impl ClientKey { let pbs_parameters = std_cks.parameters; - assert_eq!( - pbs_parameters.message_modulus(), - noise_squashing_private_key - .noise_squashing_parameters() - .message_modulus, - "Mismatched MessageModulus between ClientKey {:?} and NoiseSquashingPrivateKey {:?}.", - pbs_parameters.message_modulus(), - noise_squashing_private_key - .noise_squashing_parameters() - .message_modulus - ); - assert_eq!( - pbs_parameters.carry_modulus(), - noise_squashing_private_key - .noise_squashing_parameters() - .carry_modulus, - "Mismatched CarryModulus between ClientKey {:?} and NoiseSquashingPrivateKey {:?}.", - pbs_parameters.carry_modulus(), - noise_squashing_private_key - .noise_squashing_parameters() - .carry_modulus - ); - let noise_squashing_parameters = noise_squashing_private_key.noise_squashing_parameters(); - let (bootstrapping_key, modulus_switch_noise_reduction_key) = - ShortintEngine::with_thread_local_mut(|engine| { - let seeded_bsk = par_allocate_and_generate_new_seeded_lwe_bootstrap_key( - &std_cks.lwe_secret_key, - noise_squashing_private_key.post_noise_squashing_secret_key(), - noise_squashing_parameters.decomp_base_log, - noise_squashing_parameters.decomp_level_count, - noise_squashing_parameters.glwe_noise_distribution, - noise_squashing_parameters.ciphertext_modulus, - &mut engine.seeder, - ); + assert_eq!( + pbs_parameters.message_modulus(), + noise_squashing_parameters.message_modulus(), + "Mismatched MessageModulus between ClientKey {:?} and NoiseSquashingPrivateKey {:?}.", + pbs_parameters.message_modulus(), + noise_squashing_parameters.message_modulus() + ); + assert_eq!( + pbs_parameters.carry_modulus(), + noise_squashing_parameters.carry_modulus(), + "Mismatched CarryModulus between ClientKey {:?} and NoiseSquashingPrivateKey {:?}.", + pbs_parameters.carry_modulus(), + noise_squashing_parameters.carry_modulus() + ); - let modulus_switch_noise_reduction_key = noise_squashing_parameters - .modulus_switch_noise_reduction_params - .to_compressed_modulus_switch_configuration( + let bootstrapping_key = match noise_squashing_parameters { + NoiseSquashingParameters::Classic(params) => { + ShortintEngine::with_thread_local_mut(|engine| { + let seeded_bsk = par_allocate_and_generate_new_seeded_lwe_bootstrap_key( &std_cks.lwe_secret_key, - pbs_parameters.ciphertext_modulus(), - pbs_parameters.lwe_noise_distribution(), - engine, + noise_squashing_private_key.post_noise_squashing_secret_key(), + params.decomp_base_log, + params.decomp_level_count, + params.glwe_noise_distribution, + params.ciphertext_modulus, + &mut engine.seeder, ); - (seeded_bsk, modulus_switch_noise_reduction_key) - }); + let modulus_switch_noise_reduction_key = params + .modulus_switch_noise_reduction_params + .to_compressed_modulus_switch_configuration( + &std_cks.lwe_secret_key, + pbs_parameters.ciphertext_modulus(), + pbs_parameters.lwe_noise_distribution(), + engine, + ); + + CompressedShortint128BootstrappingKey::Classic { + bsk: seeded_bsk, + modulus_switch_noise_reduction_key, + } + }) + } + NoiseSquashingParameters::MultiBit(params) => { + ShortintEngine::with_thread_local_mut(|engine| { + let seeded_bsk = + par_allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key( + &std_cks.lwe_secret_key, + noise_squashing_private_key.post_noise_squashing_secret_key(), + params.decomp_base_log, + params.decomp_level_count, + params.glwe_noise_distribution, + params.grouping_factor, + params.ciphertext_modulus, + &mut engine.seeder, + ); + + let thread_count = engine.get_thread_count_for_multi_bit_pbs( + std_cks.lwe_secret_key.lwe_dimension(), + params.glwe_dimension, + params.polynomial_size, + params.decomp_base_log, + params.decomp_level_count, + params.grouping_factor, + ); + + CompressedShortint128BootstrappingKey::MultiBit { + bsk: seeded_bsk, + thread_count, + deterministic_execution: params.deterministic_execution, + } + }) + } + }; CompressedNoiseSquashingKey { bootstrapping_key, - modulus_switch_noise_reduction_key, - output_ciphertext_modulus: noise_squashing_parameters.ciphertext_modulus, - message_modulus: noise_squashing_parameters.message_modulus, - carry_modulus: noise_squashing_parameters.carry_modulus, + output_ciphertext_modulus: noise_squashing_parameters.ciphertext_modulus(), + message_modulus: noise_squashing_parameters.message_modulus(), + carry_modulus: noise_squashing_parameters.carry_modulus(), } } } @@ -114,15 +221,13 @@ impl CompressedNoiseSquashingKey { } pub fn from_raw_parts( - bootstrapping_key: SeededLweBootstrapKeyOwned, - modulus_switch_noise_reduction_key: CompressedModulusSwitchConfiguration, + bootstrapping_key: CompressedShortint128BootstrappingKey, message_modulus: MessageModulus, carry_modulus: CarryModulus, output_ciphertext_modulus: CoreCiphertextModulus, ) -> Self { Self { bootstrapping_key, - modulus_switch_noise_reduction_key, message_modulus, carry_modulus, output_ciphertext_modulus, @@ -130,28 +235,8 @@ impl CompressedNoiseSquashingKey { } pub fn decompress(&self) -> NoiseSquashingKey { - let (bootstrapping_key, modulus_switch_noise_reduction_key) = { - let std_bsk = self - .bootstrapping_key - .as_view() - .par_decompress_into_lwe_bootstrap_key(); - - let mut fbsk = Fourier128LweBootstrapKeyOwned::new( - std_bsk.input_lwe_dimension(), - std_bsk.glwe_size(), - std_bsk.polynomial_size(), - std_bsk.decomposition_base_log(), - std_bsk.decomposition_level_count(), - ); - - par_convert_standard_lwe_bootstrap_key_to_fourier_128(&std_bsk, &mut fbsk); - - (fbsk, self.modulus_switch_noise_reduction_key.decompress()) - }; - NoiseSquashingKey::from_raw_parts( - bootstrapping_key, - modulus_switch_noise_reduction_key, + self.bootstrapping_key.decompress(), self.message_modulus, self.carry_modulus, self.output_ciphertext_modulus, @@ -177,47 +262,76 @@ impl ParameterSetConformant for CompressedNoiseSquashingKey { fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool { let Self { bootstrapping_key, - modulus_switch_noise_reduction_key, message_modulus, carry_modulus, output_ciphertext_modulus, } = self; - let Self::ParameterSet { - bootstrapping_key_params: expected_bootstrapping_key_params, - modulus_switch_noise_reduction_params: expected_modulus_switch_noise_reduction_params, - message_modulus: expected_message_modulus, - carry_modulus: expected_carry_modulus, - } = parameter_set; - - let modulus_switch_key_ok = match ( - modulus_switch_noise_reduction_key, - expected_modulus_switch_noise_reduction_params, - ) { - (CompressedModulusSwitchConfiguration::Standard, ModulusSwitchType::Standard) => true, + match (bootstrapping_key, parameter_set) { ( - CompressedModulusSwitchConfiguration::CenteredMeanNoiseReduction, - ModulusSwitchType::CenteredMeanNoiseReduction, - ) => true, - ( - CompressedModulusSwitchConfiguration::DriftTechniqueNoiseReduction(key), - ModulusSwitchType::DriftTechniqueNoiseReduction(params), + CompressedShortint128BootstrappingKey::Classic { + bsk, + modulus_switch_noise_reduction_key, + }, + NoiseSquashingKeyConformanceParams::Classic { + bootstrapping_key_params: expected_bootstrapping_key_params, + modulus_switch_noise_reduction_params: + expected_modulus_switch_noise_reduction_params, + message_modulus: expected_message_modulus, + carry_modulus: expected_carry_modulus, + }, ) => { - let mod_switch_conformance_params = - ModulusSwitchNoiseReductionKeyConformanceParams { - modulus_switch_noise_reduction_params: *params, - lwe_dimension: bootstrapping_key.input_lwe_dimension(), - }; + let lwe_dimension = bsk.input_lwe_dimension(); - key.is_conformant(&mod_switch_conformance_params) + let modulus_switch_key_ok = match ( + modulus_switch_noise_reduction_key, + expected_modulus_switch_noise_reduction_params, + ) { + ( + CompressedModulusSwitchConfiguration::Standard, + ModulusSwitchType::Standard, + ) => true, + ( + CompressedModulusSwitchConfiguration::CenteredMeanNoiseReduction, + ModulusSwitchType::CenteredMeanNoiseReduction, + ) => true, + ( + CompressedModulusSwitchConfiguration::DriftTechniqueNoiseReduction(key), + ModulusSwitchType::DriftTechniqueNoiseReduction(params), + ) => { + let mod_switch_conformance_params = + ModulusSwitchNoiseReductionKeyConformanceParams { + modulus_switch_noise_reduction_params: *params, + lwe_dimension, + }; + + key.is_conformant(&mod_switch_conformance_params) + } + (_, _) => false, + }; + + modulus_switch_key_ok + && bsk.is_conformant(expected_bootstrapping_key_params) + && *output_ciphertext_modulus + == expected_bootstrapping_key_params.ciphertext_modulus + && *message_modulus == *expected_message_modulus + && *carry_modulus == *expected_carry_modulus } - (_, _) => false, - }; - - modulus_switch_key_ok - && bootstrapping_key.is_conformant(expected_bootstrapping_key_params) - && *output_ciphertext_modulus == expected_bootstrapping_key_params.ciphertext_modulus - && *message_modulus == *expected_message_modulus - && *carry_modulus == *expected_carry_modulus + ( + CompressedShortint128BootstrappingKey::MultiBit { bsk, .. }, + NoiseSquashingKeyConformanceParams::MultiBit { + bootstrapping_key_params: expected_bootstrapping_key_params, + message_modulus: expected_message_modulus, + carry_modulus: expected_carry_modulus, + }, + ) => { + bsk.is_conformant(expected_bootstrapping_key_params) + && *output_ciphertext_modulus + == expected_bootstrapping_key_params.ciphertext_modulus + && *message_modulus == *expected_message_modulus + && *carry_modulus == *expected_carry_modulus + } + _ => false, + } } } diff --git a/tfhe/src/shortint/noise_squashing/mod.rs b/tfhe/src/shortint/noise_squashing/mod.rs index 6e9fa5a37..ed4178580 100644 --- a/tfhe/src/shortint/noise_squashing/mod.rs +++ b/tfhe/src/shortint/noise_squashing/mod.rs @@ -4,7 +4,11 @@ mod server_key; #[cfg(test)] pub mod tests; -pub use compressed_server_key::CompressedNoiseSquashingKey; +pub use compressed_server_key::{ + CompressedNoiseSquashingKey, CompressedShortint128BootstrappingKey, +}; pub use private_key::NoiseSquashingPrivateKey; pub(crate) use private_key::NoiseSquashingPrivateKeyView; -pub use server_key::{NoiseSquashingKey, NoiseSquashingKeyConformanceParams}; +pub use server_key::{ + NoiseSquashingKey, NoiseSquashingKeyConformanceParams, Shortint128BootstrappingKey, +}; diff --git a/tfhe/src/shortint/noise_squashing/private_key.rs b/tfhe/src/shortint/noise_squashing/private_key.rs index 8cbf354f6..18226d875 100644 --- a/tfhe/src/shortint/noise_squashing/private_key.rs +++ b/tfhe/src/shortint/noise_squashing/private_key.rs @@ -22,8 +22,8 @@ impl NoiseSquashingPrivateKey { pub fn new(params: NoiseSquashingParameters) -> Self { let post_noise_squashing_secret_key = ShortintEngine::with_thread_local_mut(|engine| { allocate_and_generate_new_binary_glwe_secret_key( - params.glwe_dimension, - params.polynomial_size, + params.glwe_dimension(), + params.polynomial_size(), &mut engine.secret_generator, ) }); @@ -56,11 +56,11 @@ impl NoiseSquashingPrivateKey { ) -> Self { assert_eq!( post_noise_squashing_secret_key.polynomial_size(), - params.polynomial_size + params.polynomial_size() ); assert_eq!( post_noise_squashing_secret_key.glwe_dimension(), - params.glwe_dimension + params.glwe_dimension() ); Self { post_noise_squashing_secret_key, @@ -83,9 +83,9 @@ impl<'a> From<&'a NoiseSquashingPrivateKey> for NoiseSquashingPrivateKeyView<'a> Self { post_noise_squashing_secret_key: &value.post_noise_squashing_secret_key, encoding: ShortintEncoding { - ciphertext_modulus: value.params.ciphertext_modulus, - message_modulus: value.params.message_modulus, - carry_modulus: value.params.carry_modulus, + ciphertext_modulus: value.params.ciphertext_modulus(), + message_modulus: value.params.message_modulus(), + carry_modulus: value.params.carry_modulus(), padding_bit: PaddingBit::Yes, }, } diff --git a/tfhe/src/shortint/noise_squashing/server_key.rs b/tfhe/src/shortint/noise_squashing/server_key.rs index 69b2825a0..425a6fbc3 100644 --- a/tfhe/src/shortint/noise_squashing/server_key.rs +++ b/tfhe/src/shortint/noise_squashing/server_key.rs @@ -11,8 +11,17 @@ use crate::core_crypto::algorithms::lwe_programmable_bootstrapping::{ use crate::core_crypto::entities::{Fourier128LweBootstrapKeyOwned, LweCiphertext}; use crate::core_crypto::fft_impl::fft128::math::fft::Fft128; use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::LweBootstrapKeyConformanceParams; +use crate::core_crypto::prelude::fft128_lwe_multi_bit_bootstrap_key::Fourier128LweMultiBitBootstrapKeyOwned; +use crate::core_crypto::prelude::{ + multi_bit_programmable_bootstrap_f128_lwe_ciphertext, + par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key, + par_convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_128, GlweSize, + MultiBitBootstrapKeyConformanceParams, ThreadCount, +}; use crate::shortint::atomic_pattern::{AtomicPattern, AtomicPatternParameters}; -use crate::shortint::backward_compatibility::noise_squashing::NoiseSquashingKeyVersions; +use crate::shortint::backward_compatibility::noise_squashing::{ + NoiseSquashingKeyVersions, Shortint128BootstrappingKeyVersions, +}; use crate::shortint::ciphertext::{Ciphertext, SquashedNoiseCiphertext}; use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey; use crate::shortint::client_key::ClientKey; @@ -22,6 +31,7 @@ use crate::shortint::parameters::noise_squashing::NoiseSquashingParameters; use crate::shortint::parameters::{ CarryModulus, CoreCiphertextModulus, MessageModulus, ModulusSwitchType, PBSOrder, PBSParameters, }; +use crate::shortint::prelude::{LweDimension, PolynomialSize}; use crate::shortint::server_key::{ ModulusSwitchConfiguration, ModulusSwitchNoiseReductionKeyConformanceParams, ServerKey, StandardServerKeyView, @@ -29,11 +39,47 @@ use crate::shortint::server_key::{ use serde::{Deserialize, Serialize}; use tfhe_versionable::Versionize; +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)] +#[versionize(Shortint128BootstrappingKeyVersions)] +pub enum Shortint128BootstrappingKey { + Classic { + bsk: Fourier128LweBootstrapKeyOwned, + modulus_switch_noise_reduction_key: ModulusSwitchConfiguration, + }, + MultiBit { + bsk: Fourier128LweMultiBitBootstrapKeyOwned, + thread_count: ThreadCount, + deterministic_execution: bool, + }, +} + +impl Shortint128BootstrappingKey { + fn output_lwe_dimension(&self) -> LweDimension { + match self { + Self::Classic { bsk, .. } => bsk.output_lwe_dimension(), + Self::MultiBit { bsk, .. } => bsk.output_lwe_dimension(), + } + } + + fn glwe_size(&self) -> GlweSize { + match self { + Self::Classic { bsk, .. } => bsk.glwe_size(), + Self::MultiBit { bsk, .. } => bsk.glwe_size(), + } + } + + fn polynomial_size(&self) -> PolynomialSize { + match self { + Self::Classic { bsk, .. } => bsk.polynomial_size(), + Self::MultiBit { bsk, .. } => bsk.polynomial_size(), + } + } +} + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)] #[versionize(NoiseSquashingKeyVersions)] pub struct NoiseSquashingKey { - bootstrapping_key: Fourier128LweBootstrapKeyOwned, - modulus_switch_noise_reduction_key: ModulusSwitchConfiguration, + bootstrapping_key: Shortint128BootstrappingKey, message_modulus: MessageModulus, carry_modulus: CarryModulus, output_ciphertext_modulus: CoreCiphertextModulus, @@ -50,71 +96,110 @@ impl ClientKey { let pbs_parameters = std_cks.parameters; - assert_eq!( - pbs_parameters.message_modulus(), - noise_squashing_private_key - .noise_squashing_parameters() - .message_modulus, - "Incompatible MessageModulus ClientKey {:?}, NoiseSquashingPrivateKey {:?}.", - pbs_parameters.message_modulus(), - noise_squashing_private_key - .noise_squashing_parameters() - .message_modulus, - ); - assert_eq!( - pbs_parameters.carry_modulus(), - noise_squashing_private_key - .noise_squashing_parameters() - .carry_modulus, - "Incompatible CarryModulus ClientKey {:?}, NoiseSquashingPrivateKey {:?}", - pbs_parameters.carry_modulus(), - noise_squashing_private_key - .noise_squashing_parameters() - .carry_modulus, - ); - let noise_squashing_parameters = noise_squashing_private_key.noise_squashing_parameters(); - let (bootstrapping_key, modulus_switch_noise_reduction_key) = - ShortintEngine::with_thread_local_mut(|engine| { - let std_bsk = par_allocate_and_generate_new_lwe_bootstrap_key( - &std_cks.lwe_secret_key, - noise_squashing_private_key.post_noise_squashing_secret_key(), - noise_squashing_parameters.decomp_base_log, - noise_squashing_parameters.decomp_level_count, - noise_squashing_parameters.glwe_noise_distribution, - noise_squashing_parameters.ciphertext_modulus, - &mut engine.encryption_generator, - ); + assert_eq!( + pbs_parameters.message_modulus(), + noise_squashing_parameters.message_modulus(), + "Incompatible MessageModulus ClientKey {:?}, NoiseSquashingPrivateKey {:?}.", + pbs_parameters.message_modulus(), + noise_squashing_parameters.message_modulus(), + ); + assert_eq!( + pbs_parameters.carry_modulus(), + noise_squashing_parameters.carry_modulus(), + "Incompatible CarryModulus ClientKey {:?}, NoiseSquashingPrivateKey {:?}", + pbs_parameters.carry_modulus(), + noise_squashing_parameters.carry_modulus(), + ); - let mut fbsk = Fourier128LweBootstrapKeyOwned::new( - std_bsk.input_lwe_dimension(), - std_bsk.glwe_size(), - std_bsk.polynomial_size(), - std_bsk.decomposition_base_log(), - std_bsk.decomposition_level_count(), - ); - - par_convert_standard_lwe_bootstrap_key_to_fourier_128(&std_bsk, &mut fbsk); - - let modulus_switch_noise_reduction_key = noise_squashing_parameters - .modulus_switch_noise_reduction_params - .to_modulus_switch_configuration( + let bootstrapping_key = match noise_squashing_parameters { + NoiseSquashingParameters::Classic(params) => { + ShortintEngine::with_thread_local_mut(|engine| { + let std_bsk = par_allocate_and_generate_new_lwe_bootstrap_key( &std_cks.lwe_secret_key, - pbs_parameters.ciphertext_modulus(), - pbs_parameters.lwe_noise_distribution(), - engine, + noise_squashing_private_key.post_noise_squashing_secret_key(), + params.decomp_base_log, + params.decomp_level_count, + params.glwe_noise_distribution, + params.ciphertext_modulus, + &mut engine.encryption_generator, ); - (fbsk, modulus_switch_noise_reduction_key) - }); + let mut fbsk = Fourier128LweBootstrapKeyOwned::new( + std_bsk.input_lwe_dimension(), + std_bsk.glwe_size(), + std_bsk.polynomial_size(), + std_bsk.decomposition_base_log(), + std_bsk.decomposition_level_count(), + ); + + par_convert_standard_lwe_bootstrap_key_to_fourier_128(&std_bsk, &mut fbsk); + + let modulus_switch_noise_reduction_key = params + .modulus_switch_noise_reduction_params + .to_modulus_switch_configuration( + &std_cks.lwe_secret_key, + pbs_parameters.ciphertext_modulus(), + pbs_parameters.lwe_noise_distribution(), + engine, + ); + + Shortint128BootstrappingKey::Classic { + bsk: fbsk, + modulus_switch_noise_reduction_key, + } + }) + } + NoiseSquashingParameters::MultiBit(params) => { + ShortintEngine::with_thread_local_mut(|engine| { + let std_bsk = par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key( + &std_cks.lwe_secret_key, + noise_squashing_private_key.post_noise_squashing_secret_key(), + params.decomp_base_log, + params.decomp_level_count, + params.grouping_factor, + params.glwe_noise_distribution, + params.ciphertext_modulus, + &mut engine.encryption_generator, + ); + + let mut fbsk = Fourier128LweMultiBitBootstrapKeyOwned::new( + std_bsk.input_lwe_dimension(), + std_bsk.glwe_size(), + std_bsk.polynomial_size(), + std_bsk.decomposition_base_log(), + std_bsk.decomposition_level_count(), + std_bsk.grouping_factor(), + ); + + par_convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_128( + &std_bsk, &mut fbsk, + ); + + let thread_count = engine.get_thread_count_for_multi_bit_pbs( + std_cks.lwe_secret_key.lwe_dimension(), + params.glwe_dimension, + params.polynomial_size, + params.decomp_base_log, + params.decomp_level_count, + params.grouping_factor, + ); + + Shortint128BootstrappingKey::MultiBit { + bsk: fbsk, + thread_count, + deterministic_execution: params.deterministic_execution, + } + }) + } + }; NoiseSquashingKey { bootstrapping_key, - modulus_switch_noise_reduction_key, - output_ciphertext_modulus: noise_squashing_parameters.ciphertext_modulus, - message_modulus: noise_squashing_parameters.message_modulus, - carry_modulus: noise_squashing_parameters.carry_modulus, + output_ciphertext_modulus: noise_squashing_parameters.ciphertext_modulus(), + message_modulus: noise_squashing_parameters.message_modulus(), + carry_modulus: noise_squashing_parameters.carry_modulus(), } } } @@ -128,15 +213,13 @@ impl NoiseSquashingKey { } pub fn from_raw_parts( - bootstrapping_key: Fourier128LweBootstrapKeyOwned, - modulus_switch_noise_reduction_key: ModulusSwitchConfiguration, + bootstrapping_key: Shortint128BootstrappingKey, message_modulus: MessageModulus, carry_modulus: CarryModulus, output_ciphertext_modulus: CoreCiphertextModulus, ) -> Self { Self { bootstrapping_key, - modulus_switch_noise_reduction_key, message_modulus, carry_modulus, output_ciphertext_modulus, @@ -146,15 +229,13 @@ impl NoiseSquashingKey { pub fn into_raw_parts( self, ) -> ( - Fourier128LweBootstrapKeyOwned, - ModulusSwitchConfiguration, + Shortint128BootstrappingKey, MessageModulus, CarryModulus, CoreCiphertextModulus, ) { let Self { bootstrapping_key, - modulus_switch_noise_reduction_key, message_modulus, carry_modulus, output_ciphertext_modulus, @@ -162,7 +243,6 @@ impl NoiseSquashingKey { ( bootstrapping_key, - modulus_switch_noise_reduction_key, message_modulus, carry_modulus, output_ciphertext_modulus, @@ -230,7 +310,7 @@ impl NoiseSquashingKey { ciphertext: &Ciphertext, src_server_key: StandardServerKeyView, ) -> SquashedNoiseCiphertext { - let lwe_before_noise_squashing = match src_server_key.atomic_pattern.pbs_order { + let lwe_before_ms = match src_server_key.atomic_pattern.pbs_order { // Under the big key, first need to keyswitch PBSOrder::KeyswitchBootstrap => { let mut after_ks_ct = LweCiphertext::new( @@ -256,15 +336,6 @@ impl NoiseSquashingKey { PBSOrder::BootstrapKeyswitch => ciphertext.ct.clone(), }; - let br_input_modulus_log = self - .bootstrapping_key - .polynomial_size() - .to_blind_rotation_input_modulus_log(); - - let lwe_ciphertext_to_squash_noise = self - .modulus_switch_noise_reduction_key - .lwe_ciphertext_modulus_switch(&lwe_before_noise_squashing, br_input_modulus_log); - let output_lwe_size = self.bootstrapping_key.output_lwe_dimension().to_lwe_size(); let output_message_modulus = self.message_modulus; let output_carry_modulus = self.carry_modulus; @@ -280,20 +351,6 @@ impl NoiseSquashingKey { let bsk_glwe_size = self.bootstrapping_key.glwe_size(); let bsk_polynomial_size = self.bootstrapping_key.polynomial_size(); - let fft = Fft128::new(bsk_polynomial_size); - let fft = fft.as_view(); - - let mem_requirement = blind_rotate_f128_lwe_ciphertext_mem_optimized_requirement::( - bsk_glwe_size, - bsk_polynomial_size, - fft, - ) - .unwrap() - .try_unaligned_bytes_required() - .unwrap(); - - // CarryModulus set to 1, as the output ciphertext does not have a carry space, mod == 1, - // means carry max == 0 let delta = compute_delta( output_ciphertext_modulus, output_message_modulus, @@ -312,19 +369,58 @@ impl NoiseSquashingKey { |x| x, ); - ShortintEngine::with_thread_local_mut(|engine| { - let buffers = engine.get_computation_buffers(); - buffers.resize(mem_requirement); + match &self.bootstrapping_key { + Shortint128BootstrappingKey::Classic { + bsk, + modulus_switch_noise_reduction_key, + } => { + let bsk_glwe_size = bsk.glwe_size(); + let bsk_polynomial_size = bsk.polynomial_size(); - blind_rotate_f128_lwe_ciphertext_mem_optimized( - &lwe_ciphertext_to_squash_noise, - res.lwe_ciphertext_mut(), - &id_lut, - &self.bootstrapping_key, - fft, - buffers.stack(), - ); - }); + let fft = Fft128::new(bsk_polynomial_size); + let fft = fft.as_view(); + + let mem_requirement = blind_rotate_f128_lwe_ciphertext_mem_optimized_requirement::< + u128, + >(bsk_glwe_size, bsk_polynomial_size, fft) + .unwrap() + .try_unaligned_bytes_required() + .unwrap(); + + let br_input_modulus_log = + bsk.polynomial_size().to_blind_rotation_input_modulus_log(); + let lwe_ciphertext_to_squash_noise = modulus_switch_noise_reduction_key + .lwe_ciphertext_modulus_switch(&lwe_before_ms, br_input_modulus_log); + + ShortintEngine::with_thread_local_mut(|engine| { + let buffers = engine.get_computation_buffers(); + buffers.resize(mem_requirement); + + blind_rotate_f128_lwe_ciphertext_mem_optimized( + &lwe_ciphertext_to_squash_noise, + res.lwe_ciphertext_mut(), + &id_lut, + bsk, + fft, + buffers.stack(), + ); + }); + } + Shortint128BootstrappingKey::MultiBit { + bsk, + thread_count, + deterministic_execution, + } => { + multi_bit_programmable_bootstrap_f128_lwe_ciphertext( + &lwe_before_ms, + res.lwe_ciphertext_mut(), + &id_lut, + bsk, + *thread_count, + *deterministic_execution, + ); + } + } res.set_degree(ciphertext.degree); @@ -345,11 +441,18 @@ impl NoiseSquashingKey { } #[derive(Clone, Copy)] -pub struct NoiseSquashingKeyConformanceParams { - pub bootstrapping_key_params: LweBootstrapKeyConformanceParams, - pub modulus_switch_noise_reduction_params: ModulusSwitchType, - pub message_modulus: MessageModulus, - pub carry_modulus: CarryModulus, +pub enum NoiseSquashingKeyConformanceParams { + Classic { + bootstrapping_key_params: LweBootstrapKeyConformanceParams, + modulus_switch_noise_reduction_params: ModulusSwitchType, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + }, + MultiBit { + bootstrapping_key_params: MultiBitBootstrapKeyConformanceParams, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + }, } impl TryFrom<(PBSParameters, NoiseSquashingParameters)> for NoiseSquashingKeyConformanceParams { @@ -358,34 +461,48 @@ impl TryFrom<(PBSParameters, NoiseSquashingParameters)> for NoiseSquashingKeyCon fn try_from( (pbs_params, noise_squashing_params): (PBSParameters, NoiseSquashingParameters), ) -> Result { - if pbs_params.message_modulus() != noise_squashing_params.message_modulus - || pbs_params.carry_modulus() != noise_squashing_params.carry_modulus + if pbs_params.message_modulus() != noise_squashing_params.message_modulus() + || pbs_params.carry_modulus() != noise_squashing_params.carry_modulus() { return Err(crate::Error::new(format!( "Incompatible MessageModulus (PBS {:?}, NoiseSquashing {:?}) \ or CarryModulus (PBS {:?}, NoiseSquashing {:?}) \ when creating NoiseSquashingKeyConformanceParams", pbs_params.message_modulus(), - noise_squashing_params.message_modulus, + noise_squashing_params.message_modulus(), pbs_params.carry_modulus(), - noise_squashing_params.carry_modulus + noise_squashing_params.carry_modulus() ))); } - Ok(Self { - bootstrapping_key_params: LweBootstrapKeyConformanceParams { - input_lwe_dimension: pbs_params.lwe_dimension(), - output_glwe_size: noise_squashing_params.glwe_dimension.to_glwe_size(), - polynomial_size: noise_squashing_params.polynomial_size, - decomp_base_log: noise_squashing_params.decomp_base_log, - decomp_level_count: noise_squashing_params.decomp_level_count, - ciphertext_modulus: noise_squashing_params.ciphertext_modulus, - }, - modulus_switch_noise_reduction_params: noise_squashing_params - .modulus_switch_noise_reduction_params, - message_modulus: noise_squashing_params.message_modulus, - carry_modulus: noise_squashing_params.carry_modulus, - }) + match noise_squashing_params { + NoiseSquashingParameters::Classic(params) => Ok(Self::Classic { + bootstrapping_key_params: LweBootstrapKeyConformanceParams { + input_lwe_dimension: pbs_params.lwe_dimension(), + output_glwe_size: params.glwe_dimension.to_glwe_size(), + polynomial_size: params.polynomial_size, + decomp_base_log: params.decomp_base_log, + decomp_level_count: params.decomp_level_count, + ciphertext_modulus: params.ciphertext_modulus, + }, + modulus_switch_noise_reduction_params: params.modulus_switch_noise_reduction_params, + message_modulus: params.message_modulus, + carry_modulus: params.carry_modulus, + }), + NoiseSquashingParameters::MultiBit(params) => Ok(Self::MultiBit { + bootstrapping_key_params: MultiBitBootstrapKeyConformanceParams { + input_lwe_dimension: pbs_params.lwe_dimension(), + output_glwe_size: params.glwe_dimension.to_glwe_size(), + polynomial_size: params.polynomial_size, + decomp_base_log: params.decomp_base_log, + decomp_level_count: params.decomp_level_count, + grouping_factor: params.grouping_factor, + ciphertext_modulus: params.ciphertext_modulus, + }, + message_modulus: params.message_modulus, + carry_modulus: params.carry_modulus, + }), + } } } @@ -414,47 +531,73 @@ impl ParameterSetConformant for NoiseSquashingKey { fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool { let Self { bootstrapping_key, - modulus_switch_noise_reduction_key, message_modulus, carry_modulus, output_ciphertext_modulus, } = self; - let Self::ParameterSet { - bootstrapping_key_params: expected_bootstrapping_key_params, - modulus_switch_noise_reduction_params: expected_modulus_switch_noise_reduction_params, - message_modulus: expected_message_modulus, - carry_modulus: expected_carry_modulus, - } = parameter_set; - - let modulus_switch_key_ok = match ( - modulus_switch_noise_reduction_key, - expected_modulus_switch_noise_reduction_params, - ) { - (ModulusSwitchConfiguration::Standard, ModulusSwitchType::Standard) => true, + match (bootstrapping_key, parameter_set) { ( - ModulusSwitchConfiguration::CenteredMeanNoiseReduction, - ModulusSwitchType::CenteredMeanNoiseReduction, - ) => true, - ( - ModulusSwitchConfiguration::DriftTechniqueNoiseReduction(key), - ModulusSwitchType::DriftTechniqueNoiseReduction(params), + Shortint128BootstrappingKey::Classic { + bsk, + modulus_switch_noise_reduction_key, + }, + NoiseSquashingKeyConformanceParams::Classic { + bootstrapping_key_params: expected_bootstrapping_key_params, + modulus_switch_noise_reduction_params: + expected_modulus_switch_noise_reduction_params, + message_modulus: expected_message_modulus, + carry_modulus: expected_carry_modulus, + }, ) => { - let mod_switch_conformance_params = - ModulusSwitchNoiseReductionKeyConformanceParams { - modulus_switch_noise_reduction_params: *params, - lwe_dimension: bootstrapping_key.input_lwe_dimension(), - }; + let lwe_dimension = bsk.input_lwe_dimension(); - key.is_conformant(&mod_switch_conformance_params) + let modulus_switch_key_ok = match ( + modulus_switch_noise_reduction_key, + expected_modulus_switch_noise_reduction_params, + ) { + (ModulusSwitchConfiguration::Standard, ModulusSwitchType::Standard) => true, + ( + ModulusSwitchConfiguration::CenteredMeanNoiseReduction, + ModulusSwitchType::CenteredMeanNoiseReduction, + ) => true, + ( + ModulusSwitchConfiguration::DriftTechniqueNoiseReduction(key), + ModulusSwitchType::DriftTechniqueNoiseReduction(params), + ) => { + let mod_switch_conformance_params = + ModulusSwitchNoiseReductionKeyConformanceParams { + modulus_switch_noise_reduction_params: *params, + lwe_dimension, + }; + + key.is_conformant(&mod_switch_conformance_params) + } + (_, _) => false, + }; + + modulus_switch_key_ok + && bsk.is_conformant(expected_bootstrapping_key_params) + && *output_ciphertext_modulus + == expected_bootstrapping_key_params.ciphertext_modulus + && *message_modulus == *expected_message_modulus + && *carry_modulus == *expected_carry_modulus } - (_, _) => false, - }; - - modulus_switch_key_ok - && bootstrapping_key.is_conformant(expected_bootstrapping_key_params) - && *output_ciphertext_modulus == expected_bootstrapping_key_params.ciphertext_modulus - && *message_modulus == *expected_message_modulus - && *carry_modulus == *expected_carry_modulus + ( + Shortint128BootstrappingKey::MultiBit { bsk, .. }, + NoiseSquashingKeyConformanceParams::MultiBit { + bootstrapping_key_params: expected_bootstrapping_key_params, + message_modulus: expected_message_modulus, + carry_modulus: expected_carry_modulus, + }, + ) => { + bsk.is_conformant(expected_bootstrapping_key_params) + && *output_ciphertext_modulus + == expected_bootstrapping_key_params.ciphertext_modulus + && *message_modulus == *expected_message_modulus + && *carry_modulus == *expected_carry_modulus + } + _ => false, + } } } diff --git a/tfhe/src/shortint/noise_squashing/tests.rs b/tfhe/src/shortint/noise_squashing/tests.rs index da0d1a555..ecfc0eeca 100644 --- a/tfhe/src/shortint/noise_squashing/tests.rs +++ b/tfhe/src/shortint/noise_squashing/tests.rs @@ -7,12 +7,28 @@ use rand::prelude::*; use rand::thread_rng; #[test] -fn test_noise_squashing_ci_run_filter() { - let keycache_entry = KEY_CACHE.get_from_param(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128); - let (cks, sks) = (keycache_entry.client_key(), keycache_entry.server_key()); - let noise_squashing_private_key = NoiseSquashingPrivateKey::new( +fn test_classic_noise_squashing_ci_run_filter() { + test_noise_squashing( + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, ); +} + +#[test] +fn test_multi_bit_noise_squashing_ci_run_filter() { + test_noise_squashing( + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + ); +} + +fn test_noise_squashing( + classic_params: impl Into, + noise_squashing_params: NoiseSquashingParameters, +) { + let keycache_entry = KEY_CACHE.get_from_param(classic_params); + let (cks, sks) = (keycache_entry.client_key(), keycache_entry.server_key()); + let noise_squashing_private_key = NoiseSquashingPrivateKey::new(noise_squashing_params); let decompressed_noise_squashing_key = { let compressed_noise_squashing_key = CompressedNoiseSquashingKey::new(cks, &noise_squashing_private_key); diff --git a/tfhe/src/shortint/parameters/aliases.rs b/tfhe/src/shortint/parameters/aliases.rs index 802ade433..eb5f41ac2 100644 --- a/tfhe/src/shortint/parameters/aliases.rs +++ b/tfhe/src/shortint/parameters/aliases.rs @@ -51,7 +51,6 @@ use super::current_params::{ V1_3_NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, }; use super::NoiseSquashingCompressionParameters; -use crate::shortint::parameters::noise_squashing::NoiseSquashingMultiBitParameters; // Aliases @@ -131,7 +130,7 @@ pub const NOISE_SQUASHING_PARAM_GPU_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: V1_3_NOISE_SQUASHING_PARAM_GPU_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; pub const NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: - NoiseSquashingMultiBitParameters = + NoiseSquashingParameters = V1_3_NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; // GPU 2^-64 diff --git a/tfhe/src/shortint/parameters/noise_squashing.rs b/tfhe/src/shortint/parameters/noise_squashing.rs index 5f79451ce..12cc50370 100644 --- a/tfhe/src/shortint/parameters/noise_squashing.rs +++ b/tfhe/src/shortint/parameters/noise_squashing.rs @@ -10,7 +10,99 @@ use tfhe_versionable::Versionize; #[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)] #[versionize(NoiseSquashingParametersVersions)] -pub struct NoiseSquashingParameters { +pub enum NoiseSquashingParameters { + Classic(NoiseSquashingClassicParameters), + MultiBit(NoiseSquashingMultiBitParameters), +} + +impl NoiseSquashingParameters { + pub fn polynomial_size(&self) -> PolynomialSize { + match self { + Self::Classic(noise_squashing_classic_parameters) => { + noise_squashing_classic_parameters.polynomial_size + } + Self::MultiBit(noise_squashing_multi_bit_parameters) => { + noise_squashing_multi_bit_parameters.polynomial_size + } + } + } + + pub fn glwe_dimension(&self) -> GlweDimension { + match self { + Self::Classic(noise_squashing_classic_parameters) => { + noise_squashing_classic_parameters.glwe_dimension + } + Self::MultiBit(noise_squashing_multi_bit_parameters) => { + noise_squashing_multi_bit_parameters.glwe_dimension + } + } + } + + pub fn message_modulus(&self) -> MessageModulus { + match self { + Self::Classic(noise_squashing_classic_parameters) => { + noise_squashing_classic_parameters.message_modulus + } + Self::MultiBit(noise_squashing_multi_bit_parameters) => { + noise_squashing_multi_bit_parameters.message_modulus + } + } + } + pub fn carry_modulus(&self) -> CarryModulus { + match self { + Self::Classic(noise_squashing_classic_parameters) => { + noise_squashing_classic_parameters.carry_modulus + } + Self::MultiBit(noise_squashing_multi_bit_parameters) => { + noise_squashing_multi_bit_parameters.carry_modulus + } + } + } + pub fn decomp_base_log(&self) -> DecompositionBaseLog { + match self { + Self::Classic(noise_squashing_classic_parameters) => { + noise_squashing_classic_parameters.decomp_base_log + } + Self::MultiBit(noise_squashing_multi_bit_parameters) => { + noise_squashing_multi_bit_parameters.decomp_base_log + } + } + } + pub fn decomp_level_count(&self) -> DecompositionLevelCount { + match self { + Self::Classic(noise_squashing_classic_parameters) => { + noise_squashing_classic_parameters.decomp_level_count + } + Self::MultiBit(noise_squashing_multi_bit_parameters) => { + noise_squashing_multi_bit_parameters.decomp_level_count + } + } + } + pub fn glwe_noise_distribution(&self) -> DynamicDistribution { + match self { + Self::Classic(noise_squashing_classic_parameters) => { + noise_squashing_classic_parameters.glwe_noise_distribution + } + Self::MultiBit(noise_squashing_multi_bit_parameters) => { + noise_squashing_multi_bit_parameters.glwe_noise_distribution + } + } + } + pub fn ciphertext_modulus(&self) -> CoreCiphertextModulus { + match self { + Self::Classic(noise_squashing_classic_parameters) => { + noise_squashing_classic_parameters.ciphertext_modulus + } + Self::MultiBit(noise_squashing_multi_bit_parameters) => { + noise_squashing_multi_bit_parameters.ciphertext_modulus + } + } + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)] +#[versionize(NoiseSquashingClassicParametersVersions)] +pub struct NoiseSquashingClassicParameters { pub glwe_dimension: GlweDimension, pub polynomial_size: PolynomialSize, pub glwe_noise_distribution: DynamicDistribution, @@ -22,20 +114,6 @@ pub struct NoiseSquashingParameters { pub ciphertext_modulus: CoreCiphertextModulus, } -#[derive(Copy, Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize, Versionize)] -#[versionize(NoiseSquashingCompressionParametersVersions)] -pub struct NoiseSquashingCompressionParameters { - pub packing_ks_level: DecompositionLevelCount, - pub packing_ks_base_log: DecompositionBaseLog, - pub packing_ks_polynomial_size: PolynomialSize, - pub packing_ks_glwe_dimension: GlweDimension, - pub lwe_per_glwe: LweCiphertextCount, - pub packing_ks_key_noise_distribution: DynamicDistribution, - pub message_modulus: MessageModulus, - pub carry_modulus: CarryModulus, - pub ciphertext_modulus: CoreCiphertextModulus, -} - #[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)] #[versionize(NoiseSquashingMultiBitParametersVersions)] pub struct NoiseSquashingMultiBitParameters { @@ -48,4 +126,19 @@ pub struct NoiseSquashingMultiBitParameters { pub message_modulus: MessageModulus, pub carry_modulus: CarryModulus, pub ciphertext_modulus: CoreCiphertextModulus, + pub deterministic_execution: bool, +} + +#[derive(Copy, Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize, Versionize)] +#[versionize(NoiseSquashingCompressionParametersVersions)] +pub struct NoiseSquashingCompressionParameters { + pub packing_ks_level: DecompositionLevelCount, + pub packing_ks_base_log: DecompositionBaseLog, + pub packing_ks_polynomial_size: PolynomialSize, + pub packing_ks_glwe_dimension: GlweDimension, + pub lwe_per_glwe: LweCiphertextCount, + pub packing_ks_key_noise_distribution: DynamicDistribution, + pub message_modulus: MessageModulus, + pub carry_modulus: CarryModulus, + pub ciphertext_modulus: CoreCiphertextModulus, } diff --git a/tfhe/src/shortint/parameters/v1_1/noise_squashing/p_fail_2_minus_128/mod.rs b/tfhe/src/shortint/parameters/v1_1/noise_squashing/p_fail_2_minus_128/mod.rs index ec8fefdb6..6999b2764 100644 --- a/tfhe/src/shortint/parameters/v1_1/noise_squashing/p_fail_2_minus_128/mod.rs +++ b/tfhe/src/shortint/parameters/v1_1/noise_squashing/p_fail_2_minus_128/mod.rs @@ -1,3 +1,4 @@ +use crate::shortint::parameters::noise_squashing::NoiseSquashingClassicParameters; use crate::shortint::parameters::{ CarryModulus, CoreCiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, GlweDimension, LweCiphertextCount, MessageModulus, @@ -6,7 +7,7 @@ use crate::shortint::parameters::{ }; pub const V1_1_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: - NoiseSquashingParameters = NoiseSquashingParameters { + NoiseSquashingParameters = NoiseSquashingParameters::Classic(NoiseSquashingClassicParameters { glwe_dimension: GlweDimension(2), polynomial_size: PolynomialSize(2048), glwe_noise_distribution: DynamicDistribution::new_t_uniform(30), @@ -23,4 +24,4 @@ pub const V1_1_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: message_modulus: MessageModulus(4), carry_modulus: CarryModulus(4), ciphertext_modulus: CoreCiphertextModulus::::new_native(), -}; +}); diff --git a/tfhe/src/shortint/parameters/v1_2/noise_squashing/p_fail_2_minus_128/mod.rs b/tfhe/src/shortint/parameters/v1_2/noise_squashing/p_fail_2_minus_128/mod.rs index 04e194146..3b9606232 100644 --- a/tfhe/src/shortint/parameters/v1_2/noise_squashing/p_fail_2_minus_128/mod.rs +++ b/tfhe/src/shortint/parameters/v1_2/noise_squashing/p_fail_2_minus_128/mod.rs @@ -1,3 +1,4 @@ +use crate::shortint::parameters::noise_squashing::NoiseSquashingClassicParameters; use crate::shortint::parameters::{ CarryModulus, CoreCiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, GlweDimension, LweCiphertextCount, MessageModulus, @@ -6,7 +7,7 @@ use crate::shortint::parameters::{ }; pub const V1_2_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: - NoiseSquashingParameters = NoiseSquashingParameters { + NoiseSquashingParameters = NoiseSquashingParameters::Classic(NoiseSquashingClassicParameters { glwe_dimension: GlweDimension(2), polynomial_size: PolynomialSize(2048), glwe_noise_distribution: DynamicDistribution::new_t_uniform(30), @@ -23,4 +24,4 @@ pub const V1_2_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: message_modulus: MessageModulus(4), carry_modulus: CarryModulus(4), ciphertext_modulus: CoreCiphertextModulus::::new_native(), -}; +}); diff --git a/tfhe/src/shortint/parameters/v1_3/mod.rs b/tfhe/src/shortint/parameters/v1_3/mod.rs index 91958138b..d7d8cb214 100644 --- a/tfhe/src/shortint/parameters/v1_3/mod.rs +++ b/tfhe/src/shortint/parameters/v1_3/mod.rs @@ -43,7 +43,6 @@ pub use noise_squashing::p_fail_2_minus_128::*; #[cfg(feature = "hpu")] pub use hpu::*; -use crate::shortint::parameters::noise_squashing::NoiseSquashingMultiBitParameters; use crate::shortint::parameters::{ ClassicPBSParameters, CompactPublicKeyEncryptionParameters, CompressionParameters, KeySwitch32PBSParameters, MultiBitPBSParameters, NoiseSquashingCompressionParameters, @@ -1701,13 +1700,11 @@ pub const VEC_ALL_NOISE_SQUASHING_PARAMETERS: [(&NoiseSquashingParameters, &str) ), ]; -pub const VEC_ALL_NOISE_SQUASHING_MULTI_BIT_PARAMETERS: [( - &NoiseSquashingMultiBitParameters, - &str, -); 1] = [( - &V1_3_NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, - "V1_3_NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128", -)]; +pub const VEC_ALL_NOISE_SQUASHING_MULTI_BIT_PARAMETERS: [(&NoiseSquashingParameters, &str); 1] = + [( + &V1_3_NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + "V1_3_NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128", + )]; /// All [`NoiseSquashingCompressionParameters`] in this module. pub const VEC_ALL_NOISE_SQUASHING_COMPRESSION_PARAMETERS: [( diff --git a/tfhe/src/shortint/parameters/v1_3/noise_squashing/p_fail_2_minus_128/mod.rs b/tfhe/src/shortint/parameters/v1_3/noise_squashing/p_fail_2_minus_128/mod.rs index d78e5ae52..d61f7e82d 100644 --- a/tfhe/src/shortint/parameters/v1_3/noise_squashing/p_fail_2_minus_128/mod.rs +++ b/tfhe/src/shortint/parameters/v1_3/noise_squashing/p_fail_2_minus_128/mod.rs @@ -1,4 +1,6 @@ -use crate::shortint::parameters::noise_squashing::NoiseSquashingMultiBitParameters; +use crate::shortint::parameters::noise_squashing::{ + NoiseSquashingClassicParameters, NoiseSquashingMultiBitParameters, +}; use crate::shortint::parameters::{ CarryModulus, CoreCiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, GlweDimension, LweBskGroupingFactor, LweCiphertextCount, MessageModulus, @@ -8,7 +10,7 @@ use crate::shortint::parameters::{ }; pub const V1_3_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: - NoiseSquashingParameters = NoiseSquashingParameters { + NoiseSquashingParameters = NoiseSquashingParameters::Classic(NoiseSquashingClassicParameters { glwe_dimension: GlweDimension(2), polynomial_size: PolynomialSize(2048), glwe_noise_distribution: DynamicDistribution::new_t_uniform(30), @@ -25,7 +27,7 @@ pub const V1_3_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: message_modulus: MessageModulus(4), carry_modulus: CarryModulus(4), ciphertext_modulus: CoreCiphertextModulus::::new_native(), -}; +}); pub const V1_3_NOISE_SQUASHING_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: NoiseSquashingCompressionParameters = NoiseSquashingCompressionParameters { @@ -41,7 +43,7 @@ pub const V1_3_NOISE_SQUASHING_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M12 }; pub const V1_3_NOISE_SQUASHING_PARAM_GPU_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: - NoiseSquashingParameters = NoiseSquashingParameters { + NoiseSquashingParameters = NoiseSquashingParameters::Classic(NoiseSquashingClassicParameters { glwe_dimension: GlweDimension(1), polynomial_size: PolynomialSize(4096), glwe_noise_distribution: DynamicDistribution::new_t_uniform(30), @@ -58,10 +60,10 @@ pub const V1_3_NOISE_SQUASHING_PARAM_GPU_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 message_modulus: MessageModulus(4), carry_modulus: CarryModulus(4), ciphertext_modulus: CoreCiphertextModulus::::new_native(), -}; +}); pub const V1_3_NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: - NoiseSquashingMultiBitParameters = NoiseSquashingMultiBitParameters { + NoiseSquashingParameters = NoiseSquashingParameters::MultiBit(NoiseSquashingMultiBitParameters { glwe_dimension: GlweDimension(2), polynomial_size: PolynomialSize(2048), glwe_noise_distribution: DynamicDistribution::new_t_uniform(30), @@ -71,4 +73,5 @@ pub const V1_3_NOISE_SQUASHING_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_ message_modulus: MessageModulus(4), carry_modulus: CarryModulus(4), ciphertext_modulus: CoreCiphertextModulus::::new_native(), -}; + deterministic_execution: false, +});