mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
refactor(shortint): use ShortintBootstrappingKey in DecompressionKey
This commit is contained in:
@@ -15,7 +15,6 @@ use crate::shortint::atomic_pattern::{
|
||||
AtomicPatternServerKey, KS32AtomicPatternServerKey, StandardAtomicPatternServerKey,
|
||||
};
|
||||
|
||||
use crate::shortint::list_compression::CompressedDecompressionKey;
|
||||
use crate::shortint::noise_squashing::atomic_pattern::ks32::KS32AtomicPatternNoiseSquashingKey;
|
||||
use crate::shortint::noise_squashing::atomic_pattern::standard::StandardAtomicPatternNoiseSquashingKey;
|
||||
use crate::shortint::noise_squashing::atomic_pattern::AtomicPatternNoiseSquashingKey;
|
||||
@@ -402,8 +401,12 @@ impl integer::compression_keys::CompressedDecompressionKey {
|
||||
);
|
||||
|
||||
Self {
|
||||
key: shortint::list_compression::CompressedDecompressionKey::Classic {
|
||||
blind_rotate_key: core_bsk,
|
||||
key: crate::shortint::list_compression::CompressedDecompressionKey {
|
||||
bsk: ShortintCompressedBootstrappingKey::Classic {
|
||||
bsk: core_bsk,
|
||||
modulus_switch_noise_reduction_key:
|
||||
CompressedModulusSwitchConfiguration::Standard,
|
||||
},
|
||||
lwe_per_glwe: compression_params.lwe_per_glwe(),
|
||||
},
|
||||
}
|
||||
@@ -416,52 +419,16 @@ impl integer::compression_keys::CompressedDecompressionKey {
|
||||
where
|
||||
Gen: ByteRandomGenerator + ParallelByteRandomGenerator,
|
||||
{
|
||||
match &self.key {
|
||||
CompressedDecompressionKey::Classic {
|
||||
blind_rotate_key,
|
||||
let crate::shortint::list_compression::CompressedDecompressionKey {
|
||||
ref bsk,
|
||||
lwe_per_glwe,
|
||||
} = self.key;
|
||||
|
||||
integer::compression_keys::DecompressionKey {
|
||||
key: crate::shortint::list_compression::DecompressionKey {
|
||||
bsk: bsk.decompress_with_pre_seeded_generator(generator),
|
||||
lwe_per_glwe,
|
||||
} => {
|
||||
let core_fourier_bsk =
|
||||
par_decompress_bootstrap_key_to_fourier_with_pre_seeded_generator(
|
||||
blind_rotate_key,
|
||||
generator,
|
||||
);
|
||||
|
||||
integer::compression_keys::DecompressionKey {
|
||||
key: shortint::list_compression::DecompressionKey::Classic {
|
||||
blind_rotate_key: core_fourier_bsk,
|
||||
lwe_per_glwe: *lwe_per_glwe,
|
||||
},
|
||||
}
|
||||
}
|
||||
CompressedDecompressionKey::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
lwe_per_glwe,
|
||||
} => {
|
||||
let core_fourier_bsk =
|
||||
par_decompress_seeded_lwe_multi_bit_bootstrap_key_to_fourier_with_pre_seeded_generator(
|
||||
multi_bit_blind_rotate_key,
|
||||
generator,
|
||||
);
|
||||
|
||||
let thread_count =
|
||||
shortint::engine::ShortintEngine::get_thread_count_for_multi_bit_pbs(
|
||||
core_fourier_bsk.input_lwe_dimension(),
|
||||
core_fourier_bsk.glwe_size().to_glwe_dimension(),
|
||||
core_fourier_bsk.polynomial_size(),
|
||||
core_fourier_bsk.decomposition_base_log(),
|
||||
core_fourier_bsk.decomposition_level_count(),
|
||||
core_fourier_bsk.grouping_factor(),
|
||||
);
|
||||
|
||||
integer::compression_keys::DecompressionKey {
|
||||
key: shortint::list_compression::DecompressionKey::MultiBit {
|
||||
multi_bit_blind_rotate_key: core_fourier_bsk,
|
||||
lwe_per_glwe: *lwe_per_glwe,
|
||||
thread_count,
|
||||
},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,9 @@ use crate::integer::gpu::list_compression::server_keys::{
|
||||
CudaCompressionKey, CudaDecompressionKey,
|
||||
};
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::shortint::server_key::{
|
||||
CompressedModulusSwitchConfiguration, ShortintCompressedBootstrappingKey,
|
||||
};
|
||||
use crate::shortint::{CarryModulus, MessageModulus};
|
||||
|
||||
impl CompressedDecompressionKey {
|
||||
@@ -21,14 +24,22 @@ impl CompressedDecompressionKey {
|
||||
ciphertext_modulus: CiphertextModulus<u64>,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaDecompressionKey {
|
||||
match &self.key {
|
||||
crate::shortint::list_compression::CompressedDecompressionKey::Classic {
|
||||
blind_rotate_key,
|
||||
lwe_per_glwe,
|
||||
let crate::shortint::list_compression::CompressedDecompressionKey {
|
||||
ref bsk,
|
||||
lwe_per_glwe,
|
||||
} = self.key;
|
||||
|
||||
match bsk {
|
||||
ShortintCompressedBootstrappingKey::Classic {
|
||||
bsk,
|
||||
modulus_switch_noise_reduction_key,
|
||||
} => {
|
||||
let h_bootstrap_key = blind_rotate_key
|
||||
.as_view()
|
||||
.par_decompress_into_lwe_bootstrap_key();
|
||||
assert_eq!(
|
||||
modulus_switch_noise_reduction_key,
|
||||
&CompressedModulusSwitchConfiguration::Standard
|
||||
);
|
||||
|
||||
let h_bootstrap_key = bsk.as_view().par_decompress_into_lwe_bootstrap_key();
|
||||
|
||||
let d_bootstrap_key =
|
||||
CudaLweBootstrapKey::from_lwe_bootstrap_key(&h_bootstrap_key, None, streams);
|
||||
@@ -37,7 +48,7 @@ impl CompressedDecompressionKey {
|
||||
|
||||
CudaDecompressionKey {
|
||||
blind_rotate_key,
|
||||
lwe_per_glwe: *lwe_per_glwe,
|
||||
lwe_per_glwe,
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
message_modulus,
|
||||
@@ -45,11 +56,11 @@ impl CompressedDecompressionKey {
|
||||
ciphertext_modulus,
|
||||
}
|
||||
}
|
||||
crate::shortint::list_compression::CompressedDecompressionKey::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
lwe_per_glwe,
|
||||
ShortintCompressedBootstrappingKey::MultiBit {
|
||||
seeded_bsk,
|
||||
deterministic_execution: _,
|
||||
} => {
|
||||
let h_bootstrap_key = multi_bit_blind_rotate_key
|
||||
let h_bootstrap_key = seeded_bsk
|
||||
.as_view()
|
||||
.par_decompress_into_lwe_multi_bit_bootstrap_key();
|
||||
|
||||
@@ -62,7 +73,7 @@ impl CompressedDecompressionKey {
|
||||
|
||||
CudaDecompressionKey {
|
||||
blind_rotate_key,
|
||||
lwe_per_glwe: *lwe_per_glwe,
|
||||
lwe_per_glwe,
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
message_modulus,
|
||||
|
||||
@@ -7,7 +7,10 @@ use crate::shortint::list_compression::{
|
||||
NoiseSquashingCompressionPrivateKey,
|
||||
};
|
||||
use crate::shortint::parameters::LweCiphertextCount;
|
||||
use crate::shortint::server_key::ShortintBootstrappingKey;
|
||||
use crate::shortint::server_key::{
|
||||
CompressedModulusSwitchConfiguration, ModulusSwitchConfiguration, ShortintBootstrappingKey,
|
||||
ShortintCompressedBootstrappingKey,
|
||||
};
|
||||
use crate::Error;
|
||||
use tfhe_versionable::deprecation::{Deprecable, Deprecated};
|
||||
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
|
||||
@@ -62,8 +65,11 @@ impl Upgrade<DecompressionKey> for DecompressionKeyV1 {
|
||||
lwe_per_glwe,
|
||||
} = self;
|
||||
|
||||
Ok(DecompressionKey::Classic {
|
||||
blind_rotate_key,
|
||||
Ok(DecompressionKey {
|
||||
bsk: ShortintBootstrappingKey::Classic {
|
||||
bsk: blind_rotate_key,
|
||||
modulus_switch_noise_reduction_key: ModulusSwitchConfiguration::Standard,
|
||||
},
|
||||
lwe_per_glwe,
|
||||
})
|
||||
}
|
||||
@@ -108,8 +114,11 @@ impl Upgrade<CompressedDecompressionKey> for CompressedDecompressionKeyV1 {
|
||||
lwe_per_glwe,
|
||||
} = self;
|
||||
|
||||
Ok(CompressedDecompressionKey::Classic {
|
||||
blind_rotate_key,
|
||||
Ok(CompressedDecompressionKey {
|
||||
bsk: ShortintCompressedBootstrappingKey::Classic {
|
||||
bsk: blind_rotate_key,
|
||||
modulus_switch_noise_reduction_key: CompressedModulusSwitchConfiguration::Standard,
|
||||
},
|
||||
lwe_per_glwe,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -5,17 +5,15 @@ use super::{
|
||||
};
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
use crate::core_crypto::prelude::{
|
||||
par_convert_standard_lwe_bootstrap_key_to_fourier,
|
||||
par_convert_standard_lwe_multi_bit_bootstrap_key_to_fourier, CiphertextModulus,
|
||||
CiphertextModulusLog, FourierLweBootstrapKey, FourierLweMultiBitBootstrapKey, GlweSize,
|
||||
LweCiphertextCount, LwePackingKeyswitchKeyConformanceParams, PolynomialSize,
|
||||
SeededLweBootstrapKeyOwned, SeededLweMultiBitBootstrapKeyOwned, SeededLwePackingKeyswitchKey,
|
||||
CiphertextModulus, CiphertextModulusLog, GlweSize, LweCiphertextCount,
|
||||
LwePackingKeyswitchKeyConformanceParams, PolynomialSize, SeededLwePackingKeyswitchKey,
|
||||
};
|
||||
use crate::shortint::backward_compatibility::list_compression::{
|
||||
CompressedCompressionKeyVersions, CompressedDecompressionKeyVersions,
|
||||
CompressedNoiseSquashingCompressionKeyVersions,
|
||||
};
|
||||
use crate::shortint::client_key::ClientKey;
|
||||
use crate::shortint::server_key::ShortintCompressedBootstrappingKey;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Debug;
|
||||
use tfhe_versionable::Versionize;
|
||||
@@ -45,120 +43,28 @@ impl CompressedCompressionKey {
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Versionize)]
|
||||
#[versionize(CompressedDecompressionKeyVersions)]
|
||||
pub enum CompressedDecompressionKey {
|
||||
Classic {
|
||||
blind_rotate_key: SeededLweBootstrapKeyOwned<u64>,
|
||||
lwe_per_glwe: LweCiphertextCount,
|
||||
},
|
||||
MultiBit {
|
||||
multi_bit_blind_rotate_key: SeededLweMultiBitBootstrapKeyOwned<u64>,
|
||||
lwe_per_glwe: LweCiphertextCount,
|
||||
},
|
||||
pub struct CompressedDecompressionKey {
|
||||
pub(crate) bsk: ShortintCompressedBootstrappingKey<u64>,
|
||||
pub(crate) lwe_per_glwe: LweCiphertextCount,
|
||||
}
|
||||
|
||||
impl CompressedDecompressionKey {
|
||||
pub fn glwe_size(&self) -> GlweSize {
|
||||
match self {
|
||||
Self::Classic {
|
||||
blind_rotate_key, ..
|
||||
} => blind_rotate_key.glwe_size(),
|
||||
Self::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
..
|
||||
} => multi_bit_blind_rotate_key.glwe_size(),
|
||||
}
|
||||
self.bsk.glwe_size()
|
||||
}
|
||||
pub fn polynomial_size(&self) -> PolynomialSize {
|
||||
match self {
|
||||
Self::Classic {
|
||||
blind_rotate_key, ..
|
||||
} => blind_rotate_key.polynomial_size(),
|
||||
Self::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
..
|
||||
} => multi_bit_blind_rotate_key.polynomial_size(),
|
||||
}
|
||||
self.bsk.polynomial_size()
|
||||
}
|
||||
pub fn ciphertext_modulus(&self) -> CiphertextModulus<u64> {
|
||||
match self {
|
||||
Self::Classic {
|
||||
blind_rotate_key, ..
|
||||
} => blind_rotate_key.ciphertext_modulus(),
|
||||
Self::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
..
|
||||
} => multi_bit_blind_rotate_key.ciphertext_modulus(),
|
||||
}
|
||||
self.bsk.ciphertext_modulus()
|
||||
}
|
||||
|
||||
pub fn decompress(&self) -> DecompressionKey {
|
||||
match self {
|
||||
Self::Classic {
|
||||
blind_rotate_key,
|
||||
lwe_per_glwe,
|
||||
} => {
|
||||
let blind_rotate_key = blind_rotate_key
|
||||
.as_view()
|
||||
.par_decompress_into_lwe_bootstrap_key();
|
||||
let bsk = self.bsk.decompress();
|
||||
|
||||
let mut fourier_bsk = FourierLweBootstrapKey::new(
|
||||
blind_rotate_key.input_lwe_dimension(),
|
||||
blind_rotate_key.glwe_size(),
|
||||
blind_rotate_key.polynomial_size(),
|
||||
blind_rotate_key.decomposition_base_log(),
|
||||
blind_rotate_key.decomposition_level_count(),
|
||||
);
|
||||
|
||||
// Conversion to fourier domain
|
||||
par_convert_standard_lwe_bootstrap_key_to_fourier(
|
||||
&blind_rotate_key,
|
||||
&mut fourier_bsk,
|
||||
);
|
||||
|
||||
DecompressionKey::Classic {
|
||||
blind_rotate_key: fourier_bsk,
|
||||
lwe_per_glwe: *lwe_per_glwe,
|
||||
}
|
||||
}
|
||||
Self::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
lwe_per_glwe,
|
||||
} => {
|
||||
let multi_bit_blind_rotate_key = multi_bit_blind_rotate_key
|
||||
.as_view()
|
||||
.par_decompress_into_lwe_multi_bit_bootstrap_key();
|
||||
|
||||
let mut fourier_bsk = FourierLweMultiBitBootstrapKey::new(
|
||||
multi_bit_blind_rotate_key.input_lwe_dimension(),
|
||||
multi_bit_blind_rotate_key.glwe_size(),
|
||||
multi_bit_blind_rotate_key.polynomial_size(),
|
||||
multi_bit_blind_rotate_key.decomposition_base_log(),
|
||||
multi_bit_blind_rotate_key.decomposition_level_count(),
|
||||
multi_bit_blind_rotate_key.grouping_factor(),
|
||||
);
|
||||
|
||||
// Conversion to fourier domain
|
||||
par_convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(
|
||||
&multi_bit_blind_rotate_key,
|
||||
&mut fourier_bsk,
|
||||
);
|
||||
|
||||
let thread_count =
|
||||
crate::shortint::engine::ShortintEngine::get_thread_count_for_multi_bit_pbs(
|
||||
fourier_bsk.input_lwe_dimension(),
|
||||
fourier_bsk.glwe_size().to_glwe_dimension(),
|
||||
fourier_bsk.polynomial_size(),
|
||||
fourier_bsk.decomposition_base_log(),
|
||||
fourier_bsk.decomposition_level_count(),
|
||||
fourier_bsk.grouping_factor(),
|
||||
);
|
||||
|
||||
DecompressionKey::MultiBit {
|
||||
multi_bit_blind_rotate_key: fourier_bsk,
|
||||
lwe_per_glwe: *lwe_per_glwe,
|
||||
thread_count,
|
||||
}
|
||||
}
|
||||
DecompressionKey {
|
||||
bsk,
|
||||
lwe_per_glwe: self.lwe_per_glwe,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -208,30 +114,11 @@ impl ParameterSetConformant for CompressedDecompressionKey {
|
||||
type ParameterSet = CompressionKeyConformanceParams;
|
||||
|
||||
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
|
||||
match self {
|
||||
Self::Classic {
|
||||
blind_rotate_key,
|
||||
lwe_per_glwe,
|
||||
} => {
|
||||
let Ok(params) = parameter_set.try_into() else {
|
||||
return false;
|
||||
};
|
||||
let Self { bsk, lwe_per_glwe } = self;
|
||||
|
||||
blind_rotate_key.is_conformant(¶ms)
|
||||
&& *lwe_per_glwe == parameter_set.lwe_per_glwe
|
||||
}
|
||||
Self::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
lwe_per_glwe,
|
||||
} => {
|
||||
let Ok(params) = parameter_set.try_into() else {
|
||||
return false;
|
||||
};
|
||||
let params = parameter_set.into();
|
||||
|
||||
multi_bit_blind_rotate_key.is_conformant(¶ms)
|
||||
&& *lwe_per_glwe == parameter_set.lwe_per_glwe
|
||||
}
|
||||
}
|
||||
bsk.is_conformant(¶ms) && *lwe_per_glwe == parameter_set.lwe_per_glwe
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,9 +6,8 @@ use crate::shortint::ciphertext::{CompressedCiphertextList, CompressedCiphertext
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::parameters::{CarryModulus, MessageModulus, NoiseLevel};
|
||||
use crate::shortint::server_key::{
|
||||
apply_multi_bit_blind_rotate, apply_standard_blind_rotate,
|
||||
generate_lookup_table_with_output_encoding, unchecked_scalar_mul_assign, LookupTableOwned,
|
||||
LookupTableSize,
|
||||
apply_ms_blind_rotate, generate_lookup_table_with_output_encoding, unchecked_scalar_mul_assign,
|
||||
LookupTableOwned, LookupTableSize,
|
||||
};
|
||||
use crate::shortint::{Ciphertext, MaxNoiseLevel};
|
||||
use rayon::iter::ParallelIterator;
|
||||
@@ -229,47 +228,11 @@ impl DecompressionKey {
|
||||
|
||||
let mut glwe_out = decompression_rescale.acc.clone();
|
||||
|
||||
let log_modulus = self
|
||||
.out_polynomial_size()
|
||||
.to_blind_rotation_input_modulus_log();
|
||||
ShortintEngine::with_thread_local_mut(|engine| {
|
||||
let buffers = engine.get_computation_buffers();
|
||||
|
||||
match self {
|
||||
Self::Classic {
|
||||
blind_rotate_key, ..
|
||||
} => {
|
||||
let msed_lwe =
|
||||
lwe_ciphertext_modulus_switch(intermediate_lwe.as_view(), log_modulus);
|
||||
ShortintEngine::with_thread_local_mut(|engine| {
|
||||
let buffers = engine.get_computation_buffers();
|
||||
|
||||
apply_standard_blind_rotate(
|
||||
blind_rotate_key,
|
||||
&msed_lwe,
|
||||
&mut glwe_out,
|
||||
buffers,
|
||||
);
|
||||
});
|
||||
}
|
||||
Self::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
thread_count,
|
||||
..
|
||||
} => {
|
||||
let multi_bit_msed_lwe = StandardMultiBitModulusSwitchedCt {
|
||||
input: intermediate_lwe.as_view(),
|
||||
log_modulus,
|
||||
grouping_factor: multi_bit_blind_rotate_key.grouping_factor(),
|
||||
};
|
||||
|
||||
apply_multi_bit_blind_rotate(
|
||||
&multi_bit_msed_lwe,
|
||||
&mut glwe_out,
|
||||
multi_bit_blind_rotate_key,
|
||||
*thread_count,
|
||||
true,
|
||||
);
|
||||
}
|
||||
}
|
||||
apply_ms_blind_rotate(&self.bsk, &intermediate_lwe, &mut glwe_out, buffers);
|
||||
});
|
||||
|
||||
let mut output_br = LweCiphertext::new(
|
||||
0,
|
||||
|
||||
@@ -16,6 +16,10 @@ use crate::shortint::client_key::ClientKey;
|
||||
use crate::shortint::engine::ShortintEngine;
|
||||
use crate::shortint::noise_squashing::NoiseSquashingPrivateKeyView;
|
||||
use crate::shortint::parameters::{CompressionParameters, NoiseSquashingCompressionParameters};
|
||||
use crate::shortint::server_key::{
|
||||
CompressedModulusSwitchConfiguration, ModulusSwitchConfiguration, ShortintBootstrappingKey,
|
||||
ShortintCompressedBootstrappingKey,
|
||||
};
|
||||
use crate::shortint::{EncryptionKeyChoice, ShortintParameterSet};
|
||||
use std::fmt::Debug;
|
||||
|
||||
@@ -154,8 +158,11 @@ impl CompressionPrivateKeys {
|
||||
pbs_params.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
DecompressionKey::Classic {
|
||||
blind_rotate_key,
|
||||
DecompressionKey {
|
||||
bsk: ShortintBootstrappingKey::Classic {
|
||||
bsk: blind_rotate_key,
|
||||
modulus_switch_noise_reduction_key: ModulusSwitchConfiguration::Standard,
|
||||
},
|
||||
lwe_per_glwe: classic_compression_parameters.lwe_per_glwe,
|
||||
}
|
||||
}
|
||||
@@ -179,10 +186,13 @@ impl CompressionPrivateKeys {
|
||||
multi_bit_compression_parameters.decompression_grouping_factor,
|
||||
);
|
||||
|
||||
DecompressionKey::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
DecompressionKey {
|
||||
bsk: ShortintBootstrappingKey::MultiBit {
|
||||
fourier_bsk: multi_bit_blind_rotate_key,
|
||||
thread_count,
|
||||
deterministic_execution: true,
|
||||
},
|
||||
lwe_per_glwe: multi_bit_compression_parameters.lwe_per_glwe,
|
||||
thread_count,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -215,8 +225,12 @@ impl CompressionPrivateKeys {
|
||||
)
|
||||
});
|
||||
|
||||
CompressedDecompressionKey::Classic {
|
||||
blind_rotate_key,
|
||||
CompressedDecompressionKey {
|
||||
bsk: ShortintCompressedBootstrappingKey::Classic {
|
||||
bsk: blind_rotate_key,
|
||||
modulus_switch_noise_reduction_key:
|
||||
CompressedModulusSwitchConfiguration::Standard,
|
||||
},
|
||||
lwe_per_glwe: classic_compression_parameters.lwe_per_glwe,
|
||||
}
|
||||
}
|
||||
@@ -234,8 +248,11 @@ impl CompressionPrivateKeys {
|
||||
)
|
||||
});
|
||||
|
||||
CompressedDecompressionKey::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
CompressedDecompressionKey {
|
||||
bsk: ShortintCompressedBootstrappingKey::MultiBit {
|
||||
seeded_bsk: multi_bit_blind_rotate_key,
|
||||
deterministic_execution: true,
|
||||
},
|
||||
lwe_per_glwe: multi_bit_compression_parameters.lwe_per_glwe,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use super::private_key::NoiseSquashingCompressionPrivateKey;
|
||||
use super::CompressionPrivateKeys;
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::LweBootstrapKeyConformanceParams;
|
||||
use crate::core_crypto::prelude::*;
|
||||
use crate::shortint::atomic_pattern::AtomicPatternParameters;
|
||||
use crate::shortint::backward_compatibility::list_compression::{
|
||||
@@ -15,6 +14,10 @@ use crate::shortint::parameters::{
|
||||
CompressionParameters, NoiseSquashingCompressionParameters, NoiseSquashingParameters,
|
||||
PolynomialSize,
|
||||
};
|
||||
use crate::shortint::prelude::ModulusSwitchType;
|
||||
use crate::shortint::server_key::{
|
||||
PBSConformanceParams, PbsTypeConformanceParams, ShortintBootstrappingKey,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Debug;
|
||||
use tfhe_versionable::Versionize;
|
||||
@@ -29,53 +32,22 @@ pub struct CompressionKey {
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
|
||||
#[versionize(DecompressionKeyVersions)]
|
||||
pub enum DecompressionKey {
|
||||
Classic {
|
||||
blind_rotate_key: FourierLweBootstrapKeyOwned,
|
||||
lwe_per_glwe: LweCiphertextCount,
|
||||
},
|
||||
MultiBit {
|
||||
multi_bit_blind_rotate_key: FourierLweMultiBitBootstrapKeyOwned,
|
||||
lwe_per_glwe: LweCiphertextCount,
|
||||
thread_count: ThreadCount,
|
||||
},
|
||||
pub struct DecompressionKey {
|
||||
pub(crate) bsk: ShortintBootstrappingKey<u64>,
|
||||
pub(crate) lwe_per_glwe: LweCiphertextCount,
|
||||
}
|
||||
|
||||
impl DecompressionKey {
|
||||
pub fn out_glwe_size(&self) -> GlweSize {
|
||||
match self {
|
||||
Self::Classic {
|
||||
blind_rotate_key, ..
|
||||
} => blind_rotate_key.glwe_size(),
|
||||
Self::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
..
|
||||
} => multi_bit_blind_rotate_key.glwe_size(),
|
||||
}
|
||||
self.bsk.glwe_size()
|
||||
}
|
||||
|
||||
pub fn out_polynomial_size(&self) -> PolynomialSize {
|
||||
match self {
|
||||
Self::Classic {
|
||||
blind_rotate_key, ..
|
||||
} => blind_rotate_key.polynomial_size(),
|
||||
Self::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
..
|
||||
} => multi_bit_blind_rotate_key.polynomial_size(),
|
||||
}
|
||||
self.bsk.polynomial_size()
|
||||
}
|
||||
|
||||
pub fn output_lwe_dimension(&self) -> LweDimension {
|
||||
match self {
|
||||
Self::Classic {
|
||||
blind_rotate_key, ..
|
||||
} => blind_rotate_key.output_lwe_dimension(),
|
||||
Self::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
..
|
||||
} => multi_bit_blind_rotate_key.output_lwe_dimension(),
|
||||
}
|
||||
self.bsk.output_lwe_dimension()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,39 +165,16 @@ impl ParameterSetConformant for DecompressionKey {
|
||||
type ParameterSet = CompressionKeyConformanceParams;
|
||||
|
||||
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
|
||||
match self {
|
||||
Self::Classic {
|
||||
blind_rotate_key,
|
||||
lwe_per_glwe,
|
||||
} => {
|
||||
let Ok(params) = parameter_set.try_into() else {
|
||||
return false;
|
||||
};
|
||||
let Self { bsk, lwe_per_glwe } = self;
|
||||
|
||||
blind_rotate_key.is_conformant(¶ms)
|
||||
&& *lwe_per_glwe == parameter_set.lwe_per_glwe
|
||||
}
|
||||
Self::MultiBit {
|
||||
multi_bit_blind_rotate_key,
|
||||
lwe_per_glwe,
|
||||
thread_count,
|
||||
} => {
|
||||
let Ok(params) = parameter_set.try_into() else {
|
||||
return false;
|
||||
};
|
||||
let params = parameter_set.into();
|
||||
|
||||
multi_bit_blind_rotate_key.is_conformant(¶ms)
|
||||
&& *lwe_per_glwe == parameter_set.lwe_per_glwe
|
||||
&& thread_count.0 > 0
|
||||
}
|
||||
}
|
||||
*lwe_per_glwe == parameter_set.lwe_per_glwe && bsk.is_conformant(¶ms)
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&CompressionKeyConformanceParams> for LweBootstrapKeyConformanceParams<u64> {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(value: &CompressionKeyConformanceParams) -> Result<Self, String> {
|
||||
impl From<&CompressionKeyConformanceParams> for PBSConformanceParams {
|
||||
fn from(value: &CompressionKeyConformanceParams) -> Self {
|
||||
let CompressionKeyConformanceParams {
|
||||
br_level,
|
||||
br_base_log,
|
||||
@@ -237,51 +186,27 @@ impl TryFrom<&CompressionKeyConformanceParams> for LweBootstrapKeyConformancePar
|
||||
..
|
||||
} = value;
|
||||
|
||||
if decompression_grouping_factor.is_some() {
|
||||
return Err("Expected classic PBS decompression conformance parameters, found multi bit parameters".to_owned());
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
decomp_base_log: *br_base_log,
|
||||
decomp_level_count: *br_level,
|
||||
input_lwe_dimension: packing_ks_glwe_dimension
|
||||
.to_equivalent_lwe_dimension(*packing_ks_polynomial_size),
|
||||
output_glwe_size: uncompressed_glwe_dimension.to_glwe_size(),
|
||||
polynomial_size: *uncompressed_polynomial_size,
|
||||
ciphertext_modulus: value.cipherext_modulus,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<&CompressionKeyConformanceParams> for MultiBitBootstrapKeyConformanceParams<u64> {
|
||||
type Error = String;
|
||||
|
||||
fn try_from(value: &CompressionKeyConformanceParams) -> Result<Self, String> {
|
||||
let CompressionKeyConformanceParams {
|
||||
br_level,
|
||||
br_base_log,
|
||||
packing_ks_polynomial_size,
|
||||
packing_ks_glwe_dimension,
|
||||
uncompressed_polynomial_size,
|
||||
uncompressed_glwe_dimension,
|
||||
decompression_grouping_factor,
|
||||
..
|
||||
} = value;
|
||||
|
||||
let Some(grouping_factor) = decompression_grouping_factor.as_ref() else {
|
||||
return Err("Expected multi bit PBS decompression conformance parameters, found classic parameters".to_owned());
|
||||
#[allow(clippy::option_if_let_else)]
|
||||
let pbs_type = if let Some(grouping_factor) = decompression_grouping_factor.as_ref() {
|
||||
PbsTypeConformanceParams::MultiBit {
|
||||
lwe_bsk_grouping_factor: *grouping_factor,
|
||||
}
|
||||
} else {
|
||||
PbsTypeConformanceParams::Classic {
|
||||
modulus_switch_noise_reduction: ModulusSwitchType::Standard,
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
decomp_base_log: *br_base_log,
|
||||
decomp_level_count: *br_level,
|
||||
input_lwe_dimension: packing_ks_glwe_dimension
|
||||
Self {
|
||||
in_lwe_dimension: packing_ks_glwe_dimension
|
||||
.to_equivalent_lwe_dimension(*packing_ks_polynomial_size),
|
||||
output_glwe_size: uncompressed_glwe_dimension.to_glwe_size(),
|
||||
polynomial_size: *uncompressed_polynomial_size,
|
||||
out_glwe_dimension: *uncompressed_glwe_dimension,
|
||||
out_polynomial_size: *uncompressed_polynomial_size,
|
||||
base_log: *br_base_log,
|
||||
level: *br_level,
|
||||
pbs_type,
|
||||
ciphertext_modulus: value.cipherext_modulus,
|
||||
grouping_factor: *grouping_factor,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1361,18 +1361,20 @@ impl LweClassicFftBootstrap<DynLwe, DynLwe, LookupTable<Vec<u64>>> for Decompres
|
||||
accumulator: &LookupTable<Vec<u64>>,
|
||||
side_resources: &mut Self::SideResources,
|
||||
) {
|
||||
match self {
|
||||
Self::Classic {
|
||||
blind_rotate_key,
|
||||
lwe_per_glwe: _,
|
||||
} => {
|
||||
match (input, output) {
|
||||
(DynLwe::U64(input), DynLwe::U64(output)) => blind_rotate_key
|
||||
.lwe_classic_fft_pbs(input, output, &accumulator.acc, side_resources),
|
||||
_ => panic!("DecompressionKey only supports DynLwe::U64 for noise simulation"),
|
||||
match &self.bsk {
|
||||
ShortintBootstrappingKey::Classic {
|
||||
bsk,
|
||||
modulus_switch_noise_reduction_key: _,
|
||||
} => match (input, output) {
|
||||
(DynLwe::U64(input), DynLwe::U64(output)) => {
|
||||
bsk.lwe_classic_fft_pbs(input, output, &accumulator.acc, side_resources)
|
||||
}
|
||||
}
|
||||
Self::MultiBit { .. } => {
|
||||
_ => panic!(
|
||||
"DecompressionKey only supports DynLwe::U64 for noise
|
||||
simulation"
|
||||
),
|
||||
},
|
||||
ShortintBootstrappingKey::MultiBit { .. } => {
|
||||
panic!("Tried to compute a classic PBS with a multi bit DecompressionKey")
|
||||
}
|
||||
}
|
||||
@@ -1824,12 +1826,12 @@ impl NoiseSimulationLweFourierBsk {
|
||||
}
|
||||
|
||||
pub fn matches_actual_shortint_decomp_key(&self, decomp_key: &DecompressionKey) -> bool {
|
||||
match decomp_key {
|
||||
DecompressionKey::Classic {
|
||||
blind_rotate_key,
|
||||
lwe_per_glwe: _,
|
||||
} => self.matches_actual_bsk(blind_rotate_key),
|
||||
DecompressionKey::MultiBit { .. } => false,
|
||||
match &decomp_key.bsk {
|
||||
ShortintBootstrappingKey::Classic {
|
||||
bsk,
|
||||
modulus_switch_noise_reduction_key: _,
|
||||
} => self.matches_actual_bsk(bsk),
|
||||
ShortintBootstrappingKey::MultiBit { .. } => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5759c7618de2fd921519f145fbde201d6db0994faf6de4864061c94e4ea34f6a
|
||||
size 46269030
|
||||
oid sha256:dbbd1066b91dc53722c66188afca92c8a347104c94ae494d39648bae8b426d6f
|
||||
size 46269035
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e18bde6098e48b111d6230724ec5fb4807da63f4aeca9aab3b8308df8929173a
|
||||
size 52053435
|
||||
oid sha256:0f3fda3335497ea54265726cd91d1f1eb12e91e2b1f21b4b7ae194b1e06eecc0
|
||||
size 52053452
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:697dc1f5878a483964b75d9d2eaf5d364510eab84ad1ea46bdaeabc0a74063ea
|
||||
size 92635750
|
||||
oid sha256:47ecb37b990154fd65d5c404ee2b51930c650b4cae6ea2dccfa6a923c89c9f58
|
||||
size 92635743
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b2b26204f294efc567147baaf20cf5b1d40caceb696ffd5642926ca04afee8bb
|
||||
oid sha256:7b0ff91a905d2bce4a83575ce0070396b1c4ec10b83a3dd301bac06964ef5f92
|
||||
size 108393737
|
||||
|
||||
Reference in New Issue
Block a user