feat(shortint): add CompressedAtomicPatternServerKey

This commit is contained in:
Arthur Meyre
2025-04-28 18:05:07 +02:00
committed by Nicolas Sarlin
parent 7724b7857f
commit 9be9a5d2f4
17 changed files with 918 additions and 329 deletions

View File

@@ -6,6 +6,7 @@ use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::Path;
use tfhe::keycache::NamedParam;
use tfhe::shortint::atomic_pattern::compressed::CompressedAtomicPatternServerKey;
use tfhe::shortint::keycache::KEY_CACHE;
use tfhe::shortint::parameters::current_params::*;
use tfhe::shortint::parameters::*;
@@ -101,9 +102,7 @@ fn client_server_key_sizes(results_file: &Path) {
);
let sks_compressed = CompressedServerKey::new(cks);
let bsk_compressed_size = sks_compressed
.bootstrapping_key
.bootstrapping_key_size_bytes();
let bsk_compressed_size = sks_compressed.bootstrapping_key_size_bytes();
let test_name = format!("shortint_key_sizes_{}_bsk_compressed", params.name());
write_result(&mut file, &test_name, bsk_compressed_size);
@@ -119,9 +118,7 @@ fn client_server_key_sizes(results_file: &Path) {
println!(
"Element in BSK compressed: {}, size in bytes: {}",
sks_compressed
.bootstrapping_key
.bootstrapping_key_size_elements(),
sks_compressed.bootstrapping_key_size_elements(),
bsk_compressed_size,
);
@@ -170,6 +167,15 @@ fn tuniform_key_set_sizes(results_file: &Path) {
let compressed_sks = CompressedServerKey::new(&cks);
let sks = StandardServerKey::try_from(compressed_sks.decompress()).unwrap();
let std_compressed_ap_key = match &compressed_sks.compressed_ap_server_key {
CompressedAtomicPatternServerKey::Standard(
compressed_standard_atomic_pattern_server_key,
) => compressed_standard_atomic_pattern_server_key,
CompressedAtomicPatternServerKey::KeySwitch32(_) => {
panic!("KS32 is unsupported to measure key sizes at the moment")
}
};
measure_serialized_size(
&sks.atomic_pattern.key_switching_key,
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
@@ -179,7 +185,7 @@ fn tuniform_key_set_sizes(results_file: &Path) {
&mut file,
);
measure_serialized_size(
&compressed_sks.key_switching_key,
std_compressed_ap_key.key_switching_key(),
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"ksk_compressed",
@@ -196,7 +202,7 @@ fn tuniform_key_set_sizes(results_file: &Path) {
&mut file,
);
measure_serialized_size(
&compressed_sks.bootstrapping_key,
&std_compressed_ap_key.bootstrapping_key(),
<ClassicPBSParameters as Into<PBSParameters>>::into(param_fhe),
&param_fhe_name,
"bsk_compressed",

View File

@@ -355,7 +355,7 @@ where
/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let input_lwe_secret_key =
/// let input_lwe_secret_key: LweSecretKeyOwned<u64> =
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
/// output_lwe_dimension,
@@ -383,7 +383,8 @@ where
/// assert!(!ksk.as_ref().iter().all(|&x| x == 0));
/// ```
pub fn generate_seeded_lwe_keyswitch_key<
Scalar,
InputScalar,
OutputScalar,
NoiseDistribution,
InputKeyCont,
OutputKeyCont,
@@ -396,11 +397,12 @@ pub fn generate_seeded_lwe_keyswitch_key<
noise_distribution: NoiseDistribution,
noise_seeder: &mut NoiseSeeder,
) where
Scalar: Encryptable<Uniform, NoiseDistribution>,
InputScalar: UnsignedInteger + CastInto<OutputScalar>,
OutputScalar: Encryptable<Uniform, NoiseDistribution>,
NoiseDistribution: Distribution,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
KSKeyCont: ContainerMut<Element = Scalar>,
InputKeyCont: Container<Element = InputScalar>,
OutputKeyCont: Container<Element = OutputScalar>,
KSKeyCont: ContainerMut<Element = OutputScalar>,
// Maybe Sized allows to pass Box<dyn Seeder>.
NoiseSeeder: Seeder + ?Sized,
{
@@ -418,6 +420,13 @@ pub fn generate_seeded_lwe_keyswitch_key<
lwe_keyswitch_key.output_key_lwe_dimension(),
input_lwe_sk.lwe_dimension()
);
assert!(
lwe_keyswitch_key.decomposition_base_log().0
* lwe_keyswitch_key.decomposition_level_count().0
<= OutputScalar::BITS,
"This operation only supports a DecompositionBaseLog and DecompositionLevelCount product \
smaller than the OutputScalar bit count."
);
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
@@ -426,7 +435,7 @@ pub fn generate_seeded_lwe_keyswitch_key<
// The plaintexts used to encrypt a key element will be stored in this buffer
let mut decomposition_plaintexts_buffer =
PlaintextListOwned::new(Scalar::ZERO, PlaintextCount(decomp_level_count.0));
PlaintextListOwned::new(OutputScalar::ZERO, PlaintextCount(decomp_level_count.0));
let mut generator = EncryptionRandomGenerator::<DefaultRandomGenerator>::new(
lwe_keyswitch_key.compression_seed().seed,
@@ -448,9 +457,13 @@ pub fn generate_seeded_lwe_keyswitch_key<
// Here we take the decomposition term from the native torus, bring it to the torus we
// are working with by dividing by the scaling factor and the encryption will take care
// of mapping that back to the native torus
*message.0 = DecompositionTerm::new(level, decomp_base_log, *input_key_element)
.to_recomposition_summand()
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
*message.0 = DecompositionTerm::new(
level,
decomp_base_log,
CastInto::<OutputScalar>::cast_into(*input_key_element),
)
.to_recomposition_summand()
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
}
encrypt_seeded_lwe_ciphertext_list_with_pre_seeded_generator(
@@ -467,7 +480,8 @@ pub fn generate_seeded_lwe_keyswitch_key<
/// keyswitching key constructed from an input and an output key
/// [`LWE secret key`](`LweSecretKey`).
pub fn allocate_and_generate_new_seeded_lwe_keyswitch_key<
Scalar,
InputScalar,
OutputScalar,
NoiseDistribution,
InputKeyCont,
OutputKeyCont,
@@ -478,19 +492,20 @@ pub fn allocate_and_generate_new_seeded_lwe_keyswitch_key<
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
noise_distribution: NoiseDistribution,
ciphertext_modulus: CiphertextModulus<Scalar>,
ciphertext_modulus: CiphertextModulus<OutputScalar>,
noise_seeder: &mut NoiseSeeder,
) -> SeededLweKeyswitchKeyOwned<Scalar>
) -> SeededLweKeyswitchKeyOwned<OutputScalar>
where
Scalar: Encryptable<Uniform, NoiseDistribution>,
InputScalar: UnsignedInteger + CastInto<OutputScalar>,
OutputScalar: Encryptable<Uniform, NoiseDistribution>,
NoiseDistribution: Distribution,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
InputKeyCont: Container<Element = InputScalar>,
OutputKeyCont: Container<Element = OutputScalar>,
// Maybe Sized allows to pass Box<dyn Seeder>.
NoiseSeeder: Seeder + ?Sized,
{
let mut new_lwe_keyswitch_key = SeededLweKeyswitchKeyOwned::new(
Scalar::ZERO,
OutputScalar::ZERO,
decomp_base_log,
decomp_level_count,
input_lwe_sk.lwe_dimension(),

View File

@@ -28,7 +28,7 @@ fn test_seeded_lwe_ksk_gen_equivalence<Scalar: UnsignedTorus + Send + Sync>(
for _ in 0..NB_TESTS {
// Create the LweSecretKey
let input_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
let input_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key::<Scalar, _>(
input_lwe_dimension,
&mut secret_generator,
);

View File

@@ -425,8 +425,10 @@ impl<Scalar: UnsignedInteger, C: ContainerMut<Element = Scalar>> ContiguousEntit
Self: 'this;
}
impl<C: Container<Element = u64>> ParameterSetConformant for SeededLweKeyswitchKey<C> {
type ParameterSet = LweKeyswitchKeyConformanceParams<u64>;
impl<Scalar: UnsignedInteger, C: Container<Element = Scalar>> ParameterSetConformant
for SeededLweKeyswitchKey<C>
{
type ParameterSet = LweKeyswitchKeyConformanceParams<Scalar>;
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
let Self {

View File

@@ -10,6 +10,7 @@ use crate::core_crypto::prelude::{
use crate::integer::gpu::UnsignedInteger;
use crate::integer::server_key::num_bits_to_represent_unsigned_value;
use crate::integer::ClientKey;
use crate::shortint::atomic_pattern::compressed::CompressedAtomicPatternServerKey;
use crate::shortint::ciphertext::{MaxDegree, MaxNoiseLevel};
use crate::shortint::engine::ShortintEngine;
use crate::shortint::server_key::ModulusSwitchNoiseReductionKey;
@@ -215,16 +216,21 @@ impl CudaServerKey {
streams: &CudaStreams,
) -> Self {
let crate::shortint::CompressedServerKey {
key_switching_key,
bootstrapping_key,
compressed_ap_server_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus,
pbs_order,
} = cpu_key.key.clone();
// Generate a regular keyset and convert to the GPU
let CompressedAtomicPatternServerKey::Standard(std_key) = compressed_ap_server_key else {
panic!("Only the standard atomic pattern is supported on GPU")
};
let ciphertext_modulus = std_key.ciphertext_modulus();
let (key_switching_key, bootstrapping_key, pbs_order) = std_key.into_raw_parts();
let h_key_switching_key = key_switching_key.par_decompress_into_lwe_keyswitch_key();
let key_switching_key =
CudaLweKeyswitchKey::from_lwe_keyswitch_key(&h_key_switching_key, streams);

View File

@@ -324,17 +324,12 @@ impl ParameterSetConformant for CompressedServerKey {
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
let Self { key } = self;
let AtomicPatternParameters::Standard(parameters) = *parameter_set else {
// Server key compression is only supported for classical AP
return false;
};
let expected_max_degree = MaxDegree::integer_radix_server_key(
parameters.message_modulus(),
parameters.carry_modulus(),
parameter_set.message_modulus(),
parameter_set.carry_modulus(),
);
key.is_conformant(&(parameters, expected_max_degree))
key.is_conformant(&(*parameter_set, expected_max_degree))
}
}

View File

@@ -0,0 +1,138 @@
use crate::conformance::ParameterSetConformant;
use crate::core_crypto::algorithms::lwe_keyswitch_key_generation::allocate_and_generate_new_seeded_lwe_keyswitch_key;
use crate::core_crypto::entities::lwe_secret_key::LweSecretKey;
use crate::core_crypto::entities::seeded_lwe_keyswitch_key::SeededLweKeyswitchKeyOwned;
use crate::shortint::atomic_pattern::ks32::KS32AtomicPatternServerKey;
use crate::shortint::backward_compatibility::atomic_pattern::CompressedKS32AtomicPatternServerKeyVersions;
use crate::shortint::client_key::ClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::{KeySwitch32PBSParameters, LweDimension};
use crate::shortint::server_key::ShortintCompressedBootstrappingKey;
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;
/// The definition of the compressed server key elements used in the
/// [`KeySwitch32`](crate::shortint::atomic_pattern::AtomicPatternKind::KeySwitch32) atomic pattern
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(CompressedKS32AtomicPatternServerKeyVersions)]
pub struct CompressedKS32AtomicPatternServerKey {
key_switching_key: SeededLweKeyswitchKeyOwned<u32>,
bootstrapping_key: ShortintCompressedBootstrappingKey<u32>,
}
impl CompressedKS32AtomicPatternServerKey {
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
let params = &cks.parameters;
let pbs_params = params.ks32_parameters().unwrap();
let in_key = LweSecretKey::from_container(
cks.small_lwe_secret_key()
.as_ref()
.iter()
.copied()
.map(|x| x as u32)
.collect::<Vec<_>>(),
);
let out_key = &cks.glwe_secret_key;
let bootstrapping_key_base =
engine.new_compressed_bootstrapping_key_ks32(pbs_params, &in_key, out_key);
// Creation of the key switching key
let key_switching_key = allocate_and_generate_new_seeded_lwe_keyswitch_key(
&cks.large_lwe_secret_key(),
&in_key,
params.ks_base_log(),
params.ks_level(),
pbs_params.lwe_noise_distribution(),
pbs_params.post_keyswitch_ciphertext_modulus(),
&mut engine.seeder,
);
Self::from_raw_parts(key_switching_key, bootstrapping_key_base)
}
pub fn from_raw_parts(
key_switching_key: SeededLweKeyswitchKeyOwned<u32>,
bootstrapping_key: ShortintCompressedBootstrappingKey<u32>,
) -> Self {
assert_eq!(
key_switching_key.input_key_lwe_dimension(),
bootstrapping_key.output_lwe_dimension(),
"Mismatch between the input SeededLweKeyswitchKey LweDimension ({:?}) \
and the ShortintCompressedBootstrappingKey output LweDimension ({:?})",
key_switching_key.input_key_lwe_dimension(),
bootstrapping_key.output_lwe_dimension()
);
assert_eq!(
key_switching_key.output_key_lwe_dimension(),
bootstrapping_key.input_lwe_dimension(),
"Mismatch between the output SeededLweKeyswitchKey LweDimension ({:?}) \
and the ShortintCompressedBootstrappingKey input LweDimension ({:?})",
key_switching_key.output_key_lwe_dimension(),
bootstrapping_key.input_lwe_dimension()
);
Self {
key_switching_key,
bootstrapping_key,
}
}
pub fn ciphertext_lwe_dimension(&self) -> LweDimension {
// KS32 is always KeyswitchBootstrap, meaning Ciphertext is under the big LWE secret key
self.key_switching_key.input_key_lwe_dimension()
}
pub fn key_switching_key(&self) -> &SeededLweKeyswitchKeyOwned<u32> {
&self.key_switching_key
}
pub fn bootstrapping_key(&self) -> &ShortintCompressedBootstrappingKey<u32> {
&self.bootstrapping_key
}
pub fn decompress(&self) -> KS32AtomicPatternServerKey {
let Self {
key_switching_key,
bootstrapping_key,
} = self;
let ciphertext_modulus = bootstrapping_key.ciphertext_modulus();
let (key_switching_key, bootstrapping_key) = rayon::join(
|| {
key_switching_key
.as_view()
.par_decompress_into_lwe_keyswitch_key()
},
|| bootstrapping_key.decompress(),
);
KS32AtomicPatternServerKey::from_raw_parts(
key_switching_key,
bootstrapping_key,
ciphertext_modulus,
)
}
}
impl ParameterSetConformant for CompressedKS32AtomicPatternServerKey {
type ParameterSet = KeySwitch32PBSParameters;
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
let Self {
key_switching_key,
bootstrapping_key,
} = self;
let ksk_ok = key_switching_key.is_conformant(&parameter_set.into());
let bsk_ok = bootstrapping_key.is_conformant(&parameter_set.into());
ksk_ok && bsk_ok
}
}

View File

@@ -0,0 +1,95 @@
pub mod ks32;
pub mod standard;
pub use ks32::*;
pub use standard::*;
use super::AtomicPatternServerKey;
use crate::conformance::ParameterSetConformant;
use crate::shortint::backward_compatibility::atomic_pattern::CompressedAtomicPatternServerKeyVersions;
use crate::shortint::client_key::ClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::{AtomicPatternParameters, CiphertextModulus, LweDimension};
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;
/// The server key materials for all the supported Atomic Patterns
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(CompressedAtomicPatternServerKeyVersions)]
#[allow(clippy::large_enum_variant)] // The most common variant should be `Standard` so we optimize for it
pub enum CompressedAtomicPatternServerKey {
Standard(CompressedStandardAtomicPatternServerKey),
KeySwitch32(CompressedKS32AtomicPatternServerKey),
}
impl CompressedAtomicPatternServerKey {
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
let params = &cks.parameters;
match params.ap_parameters().unwrap() {
AtomicPatternParameters::Standard(_) => {
Self::Standard(CompressedStandardAtomicPatternServerKey::new(cks, engine))
}
AtomicPatternParameters::KeySwitch32(_) => {
Self::KeySwitch32(CompressedKS32AtomicPatternServerKey::new(cks, engine))
}
}
}
pub fn ciphertext_lwe_dimension(&self) -> LweDimension {
match self {
Self::Standard(compressed_standard_atomic_pattern_server_key) => {
compressed_standard_atomic_pattern_server_key.ciphertext_lwe_dimension()
}
Self::KeySwitch32(compressed_ks32_atomic_pattern_server_key) => {
compressed_ks32_atomic_pattern_server_key.ciphertext_lwe_dimension()
}
}
}
pub fn ciphertext_modulus(&self) -> CiphertextModulus {
match self {
Self::Standard(compressed_standard_atomic_pattern_server_key) => {
compressed_standard_atomic_pattern_server_key
.bootstrapping_key()
.ciphertext_modulus()
}
Self::KeySwitch32(compressed_ks32_atomic_pattern_server_key) => {
compressed_ks32_atomic_pattern_server_key
.bootstrapping_key()
.ciphertext_modulus()
}
}
}
pub fn decompress(&self) -> AtomicPatternServerKey {
match self {
Self::Standard(compressed_standard_atomic_pattern_server_key) => {
AtomicPatternServerKey::Standard(
compressed_standard_atomic_pattern_server_key.decompress(),
)
}
Self::KeySwitch32(compressed_ks32_atomic_pattern_server_key) => {
AtomicPatternServerKey::KeySwitch32(
compressed_ks32_atomic_pattern_server_key.decompress(),
)
}
}
}
}
impl ParameterSetConformant for CompressedAtomicPatternServerKey {
type ParameterSet = AtomicPatternParameters;
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
match (self, parameter_set) {
(Self::Standard(ap), AtomicPatternParameters::Standard(params)) => {
ap.is_conformant(params)
}
(Self::KeySwitch32(ap), AtomicPatternParameters::KeySwitch32(params)) => {
ap.is_conformant(params)
}
_ => false,
}
}
}

View File

@@ -0,0 +1,175 @@
use crate::conformance::ParameterSetConformant;
use crate::core_crypto::algorithms::lwe_keyswitch_key_generation::allocate_and_generate_new_seeded_lwe_keyswitch_key;
use crate::core_crypto::entities::seeded_lwe_keyswitch_key::SeededLweKeyswitchKeyOwned;
use crate::shortint::atomic_pattern::standard::StandardAtomicPatternServerKey;
use crate::shortint::backward_compatibility::atomic_pattern::CompressedStandardAtomicPatternServerKeyVersions;
use crate::shortint::client_key::ClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::{CiphertextModulus, LweDimension, PBSOrder, PBSParameters};
use crate::shortint::server_key::ShortintCompressedBootstrappingKey;
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;
/// The definition of the compressed server key elements used in the
/// [`Standard`](crate::shortint::atomic_pattern::AtomicPatternKind::Standard) atomic pattern
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(CompressedStandardAtomicPatternServerKeyVersions)]
pub struct CompressedStandardAtomicPatternServerKey {
key_switching_key: SeededLweKeyswitchKeyOwned<u64>,
bootstrapping_key: ShortintCompressedBootstrappingKey<u64>,
pbs_order: PBSOrder,
}
impl CompressedStandardAtomicPatternServerKey {
pub fn new(cks: &ClientKey, engine: &mut ShortintEngine) -> Self {
let params = &cks.parameters;
let pbs_params_base = params.pbs_parameters().unwrap();
let in_key = &cks.small_lwe_secret_key();
let out_key = &cks.glwe_secret_key;
let bootstrapping_key_base =
engine.new_compressed_bootstrapping_key(pbs_params_base, in_key, out_key);
// Creation of the key switching key
let key_switching_key = allocate_and_generate_new_seeded_lwe_keyswitch_key(
&cks.large_lwe_secret_key(),
&cks.small_lwe_secret_key(),
params.ks_base_log(),
params.ks_level(),
params.lwe_noise_distribution(),
params.ciphertext_modulus(),
&mut engine.seeder,
);
Self::from_raw_parts(
key_switching_key,
bootstrapping_key_base,
pbs_params_base.encryption_key_choice().into(),
)
}
pub fn from_raw_parts(
key_switching_key: SeededLweKeyswitchKeyOwned<u64>,
bootstrapping_key: ShortintCompressedBootstrappingKey<u64>,
pbs_order: PBSOrder,
) -> Self {
assert_eq!(
key_switching_key.input_key_lwe_dimension(),
bootstrapping_key.output_lwe_dimension(),
"Mismatch between the input SeededLweKeyswitchKey LweDimension ({:?}) \
and the ShortintCompressedBootstrappingKey output LweDimension ({:?})",
key_switching_key.input_key_lwe_dimension(),
bootstrapping_key.output_lwe_dimension()
);
assert_eq!(
key_switching_key.output_key_lwe_dimension(),
bootstrapping_key.input_lwe_dimension(),
"Mismatch between the output SeededLweKeyswitchKey LweDimension ({:?}) \
and the ShortintCompressedBootstrappingKey input LweDimension ({:?})",
key_switching_key.output_key_lwe_dimension(),
bootstrapping_key.input_lwe_dimension()
);
assert_eq!(
key_switching_key.ciphertext_modulus(),
bootstrapping_key.ciphertext_modulus(),
"Mismatch between the output SeededLweKeyswitchKey CiphertextModulus ({:?}) \
and the ShortintCompressedBootstrappingKey input CiphertextModulus ({:?})",
key_switching_key.ciphertext_modulus(),
bootstrapping_key.ciphertext_modulus(),
);
Self {
key_switching_key,
bootstrapping_key,
pbs_order,
}
}
pub fn into_raw_parts(
self,
) -> (
SeededLweKeyswitchKeyOwned<u64>,
ShortintCompressedBootstrappingKey<u64>,
PBSOrder,
) {
let Self {
key_switching_key,
bootstrapping_key,
pbs_order,
} = self;
(key_switching_key, bootstrapping_key, pbs_order)
}
pub fn ciphertext_lwe_dimension(&self) -> LweDimension {
match self.pbs_order() {
PBSOrder::KeyswitchBootstrap => self.key_switching_key.input_key_lwe_dimension(),
PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_dimension(),
}
}
pub fn key_switching_key(&self) -> &SeededLweKeyswitchKeyOwned<u64> {
&self.key_switching_key
}
pub fn bootstrapping_key(&self) -> &ShortintCompressedBootstrappingKey<u64> {
&self.bootstrapping_key
}
pub fn pbs_order(&self) -> PBSOrder {
self.pbs_order
}
pub fn ciphertext_modulus(&self) -> CiphertextModulus {
self.bootstrapping_key.ciphertext_modulus()
}
pub fn decompress(&self) -> StandardAtomicPatternServerKey {
let Self {
key_switching_key,
bootstrapping_key,
pbs_order,
} = self;
let (key_switching_key, bootstrapping_key) = rayon::join(
|| {
key_switching_key
.as_view()
.par_decompress_into_lwe_keyswitch_key()
},
|| bootstrapping_key.decompress(),
);
StandardAtomicPatternServerKey::from_raw_parts(
key_switching_key,
bootstrapping_key,
*pbs_order,
)
}
}
impl ParameterSetConformant for CompressedStandardAtomicPatternServerKey {
type ParameterSet = PBSParameters;
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
let Self {
key_switching_key,
bootstrapping_key,
pbs_order,
} = self;
let ksk_ok = key_switching_key.is_conformant(&parameter_set.into());
let bsk_ok = bootstrapping_key.is_conformant(&parameter_set.into());
let params_pbs_order: PBSOrder = parameter_set.encryption_key_choice().into();
let pbs_order_ok = *pbs_order == params_pbs_order;
ksk_ok && bsk_ok && pbs_order_ok
}
}

View File

@@ -4,6 +4,7 @@
//! For example, in TFHE the standard atomic pattern is the chain of n linear operations, a
//! Keyswitch and a PBS.
pub mod compressed;
pub mod ks32;
pub mod standard;

View File

@@ -20,9 +20,7 @@ use crate::shortint::server_key::{
decompress_and_apply_lookup_table, switch_modulus_and_compress, LookupTableOwned,
LookupTableSize, ManyLookupTableOwned, ShortintBootstrappingKey,
};
use crate::shortint::{
Ciphertext, CiphertextModulus, ClientKey, EncryptionKeyChoice, PBSOrder, PBSParameters,
};
use crate::shortint::{Ciphertext, CiphertextModulus, ClientKey, PBSOrder, PBSParameters};
/// The definition of the server key elements used in the [`Standard`](AtomicPatternKind::Standard)
/// atomic pattern
@@ -52,11 +50,8 @@ impl ParameterSetConformant for StandardAtomicPatternServerKey {
let ks_key_ok = key_switching_key.is_conformant(&ks_conformance_params);
let pbs_order_ok = matches!(
(*pbs_order, parameter_set.encryption_key_choice()),
(PBSOrder::KeyswitchBootstrap, EncryptionKeyChoice::Big)
| (PBSOrder::BootstrapKeyswitch, EncryptionKeyChoice::Small)
);
let params_pbs_order: PBSOrder = parameter_set.encryption_key_choice().into();
let pbs_order_ok = *pbs_order == params_pbs_order;
pbs_key_ok && ks_key_ok && pbs_order_ok
}

View File

@@ -1,5 +1,8 @@
use tfhe_versionable::VersionsDispatch;
use crate::shortint::atomic_pattern::compressed::ks32::CompressedKS32AtomicPatternServerKey;
use crate::shortint::atomic_pattern::compressed::standard::CompressedStandardAtomicPatternServerKey;
use crate::shortint::atomic_pattern::compressed::CompressedAtomicPatternServerKey;
use crate::shortint::atomic_pattern::{
AtomicPatternServerKey, KS32AtomicPatternServerKey, StandardAtomicPatternServerKey,
};
@@ -29,3 +32,18 @@ pub enum StandardAtomicPatternServerKeyVersions {
pub enum KS32AtomicPatternServerKeyVersions {
V0(KS32AtomicPatternServerKey),
}
#[derive(VersionsDispatch)]
pub enum CompressedAtomicPatternServerKeyVersions {
V0(CompressedAtomicPatternServerKey),
}
#[derive(VersionsDispatch)]
pub enum CompressedStandardAtomicPatternServerKeyVersions {
V0(CompressedStandardAtomicPatternServerKey),
}
#[derive(VersionsDispatch)]
pub enum CompressedKS32AtomicPatternServerKeyVersions {
V0(CompressedKS32AtomicPatternServerKey),
}

View File

@@ -2,6 +2,9 @@ pub mod modulus_switch_noise_reduction;
use crate::core_crypto::entities::*;
use crate::core_crypto::prelude::{Container, PBSOrder, UnsignedInteger};
use crate::shortint::atomic_pattern::compressed::{
CompressedAtomicPatternServerKey, CompressedStandardAtomicPatternServerKey,
};
use crate::shortint::atomic_pattern::{AtomicPatternServerKey, StandardAtomicPatternServerKey};
use crate::shortint::ciphertext::MaxDegree;
use crate::shortint::server_key::*;
@@ -131,7 +134,7 @@ pub enum ServerKeyVersions<AP> {
V2(GenericServerKey<AP>),
}
impl Deprecable for ShortintCompressedBootstrappingKey {
impl<InputScalar: UnsignedInteger> Deprecable for ShortintCompressedBootstrappingKey<InputScalar> {
const TYPE_NAME: &'static str = "ShortintCompressedBootstrappingKey";
const MIN_SUPPORTED_APP_VERSION: &'static str = "TFHE-rs v0.10";
}
@@ -145,10 +148,12 @@ pub enum ShortintCompressedBootstrappingKeyV1 {
},
}
impl Upgrade<ShortintCompressedBootstrappingKey> for ShortintCompressedBootstrappingKeyV1 {
impl<InputScalar: UnsignedInteger> Upgrade<ShortintCompressedBootstrappingKey<InputScalar>>
for ShortintCompressedBootstrappingKeyV1
{
type Error = Infallible;
fn upgrade(self) -> Result<ShortintCompressedBootstrappingKey, Self::Error> {
fn upgrade(self) -> Result<ShortintCompressedBootstrappingKey<InputScalar>, Self::Error> {
Ok(match self {
Self::Classic(seeded_lwe_bootstrap_key) => {
ShortintCompressedBootstrappingKey::Classic {
@@ -168,10 +173,16 @@ impl Upgrade<ShortintCompressedBootstrappingKey> for ShortintCompressedBootstrap
}
#[derive(VersionsDispatch)]
pub enum ShortintCompressedBootstrappingKeyVersions {
V0(Deprecated<ShortintCompressedBootstrappingKey>),
pub enum ShortintCompressedBootstrappingKeyVersions<InputScalar>
where
InputScalar: UnsignedInteger,
{
V0(Deprecated<ShortintCompressedBootstrappingKey<InputScalar>>),
V1(ShortintCompressedBootstrappingKeyV1),
V2(ShortintCompressedBootstrappingKey),
// Here a generic `InputScalar` has been added but it does not requires a new version since it
// is only added through the `CompressedModulusSwitchNoiseReductionKey`, which handles the
// upgrade itself.
V2(ShortintCompressedBootstrappingKey<InputScalar>),
}
impl Deprecable for CompressedServerKey {
@@ -179,9 +190,58 @@ impl Deprecable for CompressedServerKey {
const MIN_SUPPORTED_APP_VERSION: &'static str = "TFHE-rs v0.10";
}
#[derive(Version)]
pub struct CompressedServerKeyV2 {
pub key_switching_key: SeededLweKeyswitchKeyOwned<u64>,
pub bootstrapping_key: ShortintCompressedBootstrappingKey<u64>,
// Size of the message buffer
pub message_modulus: MessageModulus,
// Size of the carry buffer
pub carry_modulus: CarryModulus,
// Maximum number of operations that can be done before emptying the operation buffer
pub max_degree: MaxDegree,
pub max_noise_level: MaxNoiseLevel,
pub ciphertext_modulus: CiphertextModulus,
pub pbs_order: PBSOrder,
}
impl Upgrade<CompressedServerKey> for CompressedServerKeyV2 {
type Error = Infallible;
fn upgrade(self) -> Result<CompressedServerKey, Self::Error> {
let Self {
key_switching_key,
bootstrapping_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus: _, // Ciphertext modulus is on the compressed bootstrapping_key
pbs_order,
} = self;
let compressed_ap_server_key = CompressedAtomicPatternServerKey::Standard(
CompressedStandardAtomicPatternServerKey::from_raw_parts(
key_switching_key,
bootstrapping_key,
pbs_order,
),
);
Ok(CompressedServerKey {
compressed_ap_server_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
})
}
}
#[derive(VersionsDispatch)]
pub enum CompressedServerKeyVersions {
V0(Deprecated<CompressedServerKey>),
V1(Deprecated<CompressedServerKey>),
V2(CompressedServerKey),
V2(CompressedServerKeyV2),
V3(CompressedServerKey),
}

View File

@@ -8,7 +8,7 @@ use crate::shortint::server_key::{
use crate::Error;
use tfhe_versionable::{Upgrade, Version, VersionsDispatch};
use super::LweCiphertextListOwned;
use super::{LweCiphertextListOwned, SeededLweCiphertextListOwned};
#[derive(Version)]
pub struct ModulusSwitchNoiseReductionKeyV0 {
@@ -34,10 +34,11 @@ where
modulus_switch_zeros: modulus_switch_zeros
.downcast_ref::<LweCiphertextListOwned<InputScalar>>()
.ok_or_else(|| {
Error::new(
"Invalid ModulusSwitchNoiseReductionKey, expected scalar size u64"
.to_string(),
)
Error::new(format!(
"Expected u64 as InputScalar while upgrading \
ModulusSwitchNoiseReductionKey, got {}",
std::any::type_name::<InputScalar>(),
))
})?
.clone(),
ms_bound: self.ms_bound,
@@ -56,10 +57,49 @@ where
V1(ModulusSwitchNoiseReductionKey<InputScalar>),
}
#[derive(Version)]
pub struct CompressedModulusSwitchNoiseReductionKeyV0 {
pub modulus_switch_zeros: SeededLweCiphertextListOwned<u64>,
pub ms_bound: NoiseEstimationMeasureBound,
pub ms_r_sigma_factor: RSigmaFactor,
pub ms_input_variance: Variance,
}
impl<InputScalar> Upgrade<CompressedModulusSwitchNoiseReductionKey<InputScalar>>
for CompressedModulusSwitchNoiseReductionKeyV0
where
InputScalar: UnsignedInteger,
{
type Error = Error;
fn upgrade(self) -> Result<CompressedModulusSwitchNoiseReductionKey<InputScalar>, Self::Error> {
let modulus_switch_zeros = &self.modulus_switch_zeros as &dyn Any;
// Keys from previous versions where only stored as u64, we check if the destination
// key is also u64 or we return an error
Ok(CompressedModulusSwitchNoiseReductionKey {
modulus_switch_zeros: modulus_switch_zeros
.downcast_ref::<SeededLweCiphertextListOwned<InputScalar>>()
.ok_or_else(|| {
Error::new(format!(
"Expected u64 as InputScalar while upgrading \
CompressedModulusSwitchNoiseReductionKey, got {}",
std::any::type_name::<InputScalar>(),
))
})?
.clone(),
ms_bound: self.ms_bound,
ms_r_sigma_factor: self.ms_r_sigma_factor,
ms_input_variance: self.ms_input_variance,
})
}
}
#[derive(VersionsDispatch)]
pub enum CompressedModulusSwitchNoiseReductionKeyVersions<InputScalar>
where
InputScalar: UnsignedInteger,
{
V0(CompressedModulusSwitchNoiseReductionKey<InputScalar>),
V0(CompressedModulusSwitchNoiseReductionKeyV0),
V1(CompressedModulusSwitchNoiseReductionKey<InputScalar>),
}

View File

@@ -7,6 +7,7 @@ use crate::core_crypto::commons::parameters::{
};
use crate::core_crypto::commons::traits::{CastInto, Container, UnsignedInteger};
use crate::core_crypto::entities::*;
use crate::shortint::atomic_pattern::compressed::CompressedAtomicPatternServerKey;
use crate::shortint::atomic_pattern::AtomicPatternServerKey;
use crate::shortint::ciphertext::MaxDegree;
use crate::shortint::client_key::secret_encryption_key::SecretEncryptionKeyView;
@@ -338,30 +339,79 @@ impl ShortintEngine {
cks: &ClientKey,
max_degree: MaxDegree,
) -> CompressedServerKey {
let params = &cks.parameters;
let compressed_ap_server_key = CompressedAtomicPatternServerKey::new(cks, self);
let bootstrapping_key = match params.pbs_parameters().unwrap() {
let params = cks.parameters;
let message_modulus = params.message_modulus();
let carry_modulus = params.carry_modulus();
let max_noise_level = params.max_noise_level();
// Pack the keys in the server key set:
CompressedServerKey {
compressed_ap_server_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
}
}
pub fn new_compressed_bootstrapping_key_ks32<
InKeycont: Container<Element = u32> + Sync,
OutKeyCont: Container<Element = u64> + Sync,
>(
&mut self,
pbs_params: KeySwitch32PBSParameters,
in_key: &LweSecretKey<InKeycont>,
out_key: &GlweSecretKey<OutKeyCont>,
) -> ShortintCompressedBootstrappingKey<u32> {
let bsk = self.new_compressed_classic_bootstrapping_key(
in_key,
out_key,
pbs_params.glwe_noise_distribution,
pbs_params.pbs_base_log,
pbs_params.pbs_level,
pbs_params.ciphertext_modulus,
);
let modulus_switch_noise_reduction_key = pbs_params
.modulus_switch_noise_reduction_params
.map(|modulus_switch_noise_reduction_params| {
let seed = self.seeder.seed();
CompressedModulusSwitchNoiseReductionKey::new(
modulus_switch_noise_reduction_params,
in_key,
self,
pbs_params.post_keyswitch_ciphertext_modulus,
pbs_params.lwe_noise_distribution,
CompressionSeed { seed },
)
});
ShortintCompressedBootstrappingKey::Classic {
bsk,
modulus_switch_noise_reduction_key,
}
}
pub fn new_compressed_bootstrapping_key<
InKeycont: Container<Element = u64> + Sync,
OutKeyCont: Container<Element = u64> + Sync,
>(
&mut self,
pbs_params_base: PBSParameters,
in_key: &LweSecretKey<InKeycont>,
out_key: &GlweSecretKey<OutKeyCont>,
) -> ShortintCompressedBootstrappingKey<u64> {
match pbs_params_base {
crate::shortint::PBSParameters::PBS(pbs_params) => {
#[cfg(any(not(feature = "__wasm_api"), feature = "parallel-wasm-api"))]
let bootstrapping_key = par_allocate_and_generate_new_seeded_lwe_bootstrap_key(
&cks.small_lwe_secret_key(),
&cks.glwe_secret_key,
let bootstrapping_key = self.new_compressed_classic_bootstrapping_key(
in_key,
out_key,
pbs_params.glwe_noise_distribution,
pbs_params.pbs_base_log,
pbs_params.pbs_level,
pbs_params.glwe_noise_distribution,
pbs_params.ciphertext_modulus,
&mut self.seeder,
);
#[cfg(all(feature = "__wasm_api", not(feature = "parallel-wasm-api")))]
let bootstrapping_key = allocate_and_generate_new_seeded_lwe_bootstrap_key(
&cks.small_lwe_secret_key(),
&cks.glwe_secret_key,
pbs_params.pbs_base_log,
pbs_params.pbs_level,
pbs_params.glwe_noise_distribution,
pbs_params.ciphertext_modulus,
&mut self.seeder,
);
let modulus_switch_noise_reduction_key = pbs_params
@@ -371,7 +421,7 @@ impl ShortintEngine {
CompressedModulusSwitchNoiseReductionKey::new(
modulus_switch_noise_reduction_params,
&cks.small_lwe_secret_key(),
in_key,
self,
pbs_params.ciphertext_modulus,
pbs_params.lwe_noise_distribution,
@@ -385,60 +435,74 @@ impl ShortintEngine {
}
}
crate::shortint::PBSParameters::MultiBitPBS(pbs_params) => {
#[cfg(any(not(feature = "__wasm_api"), feature = "parallel-wasm-api"))]
let bootstrapping_key =
par_allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key(
&cks.small_lwe_secret_key(),
&cks.glwe_secret_key,
pbs_params.pbs_base_log,
pbs_params.pbs_level,
pbs_params.glwe_noise_distribution,
pbs_params.grouping_factor,
pbs_params.ciphertext_modulus,
&mut self.seeder,
);
#[cfg(all(feature = "__wasm_api", not(feature = "parallel-wasm-api")))]
let bootstrapping_key =
allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key(
&cks.small_lwe_secret_key(),
&cks.glwe_secret_key,
pbs_params.pbs_base_log,
pbs_params.pbs_level,
pbs_params.glwe_noise_distribution,
pbs_params.grouping_factor,
pbs_params.ciphertext_modulus,
&mut self.seeder,
);
if cfg!(feature = "__wasm_api") && !cfg!(feature = "parallel-wasm-api") {
// WASM and no parallelism -> sequential generation
allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key(
in_key,
out_key,
pbs_params.pbs_base_log,
pbs_params.pbs_level,
pbs_params.glwe_noise_distribution,
pbs_params.grouping_factor,
pbs_params.ciphertext_modulus,
&mut self.seeder,
)
} else {
par_allocate_and_generate_new_seeded_lwe_multi_bit_bootstrap_key(
in_key,
out_key,
pbs_params.pbs_base_log,
pbs_params.pbs_level,
pbs_params.glwe_noise_distribution,
pbs_params.grouping_factor,
pbs_params.ciphertext_modulus,
&mut self.seeder,
)
};
ShortintCompressedBootstrappingKey::MultiBit {
seeded_bsk: bootstrapping_key,
deterministic_execution: pbs_params.deterministic_execution,
}
}
};
}
}
// Creation of the key switching key
let key_switching_key = allocate_and_generate_new_seeded_lwe_keyswitch_key(
&cks.large_lwe_secret_key(),
&cks.small_lwe_secret_key(),
params.ks_base_log(),
params.ks_level(),
params.lwe_noise_distribution(),
params.ciphertext_modulus(),
&mut self.seeder,
);
// Pack the keys in the server key set:
CompressedServerKey {
key_switching_key,
bootstrapping_key,
message_modulus: params.message_modulus(),
carry_modulus: params.carry_modulus(),
max_degree,
max_noise_level: params.max_noise_level(),
ciphertext_modulus: params.ciphertext_modulus(),
pbs_order: params.encryption_key_choice().into(),
pub fn new_compressed_classic_bootstrapping_key<
InputScalar: UnsignedInteger + CastInto<u64>,
InKeycont: Container<Element = InputScalar>,
OutKeyCont: Container<Element = u64> + Sync,
>(
&mut self,
in_key: &LweSecretKey<InKeycont>,
out_key: &GlweSecretKey<OutKeyCont>,
glwe_noise_distribution: DynamicDistribution<u64>,
pbs_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
ciphertext_modulus: CiphertextModulus,
) -> SeededLweBootstrapKeyOwned<u64> {
if cfg!(feature = "__wasm_api") && !cfg!(feature = "parallel-wasm-api") {
// WASM and no parallelism -> sequential generation
allocate_and_generate_new_seeded_lwe_bootstrap_key(
in_key,
out_key,
pbs_base_log,
pbs_level,
glwe_noise_distribution,
ciphertext_modulus,
&mut self.seeder,
)
} else {
par_allocate_and_generate_new_seeded_lwe_bootstrap_key(
in_key,
out_key,
pbs_base_log,
pbs_level,
glwe_noise_distribution,
ciphertext_modulus,
&mut self.seeder,
)
}
}
}

View File

@@ -1005,10 +1005,12 @@ impl CompressedKeySwitchingKey {
);
assert_eq!(
src_server_key.ciphertext_modulus, dest_server_key.ciphertext_modulus,
src_server_key.ciphertext_modulus(),
dest_server_key.ciphertext_modulus(),
"Mismatch between the source CompressedServerKey CiphertextModulus ({:?}) \
and the destination CompressedServerKey CiphertextModulus ({:?})",
src_server_key.ciphertext_modulus, dest_server_key.ciphertext_modulus,
src_server_key.ciphertext_modulus(),
dest_server_key.ciphertext_modulus(),
);
}
None => assert!(
@@ -1018,9 +1020,16 @@ impl CompressedKeySwitchingKey {
),
}
let std_dest_server_key = dest_server_key
.as_compressed_standard_atomic_pattern_server_key()
.expect(
"Trying to build a shortint::CompressedKeySwitchingKey \
with an unsupported atomic pattern",
);
let dest_bootstrapping_key = std_dest_server_key.bootstrapping_key();
let dst_lwe_dimension = match key_switching_key_material.destination_key {
EncryptionKeyChoice::Big => dest_server_key.bootstrapping_key.output_lwe_dimension(),
EncryptionKeyChoice::Small => dest_server_key.bootstrapping_key.input_lwe_dimension(),
EncryptionKeyChoice::Big => dest_bootstrapping_key.output_lwe_dimension(),
EncryptionKeyChoice::Small => dest_bootstrapping_key.input_lwe_dimension(),
};
assert_eq!(
@@ -1039,13 +1048,13 @@ impl CompressedKeySwitchingKey {
key_switching_key_material
.key_switching_key
.ciphertext_modulus(),
dest_server_key.ciphertext_modulus,
dest_server_key.ciphertext_modulus(),
"Mismatch between the SeededLweKeyswitchKey CiphertextModulus ({:?}) \
and the destination CompressedServerKey CiphertextModulus ({:?})",
key_switching_key_material
.key_switching_key
.ciphertext_modulus(),
dest_server_key.ciphertext_modulus,
dest_server_key.ciphertext_modulus(),
);
Self {

View File

@@ -8,24 +8,32 @@ use super::{
use crate::conformance::ParameterSetConformant;
use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::LweBootstrapKeyConformanceParams;
use crate::core_crypto::prelude::*;
use crate::shortint::atomic_pattern::StandardAtomicPatternServerKey;
use crate::shortint::atomic_pattern::compressed::{
CompressedAtomicPatternServerKey, CompressedStandardAtomicPatternServerKey,
};
use crate::shortint::backward_compatibility::server_key::{
CompressedServerKeyVersions, ShortintCompressedBootstrappingKeyVersions,
};
use crate::shortint::ciphertext::MaxNoiseLevel;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::{CarryModulus, CiphertextModulus, MessageModulus};
use crate::shortint::parameters::{
AtomicPatternParameters, CarryModulus, CiphertextModulus, MessageModulus,
};
use crate::shortint::server_key::ShortintBootstrappingKey;
use crate::shortint::{ClientKey, PBSParameters, ServerKey};
use crate::shortint::{ClientKey, ServerKey};
use serde::{Deserialize, Serialize};
use tfhe_versionable::Versionize;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(ShortintCompressedBootstrappingKeyVersions)]
pub enum ShortintCompressedBootstrappingKey {
pub enum ShortintCompressedBootstrappingKey<InputScalar>
where
InputScalar: UnsignedInteger,
{
Classic {
bsk: SeededLweBootstrapKeyOwned<u64>,
modulus_switch_noise_reduction_key: Option<CompressedModulusSwitchNoiseReductionKey<u64>>,
modulus_switch_noise_reduction_key:
Option<CompressedModulusSwitchNoiseReductionKey<InputScalar>>,
},
MultiBit {
seeded_bsk: SeededLweMultiBitBootstrapKeyOwned<u64>,
@@ -33,7 +41,10 @@ pub enum ShortintCompressedBootstrappingKey {
},
}
impl ShortintCompressedBootstrappingKey {
impl<InputScalar> ShortintCompressedBootstrappingKey<InputScalar>
where
InputScalar: UnsignedInteger,
{
pub fn input_lwe_dimension(&self) -> LweDimension {
match self {
Self::Classic { bsk, .. } => bsk.input_lwe_dimension(),
@@ -116,6 +127,91 @@ impl ShortintCompressedBootstrappingKey {
}
}
impl<InputScalar: UnsignedTorus> ShortintCompressedBootstrappingKey<InputScalar> {
pub fn decompress(&self) -> ShortintBootstrappingKey<InputScalar> {
match self {
Self::Classic {
bsk: compressed_bootstrapping_key,
modulus_switch_noise_reduction_key,
} => {
let (fourier_bsk, modulus_switch_noise_reduction_key) = rayon::join(
|| {
let decompressed_bootstrapping_key = compressed_bootstrapping_key
.as_view()
.par_decompress_into_lwe_bootstrap_key();
let mut fourier_bsk = FourierLweBootstrapKeyOwned::new(
decompressed_bootstrapping_key.input_lwe_dimension(),
decompressed_bootstrapping_key.glwe_size(),
decompressed_bootstrapping_key.polynomial_size(),
decompressed_bootstrapping_key.decomposition_base_log(),
decompressed_bootstrapping_key.decomposition_level_count(),
);
par_convert_standard_lwe_bootstrap_key_to_fourier(
&decompressed_bootstrapping_key,
&mut fourier_bsk,
);
fourier_bsk
},
|| {
modulus_switch_noise_reduction_key.as_ref().map(
|modulus_switch_noise_reduction_key| {
modulus_switch_noise_reduction_key.decompress()
},
)
},
);
ShortintBootstrappingKey::Classic {
bsk: fourier_bsk,
modulus_switch_noise_reduction_key,
}
}
Self::MultiBit {
seeded_bsk: compressed_bootstrapping_key,
deterministic_execution,
} => {
let decompressed_bootstrapping_key = compressed_bootstrapping_key
.as_view()
.par_decompress_into_lwe_multi_bit_bootstrap_key();
let mut fourier_bsk = FourierLweMultiBitBootstrapKeyOwned::new(
decompressed_bootstrapping_key.input_lwe_dimension(),
decompressed_bootstrapping_key.glwe_size(),
decompressed_bootstrapping_key.polynomial_size(),
decompressed_bootstrapping_key.decomposition_base_log(),
decompressed_bootstrapping_key.decomposition_level_count(),
decompressed_bootstrapping_key.grouping_factor(),
);
par_convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(
&decompressed_bootstrapping_key,
&mut fourier_bsk,
);
let thread_count = ShortintEngine::with_thread_local_mut(|engine| {
engine.get_thread_count_for_multi_bit_pbs(
fourier_bsk.input_lwe_dimension(),
fourier_bsk.glwe_size().to_glwe_dimension(),
fourier_bsk.polynomial_size(),
fourier_bsk.decomposition_base_log(),
fourier_bsk.decomposition_level_count(),
fourier_bsk.grouping_factor(),
)
});
ShortintBootstrappingKey::MultiBit {
fourier_bsk,
thread_count,
deterministic_execution: *deterministic_execution,
}
}
}
}
}
/// A structure containing a compressed server public key.
///
/// The server key is generated by the client and is meant to be published: the client
@@ -123,8 +219,7 @@ impl ShortintCompressedBootstrappingKey {
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(CompressedServerKeyVersions)]
pub struct CompressedServerKey {
pub key_switching_key: SeededLweKeyswitchKeyOwned<u64>,
pub bootstrapping_key: ShortintCompressedBootstrappingKey,
pub compressed_ap_server_key: CompressedAtomicPatternServerKey,
// Size of the message buffer
pub message_modulus: MessageModulus,
// Size of the carry buffer
@@ -132,8 +227,6 @@ pub struct CompressedServerKey {
// Maximum number of operations that can be done before emptying the operation buffer
pub max_degree: MaxDegree,
pub max_noise_level: MaxNoiseLevel,
pub ciphertext_modulus: CiphertextModulus,
pub pbs_order: PBSOrder,
}
impl CompressedServerKey {
@@ -158,119 +251,21 @@ impl CompressedServerKey {
/// Decompress a [`CompressedServerKey`] into a [`ServerKey`].
pub fn decompress(&self) -> ServerKey {
let Self {
key_switching_key: compressed_key_switching_key,
bootstrapping_key: compressed_bootstrapping_key,
compressed_ap_server_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus,
pbs_order,
} = self;
let (key_switching_key, bootstrapping_key) = rayon::join(
|| {
compressed_key_switching_key
.as_view()
.par_decompress_into_lwe_keyswitch_key()
},
|| match compressed_bootstrapping_key {
ShortintCompressedBootstrappingKey::Classic {
bsk: compressed_bootstrapping_key,
modulus_switch_noise_reduction_key,
} => {
let (fourier_bsk, modulus_switch_noise_reduction_key) = rayon::join(
|| {
let decompressed_bootstrapping_key = compressed_bootstrapping_key
.as_view()
.par_decompress_into_lwe_bootstrap_key();
let mut fourier_bsk = FourierLweBootstrapKeyOwned::new(
decompressed_bootstrapping_key.input_lwe_dimension(),
decompressed_bootstrapping_key.glwe_size(),
decompressed_bootstrapping_key.polynomial_size(),
decompressed_bootstrapping_key.decomposition_base_log(),
decompressed_bootstrapping_key.decomposition_level_count(),
);
par_convert_standard_lwe_bootstrap_key_to_fourier(
&decompressed_bootstrapping_key,
&mut fourier_bsk,
);
fourier_bsk
},
|| {
modulus_switch_noise_reduction_key.as_ref().map(
|modulus_switch_noise_reduction_key| {
modulus_switch_noise_reduction_key.decompress()
},
)
},
);
ShortintBootstrappingKey::Classic {
bsk: fourier_bsk,
modulus_switch_noise_reduction_key,
}
}
ShortintCompressedBootstrappingKey::MultiBit {
seeded_bsk: compressed_bootstrapping_key,
deterministic_execution,
} => {
let decompressed_bootstrapping_key = compressed_bootstrapping_key
.as_view()
.par_decompress_into_lwe_multi_bit_bootstrap_key();
let mut fourier_bsk = FourierLweMultiBitBootstrapKeyOwned::new(
decompressed_bootstrapping_key.input_lwe_dimension(),
decompressed_bootstrapping_key.glwe_size(),
decompressed_bootstrapping_key.polynomial_size(),
decompressed_bootstrapping_key.decomposition_base_log(),
decompressed_bootstrapping_key.decomposition_level_count(),
decompressed_bootstrapping_key.grouping_factor(),
);
par_convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(
&decompressed_bootstrapping_key,
&mut fourier_bsk,
);
let thread_count = ShortintEngine::with_thread_local_mut(|engine| {
engine.get_thread_count_for_multi_bit_pbs(
fourier_bsk.input_lwe_dimension(),
fourier_bsk.glwe_size().to_glwe_dimension(),
fourier_bsk.polynomial_size(),
fourier_bsk.decomposition_base_log(),
fourier_bsk.decomposition_level_count(),
fourier_bsk.grouping_factor(),
)
});
ShortintBootstrappingKey::MultiBit {
fourier_bsk,
thread_count,
deterministic_execution: *deterministic_execution,
}
}
},
);
let message_modulus = *message_modulus;
let carry_modulus = *carry_modulus;
let max_degree = *max_degree;
let max_noise_level = *max_noise_level;
let ciphertext_modulus = *ciphertext_modulus;
let pbs_order = *pbs_order;
let atomic_pattern = StandardAtomicPatternServerKey::from_raw_parts(
key_switching_key,
bootstrapping_key,
pbs_order,
);
let ciphertext_modulus = compressed_ap_server_key.ciphertext_modulus();
ServerKey {
atomic_pattern: atomic_pattern.into(),
atomic_pattern: compressed_ap_server_key.decompress(),
message_modulus,
carry_modulus,
max_degree,
@@ -283,35 +278,26 @@ impl CompressedServerKey {
pub fn into_raw_parts(
self,
) -> (
SeededLweKeyswitchKeyOwned<u64>,
ShortintCompressedBootstrappingKey,
CompressedAtomicPatternServerKey,
MessageModulus,
CarryModulus,
MaxDegree,
MaxNoiseLevel,
CiphertextModulus,
PBSOrder,
) {
let Self {
key_switching_key,
bootstrapping_key,
compressed_ap_server_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus,
pbs_order,
} = self;
(
key_switching_key,
bootstrapping_key,
compressed_ap_server_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus,
pbs_order,
)
}
@@ -322,51 +308,12 @@ impl CompressedServerKey {
/// Panics if the constituents are not compatible with each others.
#[allow(clippy::too_many_arguments)]
pub fn from_raw_parts(
key_switching_key: SeededLweKeyswitchKeyOwned<u64>,
bootstrapping_key: ShortintCompressedBootstrappingKey,
compressed_ap_server_key: CompressedAtomicPatternServerKey,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
max_degree: MaxDegree,
max_noise_level: MaxNoiseLevel,
ciphertext_modulus: CiphertextModulus,
pbs_order: PBSOrder,
) -> Self {
assert_eq!(
key_switching_key.input_key_lwe_dimension(),
bootstrapping_key.output_lwe_dimension(),
"Mismatch between the input SeededLweKeyswitchKeyOwned LweDimension ({:?}) \
and the ShortintCompressedBootstrappingKey output LweDimension ({:?})",
key_switching_key.input_key_lwe_dimension(),
bootstrapping_key.output_lwe_dimension()
);
assert_eq!(
key_switching_key.output_key_lwe_dimension(),
bootstrapping_key.input_lwe_dimension(),
"Mismatch between the output SeededLweKeyswitchKeyOwned LweDimension ({:?}) \
and the ShortintCompressedBootstrappingKey input LweDimension ({:?})",
key_switching_key.output_key_lwe_dimension(),
bootstrapping_key.input_lwe_dimension()
);
assert_eq!(
key_switching_key.ciphertext_modulus(),
ciphertext_modulus,
"Mismatch between the SeededLweKeyswitchKeyOwned CiphertextModulus ({:?}) \
and the provided CiphertextModulus ({:?})",
key_switching_key.ciphertext_modulus(),
ciphertext_modulus
);
assert_eq!(
bootstrapping_key.ciphertext_modulus(),
ciphertext_modulus,
"Mismatch between the ShortintCompressedBootstrappingKey CiphertextModulus ({:?}) \
and the provided CiphertextModulus ({:?})",
bootstrapping_key.ciphertext_modulus(),
ciphertext_modulus
);
let max_max_degree = MaxDegree::from_msg_carry_modulus(message_modulus, carry_modulus);
assert!(
@@ -375,14 +322,11 @@ impl CompressedServerKey {
);
Self {
key_switching_key,
bootstrapping_key,
compressed_ap_server_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus,
pbs_order,
}
}
@@ -393,15 +337,60 @@ impl CompressedServerKey {
})
}
pub(crate) fn as_compressed_standard_atomic_pattern_server_key(
&self,
) -> Option<&CompressedStandardAtomicPatternServerKey> {
match &self.compressed_ap_server_key {
CompressedAtomicPatternServerKey::Standard(
compressed_standard_atomic_pattern_server_key,
) => Some(compressed_standard_atomic_pattern_server_key),
CompressedAtomicPatternServerKey::KeySwitch32(_) => None,
}
}
pub fn ciphertext_lwe_dimension(&self) -> LweDimension {
match self.pbs_order {
PBSOrder::KeyswitchBootstrap => self.key_switching_key.input_key_lwe_dimension(),
PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_dimension(),
self.compressed_ap_server_key.ciphertext_lwe_dimension()
}
pub fn ciphertext_modulus(&self) -> CiphertextModulus {
self.compressed_ap_server_key.ciphertext_modulus()
}
pub fn bootstrapping_key_size_bytes(&self) -> usize {
match &self.compressed_ap_server_key {
CompressedAtomicPatternServerKey::Standard(
compressed_standard_atomic_pattern_server_key,
) => compressed_standard_atomic_pattern_server_key
.bootstrapping_key()
.bootstrapping_key_size_bytes(),
CompressedAtomicPatternServerKey::KeySwitch32(
compressed_ks32_atomic_pattern_server_key,
) => compressed_ks32_atomic_pattern_server_key
.bootstrapping_key()
.bootstrapping_key_size_bytes(),
}
}
pub fn bootstrapping_key_size_elements(&self) -> usize {
match &self.compressed_ap_server_key {
CompressedAtomicPatternServerKey::Standard(
compressed_standard_atomic_pattern_server_key,
) => compressed_standard_atomic_pattern_server_key
.bootstrapping_key()
.bootstrapping_key_size_elements(),
CompressedAtomicPatternServerKey::KeySwitch32(
compressed_ks32_atomic_pattern_server_key,
) => compressed_ks32_atomic_pattern_server_key
.bootstrapping_key()
.bootstrapping_key_size_elements(),
}
}
}
impl ParameterSetConformant for ShortintCompressedBootstrappingKey {
impl<InputScalar> ParameterSetConformant for ShortintCompressedBootstrappingKey<InputScalar>
where
InputScalar: UnsignedInteger,
{
type ParameterSet = PBSConformanceParams;
fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool {
@@ -435,10 +424,9 @@ impl ParameterSetConformant for ShortintCompressedBootstrappingKey {
},
PbsTypeConformanceParams::MultiBit { .. },
) => {
let param: MultiBitBootstrapKeyConformanceParams =
parameter_set.try_into().unwrap();
let param = parameter_set.try_into();
seeded_bsk.is_conformant(&param)
param.is_ok_and(|param| seeded_bsk.is_conformant(&param))
}
_ => false,
}
@@ -446,41 +434,23 @@ impl ParameterSetConformant for ShortintCompressedBootstrappingKey {
}
impl ParameterSetConformant for CompressedServerKey {
type ParameterSet = (PBSParameters, MaxDegree);
type ParameterSet = (AtomicPatternParameters, MaxDegree);
fn is_conformant(&self, (parameter_set, expected_max_degree): &Self::ParameterSet) -> bool {
let Self {
key_switching_key,
bootstrapping_key,
compressed_ap_server_key,
message_modulus,
carry_modulus,
max_degree,
max_noise_level,
ciphertext_modulus,
pbs_order,
} = self;
let params: PBSConformanceParams = parameter_set.into();
let compressed_ap_server_key_ok = compressed_ap_server_key.is_conformant(parameter_set);
let pbs_key_ok = bootstrapping_key.is_conformant(&params);
let param: LweKeyswitchKeyConformanceParams<u64> = parameter_set.into();
let ks_key_ok = key_switching_key.is_conformant(&param);
let pbs_order_ok = matches!(
(*pbs_order, parameter_set.encryption_key_choice()),
(PBSOrder::KeyswitchBootstrap, EncryptionKeyChoice::Big)
| (PBSOrder::BootstrapKeyswitch, EncryptionKeyChoice::Small)
);
pbs_key_ok
&& ks_key_ok
&& pbs_order_ok
compressed_ap_server_key_ok
&& *max_degree == *expected_max_degree
&& *message_modulus == parameter_set.message_modulus()
&& *carry_modulus == parameter_set.carry_modulus()
&& *max_noise_level == parameter_set.max_noise_level()
&& *ciphertext_modulus == parameter_set.ciphertext_modulus()
}
}