From c6a493954b9a3bbb06616376e52ee47137f348cc Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Fri, 2 May 2025 16:03:37 +0200 Subject: [PATCH] feat(shortint): insert the AP inside the ServerKey --- tfhe/examples/utilities/shortint_key_sizes.rs | 9 +- tfhe/src/high_level_api/booleans/oprf.rs | 4 +- .../compressed_ciphertext_list.rs | 4 +- tfhe/src/high_level_api/config.rs | 4 +- .../integers/unsigned/tests/gpu.rs | 5 +- tfhe/src/high_level_api/keys/inner.rs | 15 +- tfhe/src/high_level_api/keys/server.rs | 10 +- tfhe/src/high_level_api/tests/mod.rs | 5 +- tfhe/src/integer/oprf.rs | 8 +- tfhe/src/integer/server_key/mod.rs | 22 +- tfhe/src/integer/wopbs/mod.rs | 22 +- tfhe/src/shortint/atomic_pattern/standard.rs | 25 +- .../backward_compatibility/server_key/mod.rs | 71 ++- tfhe/src/shortint/ciphertext/compact_list.rs | 4 +- tfhe/src/shortint/engine/mod.rs | 15 +- tfhe/src/shortint/engine/server_side.rs | 19 +- tfhe/src/shortint/engine/wopbs/mod.rs | 35 +- tfhe/src/shortint/key_switching_key/mod.rs | 84 ++- tfhe/src/shortint/keycache.rs | 6 +- .../shortint/list_compression/server_keys.rs | 15 +- .../shortint/noise_squashing/server_key.rs | 43 +- tfhe/src/shortint/oprf.rs | 157 +++--- .../parameters/compact_public_key_only.rs | 14 + tfhe/src/shortint/server_key/add.rs | 7 +- tfhe/src/shortint/server_key/bitwise_op.rs | 5 +- tfhe/src/shortint/server_key/bivariate_pbs.rs | 10 +- tfhe/src/shortint/server_key/comp_op.rs | 6 +- tfhe/src/shortint/server_key/compressed.rs | 9 +- tfhe/src/shortint/server_key/div_mod.rs | 6 +- tfhe/src/shortint/server_key/mod.rs | 478 +++++++----------- .../modulus_switched_compression.rs | 194 +------ tfhe/src/shortint/server_key/mul.rs | 7 +- tfhe/src/shortint/server_key/neg.rs | 7 +- tfhe/src/shortint/server_key/scalar_add.rs | 7 +- .../shortint/server_key/scalar_bitwise_op.rs | 5 +- .../src/shortint/server_key/scalar_div_mod.rs | 6 +- tfhe/src/shortint/server_key/scalar_mul.rs | 7 +- tfhe/src/shortint/server_key/scalar_sub.rs | 7 +- tfhe/src/shortint/server_key/shift.rs | 7 +- tfhe/src/shortint/server_key/sub.rs | 7 +- tfhe/src/shortint/wopbs/mod.rs | 94 ++-- tfhe/src/shortint/wopbs/test.rs | 4 +- 42 files changed, 745 insertions(+), 724 deletions(-) diff --git a/tfhe/examples/utilities/shortint_key_sizes.rs b/tfhe/examples/utilities/shortint_key_sizes.rs index 3dcca551c..4eaaa7238 100644 --- a/tfhe/examples/utilities/shortint_key_sizes.rs +++ b/tfhe/examples/utilities/shortint_key_sizes.rs @@ -9,6 +9,7 @@ use tfhe::keycache::NamedParam; use tfhe::shortint::keycache::KEY_CACHE; use tfhe::shortint::parameters::current_params::*; use tfhe::shortint::parameters::*; +use tfhe::shortint::server_key::{StandardServerKey, StandardServerKeyView}; use tfhe::shortint::{ ClassicPBSParameters, ClientKey, CompactPrivateKey, CompressedCompactPublicKey, CompressedKeySwitchingKey, CompressedServerKey, PBSParameters, @@ -58,7 +59,7 @@ fn client_server_key_sizes(results_file: &Path) { let keys = KEY_CACHE.get_from_param(params); let cks = keys.client_key(); - let sks = keys.server_key(); + let sks = StandardServerKeyView::try_from(keys.server_key().as_view()).unwrap(); let ksk_size = sks.key_switching_key_size_bytes(); let test_name = format!("shortint_key_sizes_{}_ksk", params.name()); @@ -167,10 +168,10 @@ fn tuniform_key_set_sizes(results_file: &Path) { let param_fhe_name = param_fhe.name(); let cks = ClientKey::new(param_fhe); let compressed_sks = CompressedServerKey::new(&cks); - let sks = compressed_sks.decompress(); + let sks = StandardServerKey::try_from(compressed_sks.decompress()).unwrap(); measure_serialized_size( - &sks.key_switching_key, + &sks.atomic_pattern.key_switching_key, >::into(param_fhe), ¶m_fhe_name, "ksk", @@ -187,7 +188,7 @@ fn tuniform_key_set_sizes(results_file: &Path) { ); measure_serialized_size( - &sks.bootstrapping_key, + &sks.atomic_pattern.bootstrapping_key, >::into(param_fhe), ¶m_fhe_name, "bsk", diff --git a/tfhe/src/high_level_api/booleans/oprf.rs b/tfhe/src/high_level_api/booleans/oprf.rs index 942022010..270ef554a 100644 --- a/tfhe/src/high_level_api/booleans/oprf.rs +++ b/tfhe/src/high_level_api/booleans/oprf.rs @@ -32,7 +32,9 @@ impl FheBool { pub fn generate_oblivious_pseudo_random(seed: Seed) -> Self { let (ciphertext, tag) = global_state::with_internal_keys(|key| match key { InternalServerKey::Cpu(key) => { - let ct = key.pbs_key().key.generate_oblivious_pseudo_random(seed, 1); + let sk = &key.pbs_key().key; + + let ct = sk.generate_oblivious_pseudo_random(seed, 1); ( InnerBoolean::Cpu(BooleanBlock::new_unchecked(ct)), key.tag.clone(), diff --git a/tfhe/src/high_level_api/compressed_ciphertext_list.rs b/tfhe/src/high_level_api/compressed_ciphertext_list.rs index 17dfa89ba..45849923c 100644 --- a/tfhe/src/high_level_api/compressed_ciphertext_list.rs +++ b/tfhe/src/high_level_api/compressed_ciphertext_list.rs @@ -882,8 +882,8 @@ mod tests { #[cfg(feature = "strings")] #[test] fn test_compressed_strings_cpu() { - let params = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into(); - let config = crate::ConfigBuilder::with_custom_parameters::(params) + let params = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + let config = crate::ConfigBuilder::with_custom_parameters(params) .enable_compression(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128) .build(); diff --git a/tfhe/src/high_level_api/config.rs b/tfhe/src/high_level_api/config.rs index e5b85c04b..d5cabaf53 100644 --- a/tfhe/src/high_level_api/config.rs +++ b/tfhe/src/high_level_api/config.rs @@ -62,7 +62,7 @@ impl ConfigBuilder { pub fn with_custom_parameters

(block_parameters: P) -> Self where - P: Into, + P: Into, { Self { config: Config { @@ -85,7 +85,7 @@ impl ConfigBuilder { pub fn use_custom_parameters

(mut self, block_parameters: P) -> Self where - P: Into, + P: Into, { self.config.inner = IntegerConfig::new(block_parameters.into()); self diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs index 13d97a576..9335d52ba 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs @@ -1,7 +1,8 @@ use crate::high_level_api::traits::AddAssignSizeOnGpu; use crate::prelude::{check_valid_cuda_malloc, FheTryEncrypt}; +use crate::shortint::atomic_pattern::AtomicPatternParameters; use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS; -use crate::shortint::{ClassicPBSParameters, PBSParameters}; +use crate::shortint::ClassicPBSParameters; use crate::{set_server_key, ClientKey, ConfigBuilder, FheUint32, GpuIndex}; use rand::Rng; @@ -9,7 +10,7 @@ use rand::Rng; /// /// Crates a client key, with the given parameters or default params in None were given /// and sets the gpu server key for the current thread -pub(crate) fn setup_gpu(params: Option>) -> ClientKey { +pub(crate) fn setup_gpu(params: Option>) -> ClientKey { let config = params .map_or_else(ConfigBuilder::default, |p| { ConfigBuilder::with_custom_parameters(p.into()) diff --git a/tfhe/src/high_level_api/keys/inner.rs b/tfhe/src/high_level_api/keys/inner.rs index 62e73f06d..b9dde9cbf 100644 --- a/tfhe/src/high_level_api/keys/inner.rs +++ b/tfhe/src/high_level_api/keys/inner.rs @@ -11,12 +11,13 @@ use crate::integer::noise_squashing::{ }; use crate::integer::public_key::CompactPublicKey; use crate::integer::CompressedCompactPublicKey; +use crate::shortint::atomic_pattern::AtomicPatternParameters; use crate::shortint::key_switching_key::KeySwitchingKeyConformanceParams; use crate::shortint::parameters::list_compression::CompressionParameters; use crate::shortint::parameters::{ CompactPublicKeyEncryptionParameters, NoiseSquashingParameters, ShortintKeySwitchingParameters, }; -use crate::shortint::{EncryptionKeyChoice, MessageModulus, PBSParameters}; +use crate::shortint::{EncryptionKeyChoice, MessageModulus}; use crate::{Config, Error}; use serde::{Deserialize, Serialize}; use tfhe_csprng::seeders::Seed; @@ -25,7 +26,7 @@ use tfhe_versionable::Versionize; #[derive(Copy, Clone, Debug, serde::Serialize, serde::Deserialize, Versionize)] #[versionize(IntegerConfigVersions)] pub(crate) struct IntegerConfig { - pub(crate) block_parameters: crate::shortint::PBSParameters, + pub(crate) block_parameters: crate::shortint::atomic_pattern::AtomicPatternParameters, pub(crate) dedicated_compact_public_key_parameters: Option<( crate::shortint::parameters::CompactPublicKeyEncryptionParameters, crate::shortint::parameters::ShortintKeySwitchingParameters, @@ -35,7 +36,9 @@ pub(crate) struct IntegerConfig { } impl IntegerConfig { - pub(crate) fn new(block_parameters: crate::shortint::PBSParameters) -> Self { + pub(crate) fn new( + block_parameters: crate::shortint::atomic_pattern::AtomicPatternParameters, + ) -> Self { Self { block_parameters, dedicated_compact_public_key_parameters: None, @@ -518,7 +521,7 @@ impl IntegerCompressedCompactPublicKey { } pub struct IntegerServerKeyConformanceParams { - pub sk_param: PBSParameters, + pub sk_param: AtomicPatternParameters, pub cpk_param: Option<( CompactPublicKeyEncryptionParameters, ShortintKeySwitchingParameters, @@ -541,7 +544,7 @@ impl> From for IntegerServerKeyConformanceParams { impl TryFrom<( - PBSParameters, + AtomicPatternParameters, CompactPublicKeyEncryptionParameters, ShortintKeySwitchingParameters, )> for KeySwitchingKeyConformanceParams @@ -550,7 +553,7 @@ impl fn try_from( (sk_params, cpk_params, ks_params): ( - PBSParameters, + AtomicPatternParameters, CompactPublicKeyEncryptionParameters, ShortintKeySwitchingParameters, ), diff --git a/tfhe/src/high_level_api/keys/server.rs b/tfhe/src/high_level_api/keys/server.rs index ffe7797cf..26bb0703e 100644 --- a/tfhe/src/high_level_api/keys/server.rs +++ b/tfhe/src/high_level_api/keys/server.rs @@ -451,7 +451,7 @@ mod test { PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, }; - use crate::shortint::{ClassicPBSParameters, PBSParameters}; + use crate::shortint::ClassicPBSParameters; use crate::{ClientKey, CompressedServerKey, ConfigBuilder, ServerKey}; #[test] @@ -569,7 +569,7 @@ mod test { modifier(&mut sk_param); - let sk_param = PBSParameters::PBS(sk_param); + let sk_param = sk_param.into(); let conformance_params = IntegerServerKeyConformanceParams { sk_param, @@ -595,7 +595,7 @@ mod test { let ck = ClientKey::generate(config); let sk = ServerKey::new(&ck); - let sk_param = PBSParameters::PBS(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128); + let sk_param = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into(); cpk_params.encryption_lwe_dimension.0 += 1; @@ -724,7 +724,7 @@ mod test { modifier(&mut sk_param); - let sk_param = PBSParameters::PBS(sk_param); + let sk_param = sk_param.into(); let conformance_params = IntegerServerKeyConformanceParams { sk_param, @@ -750,7 +750,7 @@ mod test { let ck = ClientKey::generate(config); let sk = CompressedServerKey::new(&ck); - let sk_param = PBSParameters::PBS(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128); + let sk_param = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into(); cpk_params.encryption_lwe_dimension.0 += 1; diff --git a/tfhe/src/high_level_api/tests/mod.rs b/tfhe/src/high_level_api/tests/mod.rs index 8e87bf2a7..c87dec2fe 100644 --- a/tfhe/src/high_level_api/tests/mod.rs +++ b/tfhe/src/high_level_api/tests/mod.rs @@ -8,13 +8,14 @@ use crate::high_level_api::{ generate_keys, ClientKey, ConfigBuilder, FheBool, FheUint256, FheUint8, PublicKey, ServerKey, }; use crate::integer::U256; -use crate::shortint::{ClassicPBSParameters, PBSParameters}; +use crate::shortint::atomic_pattern::AtomicPatternParameters; +use crate::shortint::ClassicPBSParameters; use crate::{ set_server_key, CompactPublicKey, CompressedPublicKey, CompressedServerKey, FheUint32, Tag, }; use std::fmt::Debug; -pub(crate) fn setup_cpu(params: Option>) -> ClientKey { +pub(crate) fn setup_cpu(params: Option>) -> ClientKey { let config = params .map_or_else(ConfigBuilder::default, |p| { ConfigBuilder::with_custom_parameters(p.into()) diff --git a/tfhe/src/integer/oprf.rs b/tfhe/src/integer/oprf.rs index f7e8285ff..07388634a 100644 --- a/tfhe/src/integer/oprf.rs +++ b/tfhe/src/integer/oprf.rs @@ -38,6 +38,8 @@ impl ServerKey { let random_bits_count = range_log_size; + let sk = &self.key; + assert!(self.message_modulus().0.is_power_of_two()); let message_bits_count = self.message_modulus().0.ilog2() as u64; @@ -61,11 +63,9 @@ impl ServerKey { assert!(top_message_bits_count <= message_bits_count); - self.key - .generate_oblivious_pseudo_random(seed, top_message_bits_count) + sk.generate_oblivious_pseudo_random(seed, top_message_bits_count) } else { - self.key - .generate_oblivious_pseudo_random(seed, message_bits_count) + sk.generate_oblivious_pseudo_random(seed, message_bits_count) } } else { self.key.create_trivial(0) diff --git a/tfhe/src/integer/server_key/mod.rs b/tfhe/src/integer/server_key/mod.rs index ad2d21f80..e94e0ee89 100644 --- a/tfhe/src/integer/server_key/mod.rs +++ b/tfhe/src/integer/server_key/mod.rs @@ -12,10 +12,11 @@ use super::backward_compatibility::server_key::{CompressedServerKeyVersions, Ser use crate::conformance::ParameterSetConformant; use crate::core_crypto::prelude::UnsignedInteger; use crate::integer::client_key::ClientKey; +use crate::shortint::atomic_pattern::AtomicPatternParameters; use crate::shortint::ciphertext::{Degree, MaxDegree}; /// Error returned when the carry buffer is full. pub use crate::shortint::CheckError; -use crate::shortint::{CarryModulus, MessageModulus, PBSParameters}; +use crate::shortint::{CarryModulus, MessageModulus}; pub use radix::scalar_mul::ScalarMultiplier; pub use radix::scalar_sub::TwosComplementNegation; pub use radix_parallel::{MatchValues, MiniUnsignedInteger, Reciprocable}; @@ -195,12 +196,12 @@ impl ServerKey { } pub fn deterministic_pbs_execution(&self) -> bool { - self.key.deterministic_pbs_execution() + self.key.deterministic_execution() } pub fn set_deterministic_pbs_execution(&mut self, new_deterministic_execution: bool) { self.key - .set_deterministic_pbs_execution(new_deterministic_execution); + .set_deterministic_execution(new_deterministic_execution); } pub fn message_modulus(&self) -> MessageModulus { @@ -303,7 +304,7 @@ where } impl ParameterSetConformant for ServerKey { - type ParameterSet = PBSParameters; + type ParameterSet = AtomicPatternParameters; fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool { let Self { key } = self; @@ -318,17 +319,22 @@ impl ParameterSetConformant for ServerKey { } impl ParameterSetConformant for CompressedServerKey { - type ParameterSet = PBSParameters; + type ParameterSet = AtomicPatternParameters; fn is_conformant(&self, parameter_set: &Self::ParameterSet) -> bool { let Self { key } = self; + let AtomicPatternParameters::Standard(parameters) = *parameter_set else { + // Server key compression is only supported for classical AP + return false; + }; + let expected_max_degree = MaxDegree::integer_radix_server_key( - parameter_set.message_modulus(), - parameter_set.carry_modulus(), + parameters.message_modulus(), + parameters.carry_modulus(), ); - key.is_conformant(&(*parameter_set, expected_max_degree)) + key.is_conformant(&(parameters, expected_max_degree)) } } diff --git a/tfhe/src/integer/wopbs/mod.rs b/tfhe/src/integer/wopbs/mod.rs index 872817747..cc929a06e 100644 --- a/tfhe/src/integer/wopbs/mod.rs +++ b/tfhe/src/integer/wopbs/mod.rs @@ -22,7 +22,9 @@ mod experimental { use crate::core_crypto::prelude::*; use crate::integer::client_key::utils::i_crt; use crate::integer::{ClientKey, CrtCiphertext, IntegerCiphertext, RadixCiphertext, ServerKey}; + use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::{Degree, NoiseLevel}; + use crate::shortint::server_key::StandardServerKeyView; use crate::shortint::WopbsParameters; use crate::shortint::wopbs::WopbsLUTBase; @@ -230,10 +232,16 @@ mod experimental { sks: &ServerKey, parameters: &WopbsParameters, ) -> Self { + let sk = StandardServerKeyView::try_from(sks.key.as_view()).unwrap_or_else(|_| { + panic!( + "Wopbs is not supported by the chosen atomic pattern: {:?}", + sks.key.atomic_pattern.kind() + ) + }); Self { wopbs_key: crate::shortint::wopbs::WopbsKey::new_wopbs_key( &cks.as_ref().key, - &sks.key, + sk, parameters, ), } @@ -253,10 +261,17 @@ mod experimental { cks: &IntegerClientKey, sks: &ServerKey, ) -> Self { + let sk = StandardServerKeyView::try_from(sks.key.as_view()).unwrap_or_else(|_| { + panic!( + "Wopbs is not supported by the chosen atomic pattern: {:?}", + sks.key.atomic_pattern.kind() + ) + }); + Self { wopbs_key: crate::shortint::wopbs::WopbsKey::new_wopbs_key_only_for_wopbs( &cks.as_ref().key, - &sks.key, + sk, ), } } @@ -302,6 +317,7 @@ mod experimental { let extract_bits_output_lwe_size = self .wopbs_key .wopbs_server_key + .atomic_pattern .key_switching_key .output_key_lwe_dimension() .to_lwe_size(); @@ -390,6 +406,7 @@ mod experimental { let extract_bits_output_lwe_size = self .wopbs_key .wopbs_server_key + .atomic_pattern .key_switching_key .output_key_lwe_dimension() .to_lwe_size(); @@ -1087,6 +1104,7 @@ mod experimental { let extract_bits_output_lwe_size = self .wopbs_key .wopbs_server_key + .atomic_pattern .key_switching_key .output_key_lwe_dimension() .to_lwe_size(); diff --git a/tfhe/src/shortint/atomic_pattern/standard.rs b/tfhe/src/shortint/atomic_pattern/standard.rs index 258251a3b..680252674 100644 --- a/tfhe/src/shortint/atomic_pattern/standard.rs +++ b/tfhe/src/shortint/atomic_pattern/standard.rs @@ -20,6 +20,7 @@ use crate::shortint::ciphertext::{ NoiseLevel, }; use crate::shortint::engine::ShortintEngine; +use crate::shortint::oprf::generate_pseudo_random_from_pbs; use crate::shortint::server_key::{ apply_modulus_switch_noise_reduction, apply_programmable_bootstrap_no_ms_noise_reduction, LookupTableOwned, LookupTableSize, ManyLookupTableOwned, ShortintBootstrappingKey, @@ -128,7 +129,10 @@ impl AtomicPattern for StandardAtomicPatternServerKey { fn apply_lookup_table_assign(&self, ct: &mut Ciphertext, acc: &LookupTableOwned) { ShortintEngine::with_thread_local_mut(|engine| { - let (mut ciphertext_buffer, buffers) = engine.get_buffers(todo!()); + let (mut ciphertext_buffer, buffers) = engine.get_buffers( + self.intermediate_lwe_dimension(), + CiphertextModulus::new_native(), + ); match self.pbs_order { PBSOrder::KeyswitchBootstrap => { @@ -197,7 +201,13 @@ impl AtomicPattern for StandardAtomicPatternServerKey { random_bits_count: u64, full_bits_count: u64, ) -> (LweCiphertextOwned, Degree) { - let (ct, degree) = todo!(); + let (ct, degree) = generate_pseudo_random_from_pbs( + &self.bootstrapping_key, + seed, + random_bits_count, + full_bits_count, + self.ciphertext_modulus(), + ); match self.pbs_order { PBSOrder::KeyswitchBootstrap => (ct, degree), @@ -218,7 +228,8 @@ impl AtomicPattern for StandardAtomicPatternServerKey { fn switch_modulus_and_compress(&self, ct: &Ciphertext) -> CompressedModulusSwitchedCiphertext { let compressed_modulus_switched_lwe_ciphertext = ShortintEngine::with_thread_local_mut(|engine| { - let (mut ciphertext_buffer, _) = engine.get_buffers(todo!()); + let (mut ciphertext_buffer, _) = engine + .get_buffers(self.intermediate_lwe_dimension(), self.ciphertext_modulus()); let input_ct = match self.pbs_order { PBSOrder::KeyswitchBootstrap => { @@ -301,7 +312,8 @@ impl AtomicPattern for StandardAtomicPatternServerKey { ); ShortintEngine::with_thread_local_mut(|engine| { - let (mut ciphertext_buffer, buffers) = engine.get_buffers(todo!()); + let (mut ciphertext_buffer, buffers) = + engine.get_buffers(self.intermediate_lwe_dimension(), self.ciphertext_modulus()); match self.pbs_order { PBSOrder::KeyswitchBootstrap => { @@ -355,7 +367,10 @@ impl StandardAtomicPatternServerKey { let mut acc = lut.acc.clone(); ShortintEngine::with_thread_local_mut(|engine| { - let (mut ciphertext_buffer, buffers) = engine.get_buffers(todo!()); + let (mut ciphertext_buffer, buffers) = engine.get_buffers( + self.intermediate_lwe_dimension(), + CiphertextModulus::new_native(), + ); // Compute a key switch keyswitch_lwe_ciphertext(&self.key_switching_key, &ct.ct, &mut ciphertext_buffer); diff --git a/tfhe/src/shortint/backward_compatibility/server_key/mod.rs b/tfhe/src/shortint/backward_compatibility/server_key/mod.rs index 3406a1cdb..3791f0a0e 100644 --- a/tfhe/src/shortint/backward_compatibility/server_key/mod.rs +++ b/tfhe/src/shortint/backward_compatibility/server_key/mod.rs @@ -1,9 +1,16 @@ pub mod modulus_switch_noise_reduction; use crate::core_crypto::entities::*; -use crate::core_crypto::prelude::Container; +use crate::core_crypto::prelude::{Container, PBSOrder}; +use crate::shortint::atomic_pattern::{AtomicPatternServerKey, StandardAtomicPatternServerKey}; +use crate::shortint::ciphertext::MaxDegree; use crate::shortint::server_key::*; +use crate::shortint::{CarryModulus, CiphertextModulus, MaxNoiseLevel, MessageModulus}; +use crate::Error; + +use std::any::{Any, TypeId}; use std::convert::Infallible; + use tfhe_versionable::deprecation::{Deprecable, Deprecated}; use tfhe_versionable::{Upgrade, Version, VersionsDispatch}; @@ -49,10 +56,68 @@ impl Deprecable for ServerKey { const MIN_SUPPORTED_APP_VERSION: &'static str = "TFHE-rs v0.10"; } +#[derive(Version)] +pub struct ServerKeyV1 { + pub key_switching_key: LweKeyswitchKeyOwned, + pub bootstrapping_key: ShortintBootstrappingKey, + pub message_modulus: MessageModulus, + pub carry_modulus: CarryModulus, + pub max_degree: MaxDegree, + pub max_noise_level: MaxNoiseLevel, + pub ciphertext_modulus: CiphertextModulus, + pub pbs_order: PBSOrder, +} + +impl Upgrade> for ServerKeyV1 { + type Error = Error; + + fn upgrade(self) -> Result, Self::Error> { + let std_ap = StandardAtomicPatternServerKey::from_raw_parts( + self.key_switching_key, + self.bootstrapping_key, + self.pbs_order, + ); + + if TypeId::of::() == TypeId::of::() { + let ap = AtomicPatternServerKey::Standard(std_ap); + let sk = ServerKey::from_raw_parts( + ap, + self.message_modulus, + self.carry_modulus, + self.max_degree, + self.max_noise_level, + ); + Ok((&sk as &dyn Any) + .downcast_ref::>() + .unwrap() // We know from the TypeId that AP is of the right type so we can unwrap + .clone()) + } else if TypeId::of::() == TypeId::of::() { + let sk = StandardServerKey::from_raw_parts( + std_ap, + self.message_modulus, + self.carry_modulus, + self.max_degree, + self.max_noise_level, + ); + Ok((&sk as &dyn Any) + .downcast_ref::>() + .unwrap() // We know from the TypeId that AP is of the right type so we can unwrap + .clone()) + } else { + Err(Error::new( + "ServerKey from TFHE-rs 1.0 and before can only be deserialized to the classical \ +Atomic Pattern" + .to_string(), + )) + } + } +} + #[derive(VersionsDispatch)] -pub enum ServerKeyVersions { +pub enum ServerKeyVersions { V0(Deprecated), - V1(ServerKey), + V1(ServerKeyV1), + V2(GenericServerKey), } impl Deprecable for ShortintCompressedBootstrappingKey { diff --git a/tfhe/src/shortint/ciphertext/compact_list.rs b/tfhe/src/shortint/ciphertext/compact_list.rs index d12a2b9b3..5e1472a9f 100644 --- a/tfhe/src/shortint/ciphertext/compact_list.rs +++ b/tfhe/src/shortint/ciphertext/compact_list.rs @@ -5,6 +5,7 @@ use super::standard::Ciphertext; use crate::conformance::ParameterSetConformant; use crate::core_crypto::commons::traits::ContiguousEntityContainer; use crate::core_crypto::entities::*; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::backward_compatibility::ciphertext::CompactCiphertextListVersions; pub use crate::shortint::parameters::ShortintCompactCiphertextListCastingMode; use crate::shortint::parameters::{ @@ -118,8 +119,7 @@ impl CompactCiphertextList { None => &vec![None; output_lwe_ciphertext_list.lwe_ciphertext_count().0], }; - let atomic_pattern = - AtomicPatternKind::Standard(casting_key.dest_server_key.pbs_order); + let atomic_pattern = casting_key.dest_server_key.atomic_pattern.kind(); let res = output_lwe_ciphertext_list .par_iter() diff --git a/tfhe/src/shortint/engine/mod.rs b/tfhe/src/shortint/engine/mod.rs index 396879eee..5fa9eecd8 100644 --- a/tfhe/src/shortint/engine/mod.rs +++ b/tfhe/src/shortint/engine/mod.rs @@ -17,7 +17,7 @@ use crate::core_crypto::prelude::{ContainerMut, GlweSize}; use crate::core_crypto::seeders::new_seeder; use crate::shortint::ciphertext::{Degree, MaxDegree}; use crate::shortint::prelude::PolynomialSize; -use crate::shortint::{CarryModulus, MessageModulus, ServerKey}; +use crate::shortint::{CarryModulus, MessageModulus}; use std::cell::RefCell; use std::fmt::Debug; @@ -326,19 +326,12 @@ impl ShortintEngine { /// - [`ComputationBuffers`] used by the FFT during the PBS pub fn get_buffers( &mut self, - server_key: &ServerKey, + lwe_dimension: LweDimension, + ciphertext_modulus: CiphertextModulus, ) -> (LweCiphertextMutView<'_, u64>, &mut ComputationBuffers) { - let lwe_dimension = match server_key.pbs_order { - super::PBSOrder::KeyswitchBootstrap => { - server_key.key_switching_key.output_key_lwe_dimension() - } - super::PBSOrder::BootstrapKeyswitch => { - server_key.key_switching_key.input_key_lwe_dimension() - } - }; ( self.ciphertext_buffers - .as_lwe(lwe_dimension, server_key.ciphertext_modulus), + .as_lwe(lwe_dimension, ciphertext_modulus), &mut self.computation_buffers, ) } diff --git a/tfhe/src/shortint/engine/server_side.rs b/tfhe/src/shortint/engine/server_side.rs index b4a58d417..b60cba3b3 100644 --- a/tfhe/src/shortint/engine/server_side.rs +++ b/tfhe/src/shortint/engine/server_side.rs @@ -7,6 +7,7 @@ use crate::core_crypto::commons::parameters::{ }; use crate::core_crypto::commons::traits::Container; use crate::core_crypto::entities::*; +use crate::shortint::atomic_pattern::StandardAtomicPatternServerKey; use crate::shortint::ciphertext::MaxDegree; use crate::shortint::client_key::secret_encryption_key::SecretEncryptionKeyView; use crate::shortint::parameters::{EncryptionKeyChoice, ShortintKeySwitchingParameters}; @@ -89,16 +90,20 @@ impl ShortintEngine { &mut self.encryption_generator, ); + let atomic_pattern = StandardAtomicPatternServerKey::from_raw_parts( + key_switching_key, + bootstrapping_key_base, + pbs_params_base.encryption_key_choice().into(), + ); + // Pack the keys in the server key set: ServerKey { - key_switching_key, - bootstrapping_key: bootstrapping_key_base, - message_modulus: params.message_modulus(), - carry_modulus: params.carry_modulus(), + atomic_pattern: atomic_pattern.into(), + message_modulus: cks.parameters.message_modulus(), + carry_modulus: cks.parameters.carry_modulus(), max_degree, - max_noise_level: params.max_noise_level(), - ciphertext_modulus: params.ciphertext_modulus(), - pbs_order: params.encryption_key_choice().into(), + max_noise_level: cks.parameters.max_noise_level(), + ciphertext_modulus: cks.parameters.ciphertext_modulus(), } } diff --git a/tfhe/src/shortint/engine/wopbs/mod.rs b/tfhe/src/shortint/engine/wopbs/mod.rs index 8fa39231d..1ac3ffcbf 100644 --- a/tfhe/src/shortint/engine/wopbs/mod.rs +++ b/tfhe/src/shortint/engine/wopbs/mod.rs @@ -1,21 +1,24 @@ //! # WARNING: this module is experimental. use crate::core_crypto::algorithms::*; use crate::core_crypto::entities::*; +use crate::shortint::atomic_pattern::StandardAtomicPatternServerKey; use crate::shortint::ciphertext::{MaxDegree, MaxNoiseLevel}; use crate::shortint::engine::ShortintEngine; -use crate::shortint::server_key::ShortintBootstrappingKey; +use crate::shortint::server_key::{ + ShortintBootstrappingKey, StandardServerKey, StandardServerKeyView, +}; use crate::shortint::wopbs::{WopbsKey, WopbsKeyCreationError}; -use crate::shortint::{ClientKey, ServerKey, WopbsParameters}; +use crate::shortint::{ClientKey, WopbsParameters}; impl ShortintEngine { // Creates a key when ONLY a wopbs is used. pub(crate) fn new_wopbs_key_only_for_wopbs( &mut self, cks: &ClientKey, - sks: &ServerKey, + sks: StandardServerKeyView<'_>, ) -> crate::Result { if matches!( - sks.bootstrapping_key, + sks.atomic_pattern.bootstrapping_key, ShortintBootstrappingKey::MultiBit { .. } ) { return Err(crate::Error::new(format!( @@ -36,12 +39,12 @@ impl ShortintEngine { &mut self.encryption_generator, ); - let sks_cpy = sks.clone(); + let sks_cpy = sks.owned(); let wopbs_key = WopbsKey { wopbs_server_key: sks_cpy.clone(), cbs_pfpksk, - ksk_pbs_to_wopbs: sks.key_switching_key.clone(), + ksk_pbs_to_wopbs: sks.atomic_pattern.key_switching_key.clone(), param: wop_params, pbs_server_key: sks_cpy, }; @@ -52,7 +55,7 @@ impl ShortintEngine { pub(crate) fn new_wopbs_key( &mut self, cks: &ClientKey, - sks: &ServerKey, + sks: StandardServerKeyView<'_>, parameters: &WopbsParameters, ) -> WopbsKey { //Independent client key generation dedicated to the WoPBS @@ -142,12 +145,17 @@ impl ShortintEngine { parameters.carry_modulus, ); - let wopbs_server_key = ServerKey { + let wopbs_atomic_pattern = StandardAtomicPatternServerKey { key_switching_key: ksk_wopbs_large_to_wopbs_small, bootstrapping_key: ShortintBootstrappingKey::Classic { bsk: small_bsk, modulus_switch_noise_reduction_key: None, }, + pbs_order: cks.parameters.encryption_key_choice().into(), + }; + + let wopbs_server_key = StandardServerKey { + atomic_pattern: wopbs_atomic_pattern, message_modulus: parameters.message_modulus, carry_modulus: parameters.carry_modulus, max_degree: MaxDegree::from_msg_carry_modulus( @@ -156,7 +164,6 @@ impl ShortintEngine { ), max_noise_level: max_noise_level_wopbs, ciphertext_modulus: parameters.ciphertext_modulus, - pbs_order: cks.parameters.encryption_key_choice().into(), }; let max_noise_level_pbs = MaxNoiseLevel::from_msg_carry_modulus( @@ -164,9 +171,14 @@ impl ShortintEngine { cks.parameters.carry_modulus(), ); - let pbs_server_key = ServerKey { + let pbs_atomic_pattern = StandardAtomicPatternServerKey { key_switching_key: ksk_wopbs_large_to_pbs_small, - bootstrapping_key: sks.bootstrapping_key.clone(), + bootstrapping_key: sks.atomic_pattern.bootstrapping_key.clone(), + pbs_order: cks.parameters.encryption_key_choice().into(), + }; + + let pbs_server_key = StandardServerKey { + atomic_pattern: pbs_atomic_pattern, message_modulus: cks.parameters.message_modulus(), carry_modulus: cks.parameters.carry_modulus(), max_degree: MaxDegree::from_msg_carry_modulus( @@ -175,7 +187,6 @@ impl ShortintEngine { ), max_noise_level: max_noise_level_pbs, ciphertext_modulus: cks.parameters.ciphertext_modulus(), - pbs_order: cks.parameters.encryption_key_choice().into(), }; WopbsKey { diff --git a/tfhe/src/shortint/key_switching_key/mod.rs b/tfhe/src/shortint/key_switching_key/mod.rs index de34252fd..57c9e8628 100644 --- a/tfhe/src/shortint/key_switching_key/mod.rs +++ b/tfhe/src/shortint/key_switching_key/mod.rs @@ -7,6 +7,7 @@ use crate::core_crypto::prelude::{ keyswitch_lwe_ciphertext, Cleartext, LweKeyswitchKeyConformanceParams, LweKeyswitchKeyOwned, SeededLweKeyswitchKeyOwned, }; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::Degree; use crate::shortint::client_key::secret_encryption_key::SecretEncryptionKeyView; use crate::shortint::engine::ShortintEngine; @@ -24,6 +25,7 @@ use super::backward_compatibility::key_switching_key::{ CompressedKeySwitchingKeyMaterialVersions, CompressedKeySwitchingKeyVersions, KeySwitchingKeyMaterialVersions, KeySwitchingKeyVersions, }; +use super::server_key::{StandardServerKey, StandardServerKeyView}; #[cfg(test)] mod test; @@ -71,7 +73,7 @@ impl KeySwitchingKeyMaterial { // It is a bit of a hack, but at this point it seems ok pub(crate) struct KeySwitchingKeyBuildHelper<'keys> { pub(crate) key_switching_key_material: KeySwitchingKeyMaterial, - pub(crate) dest_server_key: &'keys ServerKey, + pub(crate) dest_server_key: StandardServerKeyView<'keys>, pub(crate) src_server_key: Option<&'keys ServerKey>, } @@ -83,7 +85,7 @@ pub(crate) struct KeySwitchingKeyBuildHelper<'keys> { #[versionize(KeySwitchingKeyVersions)] pub struct KeySwitchingKey { pub(crate) key_switching_key_material: KeySwitchingKeyMaterial, - pub(crate) dest_server_key: ServerKey, + pub(crate) dest_server_key: StandardServerKey, pub(crate) src_server_key: Option, } @@ -97,7 +99,7 @@ impl From> for KeySwitchingKey { Self { key_switching_key_material, - dest_server_key: dest_server_key.to_owned(), + dest_server_key: dest_server_key.owned(), src_server_key: src_server_key.map(ToOwned::to_owned), } } @@ -113,7 +115,7 @@ pub struct KeySwitchingKeyMaterialView<'key> { #[derive(Clone, Copy, Debug, PartialEq)] pub struct KeySwitchingKeyView<'keys> { pub(crate) key_switching_key_material: KeySwitchingKeyMaterialView<'keys>, - pub(crate) dest_server_key: &'keys ServerKey, + pub(crate) dest_server_key: StandardServerKeyView<'keys>, pub(crate) src_server_key: Option<&'keys ServerKey>, } @@ -151,6 +153,13 @@ impl<'keys> KeySwitchingKeyBuildHelper<'keys> { without providing a source ServerKey, this is not supported" ); } + let dest_server_key = output_key_pair.1.as_view().try_into().unwrap_or_else(|_| { + panic!( + "Trying to build a shortint::KeySwitchingKey with an unsupported atomic \ + pattern: {:?}", + output_key_pair.1.atomic_pattern.kind() + ) + }); let nb_bits_input: i8 = full_message_modulus_input.ilog2().try_into().unwrap(); let nb_bits_output: i8 = full_message_modulus_output.ilog2().try_into().unwrap(); @@ -162,7 +171,7 @@ impl<'keys> KeySwitchingKeyBuildHelper<'keys> { cast_rshift: nb_bits_output - nb_bits_input, destination_key: params.destination_key, }, - dest_server_key: output_key_pair.1, + dest_server_key, src_server_key: input_key_pair.1, } } @@ -214,13 +223,19 @@ impl KeySwitchingKey { KeySwitchingKeyView { key_switching_key_material: key_switching_key_material.as_view(), - dest_server_key, + dest_server_key: dest_server_key.as_view(), src_server_key: src_server_key.as_ref(), } } /// Deconstruct a [`KeySwitchingKey`] into its constituents. - pub fn into_raw_parts(self) -> (KeySwitchingKeyMaterial, ServerKey, Option) { + pub fn into_raw_parts( + self, + ) -> ( + KeySwitchingKeyMaterial, + StandardServerKey, + Option, + ) { let Self { key_switching_key_material, dest_server_key, @@ -249,6 +264,14 @@ impl KeySwitchingKey { dest_server_key: ServerKey, src_server_key: Option, ) -> Self { + let ap = dest_server_key.atomic_pattern.kind(); + let dest_server_key: StandardServerKey = dest_server_key.try_into().unwrap_or_else(|_| { + panic!( + "Trying to build a shortint::KeySwitchingKey with an unsupported atomic \ + pattern: {ap:?}" + ) + }); + match src_server_key { Some(ref src_server_key) => { let src_lwe_dimension = src_server_key.ciphertext_lwe_dimension(); @@ -281,8 +304,14 @@ impl KeySwitchingKey { } let dst_lwe_dimension = match key_switching_key_material.destination_key { - EncryptionKeyChoice::Big => dest_server_key.bootstrapping_key.output_lwe_dimension(), - EncryptionKeyChoice::Small => dest_server_key.bootstrapping_key.input_lwe_dimension(), + EncryptionKeyChoice::Big => dest_server_key + .atomic_pattern + .bootstrapping_key + .output_lwe_dimension(), + EncryptionKeyChoice::Small => dest_server_key + .atomic_pattern + .bootstrapping_key + .input_lwe_dimension(), }; assert_eq!( @@ -360,7 +389,7 @@ impl<'keys> KeySwitchingKeyView<'keys> { self, ) -> ( KeySwitchingKeyMaterialView<'keys>, - &'keys ServerKey, + StandardServerKeyView<'keys>, Option<&'keys ServerKey>, ) { let Self { @@ -391,6 +420,15 @@ impl<'keys> KeySwitchingKeyView<'keys> { dest_server_key: &'keys ServerKey, src_server_key: Option<&'keys ServerKey>, ) -> Self { + let dest_server_key: StandardServerKeyView = + dest_server_key.as_view().try_into().unwrap_or_else(|_| { + panic!( + "Trying to build a shortint::KeySwitchingKey with an unsupported atomic \ + pattern: {:?}", + dest_server_key.atomic_pattern.kind() + ) + }); + match src_server_key { Some(src_server_key) => { let src_lwe_dimension = src_server_key.ciphertext_lwe_dimension(); @@ -423,8 +461,14 @@ impl<'keys> KeySwitchingKeyView<'keys> { } let dst_lwe_dimension = match key_switching_key_material.destination_key { - EncryptionKeyChoice::Big => dest_server_key.bootstrapping_key.output_lwe_dimension(), - EncryptionKeyChoice::Small => dest_server_key.bootstrapping_key.input_lwe_dimension(), + EncryptionKeyChoice::Big => dest_server_key + .atomic_pattern + .bootstrapping_key + .output_lwe_dimension(), + EncryptionKeyChoice::Small => dest_server_key + .atomic_pattern + .bootstrapping_key + .input_lwe_dimension(), }; assert_eq!( @@ -510,15 +554,18 @@ impl<'keys> KeySwitchingKeyView<'keys> { let output_lwe_size = match self.key_switching_key_material.destination_key { EncryptionKeyChoice::Big => self .dest_server_key + .atomic_pattern .bootstrapping_key .output_lwe_dimension() .to_lwe_size(), EncryptionKeyChoice::Small => self .dest_server_key + .atomic_pattern .bootstrapping_key .input_lwe_dimension() .to_lwe_size(), }; + let mut keyswitched = self .dest_server_key .unchecked_create_trivial_with_lwe_size(Cleartext(0), output_lwe_size); @@ -570,7 +617,7 @@ impl<'keys> KeySwitchingKeyView<'keys> { let res = { let destination_pbs_order: PBSOrder = self.key_switching_key_material.destination_key.into(); - if destination_pbs_order == self.dest_server_key.pbs_order { + if destination_pbs_order == self.dest_server_key.atomic_pattern.pbs_order { CastCiphertext::CorrectKey(keyswitched) } else { // We are arriving under the wrong key for the dest_server_key @@ -586,7 +633,7 @@ impl<'keys> KeySwitchingKeyView<'keys> { ); keyswitch_lwe_ciphertext( - &self.dest_server_key.key_switching_key, + &self.dest_server_key.atomic_pattern.key_switching_key, &wrong_key_ct.ct, &mut correct_key_ct.ct, ); @@ -637,7 +684,7 @@ impl<'keys> KeySwitchingKeyView<'keys> { let buffers = engine.get_computation_buffers(); let acc = self.dest_server_key.generate_lookup_table(function); apply_programmable_bootstrap( - &self.dest_server_key.bootstrapping_key, + &self.dest_server_key.atomic_pattern.bootstrapping_key, &wrong_key_ct.ct, &mut correct_key_ct.ct, &acc.acc, @@ -688,7 +735,7 @@ impl<'keys> KeySwitchingKeyView<'keys> { function(n >> cast_rshift) }); apply_programmable_bootstrap( - &self.dest_server_key.bootstrapping_key, + &self.dest_server_key.atomic_pattern.bootstrapping_key, &wrong_key_ct.ct, &mut correct_key_ct.ct, &acc.acc, @@ -744,7 +791,7 @@ impl<'keys> KeySwitchingKeyView<'keys> { let buffers = engine.get_computation_buffers(); let acc = self.dest_server_key.generate_lookup_table(function); apply_programmable_bootstrap( - &self.dest_server_key.bootstrapping_key, + &self.dest_server_key.atomic_pattern.bootstrapping_key, &wrong_key_ct.ct, &mut correct_key_ct.ct, &acc.acc, @@ -895,7 +942,8 @@ impl CompressedKeySwitchingKey { pub fn decompress(&self) -> KeySwitchingKey { KeySwitchingKey { key_switching_key_material: self.key_switching_key_material.decompress(), - dest_server_key: self.dest_server_key.decompress(), + // CompressedServerKey are only supported for the Classical AP + dest_server_key: self.dest_server_key.decompress().try_into().unwrap(), src_server_key: self .src_server_key .as_ref() diff --git a/tfhe/src/shortint/keycache.rs b/tfhe/src/shortint/keycache.rs index 390573559..524f6f9cf 100644 --- a/tfhe/src/shortint/keycache.rs +++ b/tfhe/src/shortint/keycache.rs @@ -609,7 +609,11 @@ mod wopbs { let params = params.into(); let key = KEY_CACHE.get_from_param(params.0); let wk = self.inner.get_with_closure(params, &mut |_| { - WopbsKey::new_wopbs_key(&key.inner.0, &key.inner.1, ¶ms.1) + WopbsKey::new_wopbs_key( + &key.inner.0, + key.inner.1.as_view().try_into().unwrap(), + ¶ms.1, + ) }); SharedWopbsKey { inner: key.inner, diff --git a/tfhe/src/shortint/list_compression/server_keys.rs b/tfhe/src/shortint/list_compression/server_keys.rs index 26f7953f0..7a13bcdd8 100644 --- a/tfhe/src/shortint/list_compression/server_keys.rs +++ b/tfhe/src/shortint/list_compression/server_keys.rs @@ -1,6 +1,7 @@ use super::CompressionPrivateKeys; use crate::conformance::ParameterSetConformant; use crate::core_crypto::prelude::*; +use crate::shortint::atomic_pattern::AtomicPatternParameters; use crate::shortint::backward_compatibility::list_compression::{ CompressionKeyVersions, DecompressionKeyVersions, }; @@ -10,7 +11,7 @@ use crate::shortint::parameters::{CompressionParameters, PolynomialSize}; use crate::shortint::server_key::{ PBSConformanceParams, PbsTypeConformanceParams, ShortintBootstrappingKey, }; -use crate::shortint::{EncryptionKeyChoice, PBSParameters}; +use crate::shortint::EncryptionKeyChoice; use serde::{Deserialize, Serialize}; use std::fmt::Debug; use tfhe_versionable::Versionize; @@ -120,8 +121,10 @@ pub struct CompressionKeyConformanceParams { pub cipherext_modulus: CiphertextModulus, } -impl From<(PBSParameters, CompressionParameters)> for CompressionKeyConformanceParams { - fn from((pbs_params, compression_params): (PBSParameters, CompressionParameters)) -> Self { +impl From<(AtomicPatternParameters, CompressionParameters)> for CompressionKeyConformanceParams { + fn from( + (ap_params, compression_params): (AtomicPatternParameters, CompressionParameters), + ) -> Self { Self { br_level: compression_params.br_level, br_base_log: compression_params.br_base_log, @@ -131,9 +134,9 @@ impl From<(PBSParameters, CompressionParameters)> for CompressionKeyConformanceP packing_ks_glwe_dimension: compression_params.packing_ks_glwe_dimension, lwe_per_glwe: compression_params.lwe_per_glwe, storage_log_modulus: compression_params.storage_log_modulus, - uncompressed_polynomial_size: pbs_params.polynomial_size(), - uncompressed_glwe_dimension: pbs_params.glwe_dimension(), - cipherext_modulus: pbs_params.ciphertext_modulus(), + uncompressed_polynomial_size: ap_params.polynomial_size(), + uncompressed_glwe_dimension: ap_params.glwe_dimension(), + cipherext_modulus: ap_params.ciphertext_modulus(), } } } diff --git a/tfhe/src/shortint/noise_squashing/server_key.rs b/tfhe/src/shortint/noise_squashing/server_key.rs index 4565097d3..7a30f5173 100644 --- a/tfhe/src/shortint/noise_squashing/server_key.rs +++ b/tfhe/src/shortint/noise_squashing/server_key.rs @@ -11,6 +11,7 @@ use crate::core_crypto::algorithms::lwe_programmable_bootstrapping::{ use crate::core_crypto::entities::{Fourier128LweBootstrapKeyOwned, LweCiphertext}; use crate::core_crypto::fft_impl::fft128::math::fft::Fft128; use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::LweBootstrapKeyConformanceParams; +use crate::shortint::atomic_pattern::{AtomicPattern, AtomicPatternParameters}; use crate::shortint::backward_compatibility::noise_squashing::NoiseSquashingKeyVersions; use crate::shortint::ciphertext::{Ciphertext, SquashedNoiseCiphertext}; use crate::shortint::client_key::ClientKey; @@ -23,6 +24,7 @@ use crate::shortint::parameters::{ }; use crate::shortint::server_key::{ ModulusSwitchNoiseReductionKey, ModulusSwitchNoiseReductionKeyConformanceParams, ServerKey, + StandardServerKeyView, }; use serde::{Deserialize, Serialize}; use tfhe_versionable::Versionize; @@ -213,25 +215,40 @@ impl NoiseSquashingKey { )); } + // For the moment, noise squashing is only implemented for the Standard AP + let src_server_key: StandardServerKeyView = + src_server_key.as_view().try_into().map_err(|_| { + crate::error!( + "Noise squashing is not supported by the selected atomic pattern ({:?})", + src_server_key.atomic_pattern.kind() + ) + })?; + Ok(self.unchecked_squash_ciphertext_noise(ciphertext, src_server_key)) } pub fn unchecked_squash_ciphertext_noise( &self, ciphertext: &Ciphertext, - src_server_key: &ServerKey, + src_server_key: StandardServerKeyView, ) -> SquashedNoiseCiphertext { - let mut lwe_before_noise_squashing = match src_server_key.pbs_order { + let mut lwe_before_noise_squashing = match src_server_key.atomic_pattern.pbs_order { // Under the big key, first need to keyswitch PBSOrder::KeyswitchBootstrap => { let mut after_ks_ct = LweCiphertext::new( 0u64, - src_server_key.key_switching_key.output_lwe_size(), - src_server_key.key_switching_key.ciphertext_modulus(), + src_server_key + .atomic_pattern + .key_switching_key + .output_lwe_size(), + src_server_key + .atomic_pattern + .key_switching_key + .ciphertext_modulus(), ); keyswitch_lwe_ciphertext( - &src_server_key.key_switching_key, + &src_server_key.atomic_pattern.key_switching_key, &ciphertext.ct, &mut after_ks_ct, ); @@ -382,6 +399,22 @@ impl TryFrom<(PBSParameters, NoiseSquashingParameters)> for NoiseSquashingKeyCon } } +impl TryFrom<(AtomicPatternParameters, NoiseSquashingParameters)> + for NoiseSquashingKeyConformanceParams +{ + type Error = crate::Error; + + fn try_from( + (ap_params, noise_squashing_params): (AtomicPatternParameters, NoiseSquashingParameters), + ) -> Result { + match ap_params { + AtomicPatternParameters::Standard(pbs_params) => { + (pbs_params, noise_squashing_params).try_into() + } + } + } +} + impl ParameterSetConformant for NoiseSquashingKey { type ParameterSet = NoiseSquashingKeyConformanceParams; diff --git a/tfhe/src/shortint/oprf.rs b/tfhe/src/shortint/oprf.rs index a2482e392..bf51fdbef 100644 --- a/tfhe/src/shortint/oprf.rs +++ b/tfhe/src/shortint/oprf.rs @@ -1,14 +1,18 @@ +use super::server_key::{ + apply_programmable_bootstrap_no_ms_noise_reduction, GenericServerKey, LookupTableSize, + ShortintBootstrappingKey, +}; use super::Ciphertext; use crate::core_crypto::fft_impl::common::modulus_switch; use crate::core_crypto::prelude::{ - keyswitch_lwe_ciphertext, lwe_ciphertext_plaintext_add_assign, CiphertextModulus, - CiphertextModulusLog, LweCiphertext, LweSize, Plaintext, + lwe_ciphertext_plaintext_add_assign, CiphertextModulus, CiphertextModulusLog, LweCiphertext, + LweCiphertextOwned, LweSize, Plaintext, }; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::Degree; use crate::shortint::engine::ShortintEngine; -use crate::shortint::parameters::{AtomicPatternKind, NoiseLevel}; -use crate::shortint::server_key::apply_programmable_bootstrap_no_ms_noise_reduction; -use crate::shortint::{PBSOrder, ServerKey}; +use crate::shortint::parameters::NoiseLevel; +use crate::shortint::server_key::generate_lookup_table_no_encode; use tfhe_csprng::seeders::Seed; pub fn sha3_hash(values: &mut [u64], seed: Seed) { @@ -54,7 +58,75 @@ pub fn create_random_from_seed_modulus_switched( ct } -impl ServerKey { + +/// Uniformly generates a random encrypted value in `[0, 2^random_bits_count[`, using a PBS. +/// +/// `full_bits_count` is the size of the lwe message, ie the shortint message + carry + padding +/// bit. +/// The output in in the form 0000rrr000noise (rbc=3, fbc=7) +/// The encryted value is oblivious to the server. +/// +/// It is the reponsiblity of the calling AP to transform this into a shortint ciphertext. The +/// returned LWE is in the post PBS state, so a Keyswitch might be needed if the order is PBS-KS. +pub(crate) fn generate_pseudo_random_from_pbs( + bootstrapping_key: &ShortintBootstrappingKey, + seed: Seed, + random_bits_count: u64, + full_bits_count: u64, + ciphertext_modulus: CiphertextModulus, +) -> (LweCiphertextOwned, Degree) { + assert!( + random_bits_count <= full_bits_count, + "The number of random bits asked for (={random_bits_count}) is bigger than full_bits_count (={full_bits_count})" + ); + + let in_lwe_size = bootstrapping_key.input_lwe_dimension().to_lwe_size(); + + let seeded = create_random_from_seed_modulus_switched( + seed, + in_lwe_size, + bootstrapping_key + .polynomial_size() + .to_blind_rotation_input_modulus_log(), + ciphertext_modulus, + ); + + let p = 1 << random_bits_count; + let degree = p - 1; + + let delta = 1_u64 << (64 - full_bits_count); + + let poly_delta = 2 * bootstrapping_key.polynomial_size().0 as u64 / p; + + let lut_size = LookupTableSize::new( + bootstrapping_key.glwe_size(), + bootstrapping_key.polynomial_size(), + ); + let acc = generate_lookup_table_no_encode(lut_size, ciphertext_modulus, |x| { + (2 * (x / poly_delta) + 1) * delta / 2 + }); + + let out_lwe_size = bootstrapping_key.output_lwe_dimension().to_lwe_size(); + + let mut ct = LweCiphertext::new(0, out_lwe_size, ciphertext_modulus); + + ShortintEngine::with_thread_local_mut(|engine| { + let buffers = engine.get_computation_buffers(); + + apply_programmable_bootstrap_no_ms_noise_reduction( + bootstrapping_key, + &seeded, + &mut ct, + &acc, + buffers, + ); + }); + + lwe_ciphertext_plaintext_add_assign(&mut ct, Plaintext(degree * delta / 2)); + (ct, Degree(degree)) +} + +impl GenericServerKey { /// Uniformly generates a random encrypted value in `[0, 2^random_bits_count[` /// `2^random_bits_count` must be smaller than the message modulus /// The encryted value is oblivious to the server @@ -102,82 +174,19 @@ impl ServerKey { "The number of random bits asked for (={random_bits_count}) is bigger than carry_bits_count (={carry_bits_count}) + message_bits_count(={message_bits_count})", ); - self.generate_oblivious_pseudo_random_custom_encoding( + let (ct, degree) = self.atomic_pattern.generate_oblivious_pseudo_random( seed, random_bits_count, 1 + carry_bits_count + message_bits_count, - ) - } - - /// Uniformly generates a random encrypted value in `[0, 2^random_bits_count[` - /// The output in in the form 0000rrr000noise (rbc=3, fbc=7) - /// The encryted value is oblivious to the server - pub(crate) fn generate_oblivious_pseudo_random_custom_encoding( - &self, - seed: Seed, - random_bits_count: u64, - full_bits_count: u64, - ) -> Ciphertext { - assert!( - random_bits_count <= full_bits_count, - "The number of random bits asked for (={random_bits_count}) is bigger than full_bits_count (={full_bits_count})" ); - let in_lwe_size = self.bootstrapping_key.input_lwe_dimension().to_lwe_size(); - - let seeded = create_random_from_seed_modulus_switched( - seed, - in_lwe_size, - self.bootstrapping_key - .polynomial_size() - .to_blind_rotation_input_modulus_log(), - self.ciphertext_modulus, - ); - - let p = 1 << random_bits_count; - - let delta = 1_u64 << (64 - full_bits_count); - - let poly_delta = 2 * self.bootstrapping_key.polynomial_size().0 as u64 / p; - - let acc = self.generate_lookup_table_no_encode(|x| (2 * (x / poly_delta) + 1) * delta / 2); - - let out_lwe_size = self.bootstrapping_key.output_lwe_dimension().to_lwe_size(); - - let mut ct = LweCiphertext::new(0, out_lwe_size, self.ciphertext_modulus); - - ShortintEngine::with_thread_local_mut(|engine| { - let buffers = engine.get_computation_buffers(); - - apply_programmable_bootstrap_no_ms_noise_reduction( - &self.bootstrapping_key, - &seeded, - &mut ct, - &acc, - buffers, - ); - }); - - lwe_ciphertext_plaintext_add_assign(&mut ct, Plaintext((p - 1) * delta / 2)); - - let ct = match self.pbs_order { - PBSOrder::KeyswitchBootstrap => ct, - PBSOrder::BootstrapKeyswitch => { - let mut ct_ksed = LweCiphertext::new(0, in_lwe_size, self.ciphertext_modulus); - - keyswitch_lwe_ciphertext(&self.key_switching_key, &ct, &mut ct_ksed); - - ct_ksed - } - }; - Ciphertext::new( ct, - Degree::new(p - 1), + degree, NoiseLevel::NOMINAL, self.message_modulus, self.carry_modulus, - AtomicPatternKind::Standard(self.pbs_order), + self.atomic_pattern.kind(), ) } } @@ -224,12 +233,12 @@ pub(crate) mod test { let img = sk.generate_oblivious_pseudo_random(seed, random_bits_count); - let lwe_size = sk.bootstrapping_key.input_lwe_dimension().to_lwe_size(); + let lwe_size = params.lwe_dimension().to_lwe_size(); let ct = create_random_from_seed_modulus_switched( seed, lwe_size, - sk.bootstrapping_key + params .polynomial_size() .to_blind_rotation_input_modulus_log(), sk.ciphertext_modulus, diff --git a/tfhe/src/shortint/parameters/compact_public_key_only.rs b/tfhe/src/shortint/parameters/compact_public_key_only.rs index dfdd53bd1..18428e20d 100644 --- a/tfhe/src/shortint/parameters/compact_public_key_only.rs +++ b/tfhe/src/shortint/parameters/compact_public_key_only.rs @@ -1,3 +1,4 @@ +use crate::shortint::atomic_pattern::AtomicPatternParameters; use crate::shortint::backward_compatibility::parameters::compact_public_key_only::{ CompactCiphertextListExpansionKindVersions, CompactPublicKeyEncryptionParametersVersions, }; @@ -170,3 +171,16 @@ impl TryFrom for CompactPublicKeyEncryptionParameters { params.try_into() } } + +impl TryFrom for CompactPublicKeyEncryptionParameters { + type Error = Error; + + fn try_from(value: AtomicPatternParameters) -> Result { + match value { + AtomicPatternParameters::Standard(pbsparameters) => { + let params: ShortintParameterSet = pbsparameters.into(); + params.try_into() + } + } + } +} diff --git a/tfhe/src/shortint/server_key/add.rs b/tfhe/src/shortint/server_key/add.rs index 1197a19bb..c6e31c1b5 100644 --- a/tfhe/src/shortint/server_key/add.rs +++ b/tfhe/src/shortint/server_key/add.rs @@ -1,10 +1,11 @@ use super::{CiphertextNoiseDegree, SmartCleaningOperation}; use crate::core_crypto::algorithms::*; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::Degree; -use crate::shortint::server_key::CheckError; -use crate::shortint::{Ciphertext, MaxNoiseLevel, ServerKey}; +use crate::shortint::server_key::{CheckError, GenericServerKey}; +use crate::shortint::{Ciphertext, MaxNoiseLevel}; -impl ServerKey { +impl GenericServerKey { /// Compute homomorphically an addition between two ciphertexts encrypting integer values. /// /// This function, like all "default" operations (i.e. not smart, checked or unchecked), will diff --git a/tfhe/src/shortint/server_key/bitwise_op.rs b/tfhe/src/shortint/server_key/bitwise_op.rs index 445791161..ca7c813ef 100644 --- a/tfhe/src/shortint/server_key/bitwise_op.rs +++ b/tfhe/src/shortint/server_key/bitwise_op.rs @@ -1,9 +1,10 @@ -use super::ServerKey; use crate::core_crypto::algorithms::lwe_ciphertext_opposite_assign; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::Degree; +use crate::shortint::server_key::GenericServerKey; use crate::shortint::{CheckError, Ciphertext}; -impl ServerKey { +impl GenericServerKey { /// Compute homomorphically an AND between two ciphertexts encrypting integer values. /// /// This function, like all "default" operations (i.e. not smart, checked or unchecked), will diff --git a/tfhe/src/shortint/server_key/bivariate_pbs.rs b/tfhe/src/shortint/server_key/bivariate_pbs.rs index 4278fc24d..4bd881288 100644 --- a/tfhe/src/shortint/server_key/bivariate_pbs.rs +++ b/tfhe/src/shortint/server_key/bivariate_pbs.rs @@ -1,7 +1,9 @@ -use super::{CheckError, CiphertextNoiseDegree, LookupTable, ServerKey}; +use super::{CheckError, CiphertextNoiseDegree, LookupTable}; use crate::core_crypto::prelude::container::Container; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::{Degree, MaxDegree, NoiseLevel}; use crate::shortint::server_key::add::unchecked_add_assign; +use crate::shortint::server_key::GenericServerKey; use crate::shortint::{Ciphertext, MessageModulus}; use std::cmp::Ordering; @@ -22,8 +24,8 @@ pub type BivariateLookupTableView<'a> = BivariateLookupTable<&'a [u64]>; /// Returns whether it is possible to pack lhs and rhs into a unique /// ciphertext without exceeding the max storable value using the formula: /// `unique_ciphertext = (lhs * factor) + rhs` -fn ciphertexts_can_be_packed_without_exceeding_space_or_noise( - server_key: &ServerKey, +fn ciphertexts_can_be_packed_without_exceeding_space_or_noise( + server_key: &GenericServerKey, lhs: CiphertextNoiseDegree, rhs: CiphertextNoiseDegree, factor: u64, @@ -50,7 +52,7 @@ fn ciphertexts_can_be_packed_without_exceeding_space_or_noise( Ok(()) } -impl ServerKey { +impl GenericServerKey { /// Generates a bivariate accumulator pub fn generate_lookup_table_bivariate_with_factor( &self, diff --git a/tfhe/src/shortint/server_key/comp_op.rs b/tfhe/src/shortint/server_key/comp_op.rs index fd869b2a6..6c175d2ec 100644 --- a/tfhe/src/shortint/server_key/comp_op.rs +++ b/tfhe/src/shortint/server_key/comp_op.rs @@ -1,5 +1,5 @@ -use super::ServerKey; -use crate::shortint::server_key::CheckError; +use crate::shortint::atomic_pattern::AtomicPattern; +use crate::shortint::server_key::{CheckError, GenericServerKey}; use crate::shortint::Ciphertext; // # Note: @@ -8,7 +8,7 @@ use crate::shortint::Ciphertext; // however, comparisons like equality do not have that, "==" does not have and "===", // ">=" is greater of equal, not greater_assign. -impl ServerKey { +impl GenericServerKey { /// Compute homomorphically a `>` between two ciphertexts encrypting integer values. /// /// This function, like all "default" operations (i.e. not smart, checked or unchecked), will diff --git a/tfhe/src/shortint/server_key/compressed.rs b/tfhe/src/shortint/server_key/compressed.rs index f0b127df1..d26cb25f2 100644 --- a/tfhe/src/shortint/server_key/compressed.rs +++ b/tfhe/src/shortint/server_key/compressed.rs @@ -8,6 +8,7 @@ use super::{ use crate::conformance::ParameterSetConformant; use crate::core_crypto::fft_impl::fft64::crypto::bootstrap::LweBootstrapKeyConformanceParams; use crate::core_crypto::prelude::*; +use crate::shortint::atomic_pattern::StandardAtomicPatternServerKey; use crate::shortint::backward_compatibility::server_key::{ CompressedServerKeyVersions, ShortintCompressedBootstrappingKeyVersions, }; @@ -262,15 +263,19 @@ impl CompressedServerKey { let ciphertext_modulus = *ciphertext_modulus; let pbs_order = *pbs_order; - ServerKey { + let atomic_pattern = StandardAtomicPatternServerKey::from_raw_parts( key_switching_key, bootstrapping_key, + pbs_order, + ); + + ServerKey { + atomic_pattern: atomic_pattern.into(), message_modulus, carry_modulus, max_degree, max_noise_level, ciphertext_modulus, - pbs_order, } } diff --git a/tfhe/src/shortint/server_key/div_mod.rs b/tfhe/src/shortint/server_key/div_mod.rs index 479f0c002..95fbf177c 100644 --- a/tfhe/src/shortint/server_key/div_mod.rs +++ b/tfhe/src/shortint/server_key/div_mod.rs @@ -1,6 +1,8 @@ -use crate::shortint::{Ciphertext, ServerKey}; +use crate::shortint::atomic_pattern::AtomicPattern; +use crate::shortint::server_key::GenericServerKey; +use crate::shortint::Ciphertext; -impl ServerKey { +impl GenericServerKey { /// Compute a division between two ciphertexts. /// /// The result is returned in a _new_ ciphertext. diff --git a/tfhe/src/shortint/server_key/mod.rs b/tfhe/src/shortint/server_key/mod.rs index 6eb5af642..1397ac481 100644 --- a/tfhe/src/shortint/server_key/mod.rs +++ b/tfhe/src/shortint/server_key/mod.rs @@ -49,9 +49,9 @@ use crate::shortint::engine::{ ShortintEngine, }; use crate::shortint::parameters::{ - AtomicPatternKind, CarryModulus, CiphertextConformanceParams, CiphertextModulus, MessageModulus, + CarryModulus, CiphertextConformanceParams, CiphertextModulus, MessageModulus, }; -use crate::shortint::{EncryptionKeyChoice, PBSOrder, PaddingBit, ShortintEncoding}; +use crate::shortint::{PaddingBit, ShortintEncoding}; use aligned_vec::ABox; use serde::{Deserialize, Serialize}; use std::fmt::{Debug, Display, Formatter}; @@ -74,6 +74,10 @@ pub mod pbs_stats { #[cfg(feature = "pbs-stats")] pub use pbs_stats::*; +use super::atomic_pattern::{ + AtomicPattern, AtomicPatternMut, AtomicPatternParameters, AtomicPatternServerKey, + StandardAtomicPatternServerKey, +}; use super::backward_compatibility::server_key::{ SerializableShortintBootstrappingKeyVersions, ServerKeyVersions, }; @@ -413,11 +417,10 @@ impl ShortintBootstrappingKey { /// /// The server key is generated by the client and is meant to be published: the client /// sends it to the server so it can compute homomorphic circuits. -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Versionize)] #[versionize(ServerKeyVersions)] -pub struct ServerKey { - pub key_switching_key: LweKeyswitchKeyOwned, - pub bootstrapping_key: ShortintBootstrappingKey, +pub struct GenericServerKey { + pub atomic_pattern: AP, // Size of the message buffer pub message_modulus: MessageModulus, // Size of the carry buffer @@ -427,10 +430,95 @@ pub struct ServerKey { pub max_noise_level: MaxNoiseLevel, // Modulus use for computations on the ciphertext pub ciphertext_modulus: CiphertextModulus, - pub pbs_order: PBSOrder, } -/// Represents the number of elements in a [`LookupTable`] represented by a Glwe ciphertext +impl GenericServerKey<&AP> { + pub fn owned(&self) -> GenericServerKey { + GenericServerKey { + atomic_pattern: self.atomic_pattern.clone(), + message_modulus: self.message_modulus, + carry_modulus: self.carry_modulus, + max_degree: self.max_degree, + max_noise_level: self.max_noise_level, + ciphertext_modulus: self.ciphertext_modulus, + } + } +} + +pub type ServerKey = GenericServerKey; +pub type StandardServerKey = GenericServerKey; +pub type ServerKeyView<'key> = GenericServerKey<&'key AtomicPatternServerKey>; +pub type StandardServerKeyView<'key> = GenericServerKey<&'key StandardAtomicPatternServerKey>; + +// Manual implementation of Copy because the derive will require AP to be Copy, +// which is actually overrestrictive: https://github.com/rust-lang/rust/issues/26925 +impl Copy for StandardServerKeyView<'_> {} + +impl From for ServerKey { + fn from(value: StandardServerKey) -> Self { + let atomic_pattern = AtomicPatternServerKey::Standard(value.atomic_pattern); + + Self { + atomic_pattern, + message_modulus: value.message_modulus, + carry_modulus: value.carry_modulus, + max_degree: value.max_degree, + max_noise_level: value.max_noise_level, + ciphertext_modulus: value.ciphertext_modulus, + } + } +} + +#[derive(Debug)] +pub struct UnsupportedOperation; + +impl Display for UnsupportedOperation { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Operation not supported by the current configuration") + } +} + +impl std::error::Error for UnsupportedOperation {} + +impl TryFrom for StandardServerKey { + type Error = UnsupportedOperation; + + fn try_from(value: ServerKey) -> Result { + let atomic_pattern = match value.atomic_pattern { + AtomicPatternServerKey::Standard(ap) => ap, + }; + + Ok(Self { + atomic_pattern, + message_modulus: value.message_modulus, + carry_modulus: value.carry_modulus, + max_degree: value.max_degree, + max_noise_level: value.max_noise_level, + ciphertext_modulus: value.ciphertext_modulus, + }) + } +} + +impl<'key> TryFrom> for StandardServerKeyView<'key> { + type Error = UnsupportedOperation; + + fn try_from(value: ServerKeyView<'key>) -> Result { + let atomic_pattern = match value.atomic_pattern { + AtomicPatternServerKey::Standard(ap) => ap, + }; + + Ok(Self { + atomic_pattern, + message_modulus: value.message_modulus, + carry_modulus: value.carry_modulus, + max_degree: value.max_degree, + max_noise_level: value.max_noise_level, + ciphertext_modulus: value.ciphertext_modulus, + }) + } +} + +/// The number of elements in a [`LookupTable`] represented by a Glwe ciphertext #[derive(Copy, Clone, Debug)] pub struct LookupTableSize { glwe_size: GlweSize, @@ -509,47 +597,41 @@ impl ServerKey { engine.new_server_key_with_max_degree(cks, max_degree) }) } +} +impl GenericServerKey { pub fn ciphertext_lwe_dimension(&self) -> LweDimension { - match self.pbs_order { - PBSOrder::KeyswitchBootstrap => self.key_switching_key.input_key_lwe_dimension(), - PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_dimension(), + self.atomic_pattern.ciphertext_lwe_dimension() + } + + pub fn as_view(&self) -> GenericServerKey<&AP> { + GenericServerKey { + atomic_pattern: &self.atomic_pattern, + message_modulus: self.message_modulus, + carry_modulus: self.carry_modulus, + max_degree: self.max_degree, + max_noise_level: self.max_noise_level, + ciphertext_modulus: self.ciphertext_modulus, } } /// Deconstruct a [`ServerKey`] into its constituents. - pub fn into_raw_parts( - self, - ) -> ( - LweKeyswitchKeyOwned, - ShortintBootstrappingKey, - MessageModulus, - CarryModulus, - MaxDegree, - MaxNoiseLevel, - CiphertextModulus, - PBSOrder, - ) { + pub fn into_raw_parts(self) -> (AP, MessageModulus, CarryModulus, MaxDegree, MaxNoiseLevel) { let Self { - key_switching_key, - bootstrapping_key, + atomic_pattern, message_modulus, carry_modulus, max_degree, max_noise_level, - ciphertext_modulus, - pbs_order, + ciphertext_modulus: _, } = self; ( - key_switching_key, - bootstrapping_key, + atomic_pattern, message_modulus, carry_modulus, max_degree, max_noise_level, - ciphertext_modulus, - pbs_order, ) } @@ -558,44 +640,13 @@ impl ServerKey { /// # Panics /// /// Panics if the constituents are not compatible with each others. - #[allow(clippy::too_many_arguments)] pub fn from_raw_parts( - key_switching_key: LweKeyswitchKeyOwned, - bootstrapping_key: ShortintBootstrappingKey, + atomic_pattern_key: AP, message_modulus: MessageModulus, carry_modulus: CarryModulus, max_degree: MaxDegree, max_noise_level: MaxNoiseLevel, - ciphertext_modulus: CiphertextModulus, - pbs_order: PBSOrder, ) -> Self { - assert_eq!( - key_switching_key.input_key_lwe_dimension(), - bootstrapping_key.output_lwe_dimension(), - "Mismatch between the input LweKeyswitchKey LweDimension ({:?}) \ - and the ShortintBootstrappingKey output LweDimension ({:?})", - key_switching_key.input_key_lwe_dimension(), - bootstrapping_key.output_lwe_dimension() - ); - - assert_eq!( - key_switching_key.output_key_lwe_dimension(), - bootstrapping_key.input_lwe_dimension(), - "Mismatch between the output LweKeyswitchKey LweDimension ({:?}) \ - and the ShortintBootstrappingKey input LweDimension ({:?})", - key_switching_key.output_key_lwe_dimension(), - bootstrapping_key.input_lwe_dimension() - ); - - assert_eq!( - key_switching_key.ciphertext_modulus(), - ciphertext_modulus, - "Mismatch between the LweKeyswitchKey CiphertextModulus ({:?}) \ - and the provided CiphertextModulus ({:?})", - key_switching_key.ciphertext_modulus(), - ciphertext_modulus - ); - let max_max_degree = MaxDegree::from_msg_carry_modulus(message_modulus, carry_modulus); assert!( @@ -603,27 +654,22 @@ impl ServerKey { "Maximum valid MaxDegree is {max_max_degree:?}, got ({max_degree:?})" ); + let ciphertext_modulus = atomic_pattern_key.ciphertext_modulus(); + Self { - key_switching_key, - bootstrapping_key, + atomic_pattern: atomic_pattern_key, message_modulus, carry_modulus, max_degree, max_noise_level, ciphertext_modulus, - pbs_order, } } pub fn conformance_params(&self) -> CiphertextConformanceParams { let lwe_dim = self.ciphertext_lwe_dimension(); - let ms_decompression_method = match &self.bootstrapping_key { - ShortintBootstrappingKey::Classic { .. } => MsDecompressionType::ClassicPbs, - ShortintBootstrappingKey::MultiBit { fourier_bsk, .. } => { - MsDecompressionType::MultiBitPbs(fourier_bsk.grouping_factor()) - } - }; + let ms_decompression_method = self.atomic_pattern.ciphertext_decompression_method(); let ct_params = LweCiphertextConformanceParams { lwe_dim, @@ -636,7 +682,7 @@ impl ServerKey { message_modulus: self.message_modulus, carry_modulus: self.carry_modulus, degree: Degree::new(self.message_modulus.0 - 1), - atomic_pattern: AtomicPatternKind::Standard(self.pbs_order), + atomic_pattern: self.atomic_pattern.kind(), noise_level: NoiseLevel::NOMINAL, } } @@ -678,10 +724,7 @@ impl ServerKey { where F: Fn(u64) -> u64, { - let size = LookupTableSize::new( - self.bootstrapping_key.glwe_size(), - self.bootstrapping_key.polynomial_size(), - ); + let size = self.atomic_pattern.lookup_table_size(); generate_lookup_table( size, self.ciphertext_modulus, @@ -691,17 +734,6 @@ impl ServerKey { ) } - pub(crate) fn generate_lookup_table_no_encode(&self, f: F) -> GlweCiphertextOwned - where - F: Fn(u64) -> u64, - { - let size = LookupTableSize::new( - self.bootstrapping_key.glwe_size(), - self.bootstrapping_key.polynomial_size(), - ); - generate_lookup_table_no_encode(size, self.ciphertext_modulus, f) - } - /// Given a function as input, constructs the lookup table working on the message bits /// Carry bits are ignored /// @@ -770,17 +802,18 @@ impl ServerKey { &self, functions: &[&dyn Fn(u64) -> u64], ) -> ManyLookupTableOwned { + let lut_size = self.atomic_pattern.lookup_table_size(); let mut acc = GlweCiphertext::new( 0, - self.bootstrapping_key.glwe_size(), - self.bootstrapping_key.polynomial_size(), + lut_size.glwe_size(), + lut_size.polynomial_size(), self.ciphertext_modulus, ); let (input_max_degree, sample_extraction_stride, per_function_output_degree) = fill_many_lut_accumulator( &mut acc, - self.bootstrapping_key.polynomial_size(), - self.bootstrapping_key.glwe_size(), + lut_size.polynomial_size(), + lut_size.glwe_size(), self.message_modulus, self.carry_modulus, functions, @@ -831,41 +864,7 @@ impl ServerKey { return; } - ShortintEngine::with_thread_local_mut(|engine| { - let (mut ciphertext_buffer, buffers) = engine.get_buffers(self); - match self.pbs_order { - PBSOrder::KeyswitchBootstrap => { - keyswitch_lwe_ciphertext( - &self.key_switching_key, - &ct.ct, - &mut ciphertext_buffer, - ); - - apply_programmable_bootstrap( - &self.bootstrapping_key, - &ciphertext_buffer, - &mut ct.ct, - &acc.acc, - buffers, - ); - } - PBSOrder::BootstrapKeyswitch => { - apply_programmable_bootstrap( - &self.bootstrapping_key, - &ct.ct, - &mut ciphertext_buffer, - &acc.acc, - buffers, - ); - - keyswitch_lwe_ciphertext( - &self.key_switching_key, - &ciphertext_buffer, - &mut ct.ct, - ); - } - } - }); + self.atomic_pattern.apply_lookup_table_assign(ct, acc); ct.degree = acc.degree; ct.set_noise_level_to_nominal(); @@ -909,12 +908,19 @@ impl ServerKey { pub fn apply_many_lookup_table( &self, ct: &Ciphertext, - acc: &ManyLookupTableOwned, + lut: &ManyLookupTableOwned, ) -> Vec { - match self.pbs_order { - PBSOrder::KeyswitchBootstrap => self.keyswitch_programmable_bootstrap_many_lut(ct, acc), - PBSOrder::BootstrapKeyswitch => self.programmable_bootstrap_keyswitch_many_lut(ct, acc), + if ct.is_trivial() { + return self.trivial_pbs_many_lut(ct, lut); } + + let mut results = self.atomic_pattern.apply_many_lookup_table(ct, lut); + + for ct in results.iter_mut() { + ct.set_noise_level_to_nominal(); + } + + results } /// Applies the given function to the message of a ciphertext @@ -1139,20 +1145,13 @@ impl ServerKey { lwe_size, self.message_modulus, self.carry_modulus, - AtomicPatternKind::Standard(self.pbs_order), + self.atomic_pattern.kind(), self.ciphertext_modulus, ) } pub fn unchecked_create_trivial(&self, value: u64) -> Ciphertext { - let lwe_size = match self.pbs_order { - PBSOrder::KeyswitchBootstrap => { - self.bootstrapping_key.output_lwe_dimension().to_lwe_size() - } - PBSOrder::BootstrapKeyswitch => { - self.bootstrapping_key.input_lwe_dimension().to_lwe_size() - } - }; + let lwe_size = self.atomic_pattern.ciphertext_lwe_dimension().to_lwe_size(); self.unchecked_create_trivial_with_lwe_size(Cleartext(value), lwe_size) } @@ -1170,29 +1169,8 @@ impl ServerKey { ct.set_noise_level(NoiseLevel::ZERO, self.max_noise_level); } - pub fn bootstrapping_key_size_elements(&self) -> usize { - self.bootstrapping_key.bootstrapping_key_size_elements() - } - - pub fn bootstrapping_key_size_bytes(&self) -> usize { - self.bootstrapping_key.bootstrapping_key_size_bytes() - } - - pub fn key_switching_key_size_elements(&self) -> usize { - self.key_switching_key.as_ref().len() - } - - pub fn key_switching_key_size_bytes(&self) -> usize { - std::mem::size_of_val(self.key_switching_key.as_ref()) - } - - pub fn deterministic_pbs_execution(&self) -> bool { - self.bootstrapping_key.deterministic_pbs_execution() - } - - pub fn set_deterministic_pbs_execution(&mut self, new_deterministic_execution: bool) { - self.bootstrapping_key - .set_deterministic_pbs_execution(new_deterministic_execution); + pub fn deterministic_execution(&self) -> bool { + self.atomic_pattern.deterministic_execution() } fn trivial_pbs_assign(&self, ct: &mut Ciphertext, acc: &LookupTableOwned) { @@ -1208,7 +1186,8 @@ impl ServerKey { .decode(Plaintext(*ct.ct.get_body().data)) .0; - let box_size = self.bootstrapping_key.polynomial_size().0 / modulus_sup as usize; + let lut_size = self.atomic_pattern.lookup_table_size(); + let box_size = lut_size.polynomial_size().0 / modulus_sup as usize; let result = if ct_value >= modulus_sup { // padding bit is 1 let ct_value = ct_value % modulus_sup; @@ -1233,7 +1212,8 @@ impl ServerKey { .decode(Plaintext(*ct.ct.get_body().data)) .0; - let box_size = self.bootstrapping_key.polynomial_size().0 / modulus_sup as usize; + let lut_size = self.atomic_pattern.lookup_table_size(); + let box_size = lut_size.polynomial_size().0 / modulus_sup as usize; let padding_bit_set = ct_value >= modulus_sup; let first_result_index_in_lut = { @@ -1279,104 +1259,12 @@ impl ServerKey { outputs } +} - pub(crate) fn keyswitch_programmable_bootstrap_many_lut( - &self, - ct: &Ciphertext, - lut: &ManyLookupTableOwned, - ) -> Vec { - if ct.is_trivial() { - return self.trivial_pbs_many_lut(ct, lut); - } - - let mut acc = lut.acc.clone(); - - ShortintEngine::with_thread_local_mut(|engine| { - // Compute the programmable bootstrapping with fixed test polynomial - let (mut ciphertext_buffer, buffers) = engine.get_buffers(self); - - // Compute a key switch - keyswitch_lwe_ciphertext(&self.key_switching_key, &ct.ct, &mut ciphertext_buffer); - - apply_blind_rotate( - &self.bootstrapping_key, - &ciphertext_buffer.as_view(), - &mut acc, - buffers, - ); - }); - - // The accumulator has been rotated, we can now proceed with the various sample extractions - let function_count = lut.function_count(); - let mut outputs = Vec::with_capacity(function_count); - - for (fn_idx, output_degree) in lut.per_function_output_degree.iter().enumerate() { - let monomial_degree = MonomialDegree(fn_idx * lut.sample_extraction_stride); - let mut output_shortint_ct = ct.clone(); - - extract_lwe_sample_from_glwe_ciphertext( - &acc, - &mut output_shortint_ct.ct, - monomial_degree, - ); - - output_shortint_ct.degree = *output_degree; - output_shortint_ct.set_noise_level_to_nominal(); - outputs.push(output_shortint_ct); - } - - outputs - } - - pub(crate) fn programmable_bootstrap_keyswitch_many_lut( - &self, - ct: &Ciphertext, - lut: &ManyLookupTableOwned, - ) -> Vec { - if ct.is_trivial() { - return self.trivial_pbs_many_lut(ct, lut); - } - - let mut acc = lut.acc.clone(); - - ShortintEngine::with_thread_local_mut(|engine| { - // Compute the programmable bootstrapping with fixed test polynomial - let buffers = engine.get_computation_buffers(); - - apply_blind_rotate(&self.bootstrapping_key, &ct.ct, &mut acc, buffers); - }); - - // The accumulator has been rotated, we can now proceed with the various sample extractions - let function_count = lut.function_count(); - let mut outputs = Vec::with_capacity(function_count); - - let mut tmp_lwe_ciphertext = LweCiphertext::new( - 0u64, - self.key_switching_key - .input_key_lwe_dimension() - .to_lwe_size(), - self.key_switching_key.ciphertext_modulus(), - ); - - for (fn_idx, output_degree) in lut.per_function_output_degree.iter().enumerate() { - let monomial_degree = MonomialDegree(fn_idx * lut.sample_extraction_stride); - extract_lwe_sample_from_glwe_ciphertext(&acc, &mut tmp_lwe_ciphertext, monomial_degree); - - let mut output_shortint_ct = ct.clone(); - - // Compute a key switch - keyswitch_lwe_ciphertext( - &self.key_switching_key, - &tmp_lwe_ciphertext, - &mut output_shortint_ct.ct, - ); - - output_shortint_ct.degree = *output_degree; - output_shortint_ct.set_noise_level_to_nominal(); - outputs.push(output_shortint_ct); - } - - outputs +impl GenericServerKey { + pub fn set_deterministic_execution(&mut self, new_deterministic_execution: bool) { + self.atomic_pattern + .set_deterministic_execution(new_deterministic_execution); } } @@ -1422,7 +1310,7 @@ impl SmartCleaningOperation { } } -impl ServerKey { +impl GenericServerKey { /// Before doing an operations on 2 inputs which validity is described by /// `is_operation_possible`, one or both the inputs may need to be cleaned (carry removal and /// noise reinitilization) with a PBS @@ -1781,37 +1669,19 @@ impl ParameterSetConformant for ShortintBootstrappingKey { } impl ParameterSetConformant for ServerKey { - type ParameterSet = (PBSParameters, MaxDegree); + type ParameterSet = (AtomicPatternParameters, MaxDegree); fn is_conformant(&self, (parameter_set, expected_max_degree): &Self::ParameterSet) -> bool { let Self { - key_switching_key, - bootstrapping_key, message_modulus, + atomic_pattern, carry_modulus, max_degree, max_noise_level, ciphertext_modulus, - pbs_order, } = self; - let params: PBSConformanceParams = parameter_set.into(); - - let pbs_key_ok = bootstrapping_key.is_conformant(¶ms); - - let param: LweKeyswitchKeyConformanceParams = parameter_set.into(); - - let ks_key_ok = key_switching_key.is_conformant(¶m); - - let pbs_order_ok = matches!( - (*pbs_order, parameter_set.encryption_key_choice()), - (PBSOrder::KeyswitchBootstrap, EncryptionKeyChoice::Big) - | (PBSOrder::BootstrapKeyswitch, EncryptionKeyChoice::Small) - ); - - pbs_key_ok - && ks_key_ok - && pbs_order_ok + atomic_pattern.is_conformant(parameter_set) && *max_degree == *expected_max_degree && *message_modulus == parameter_set.message_modulus() && *carry_modulus == parameter_set.carry_modulus() @@ -1819,3 +1689,43 @@ impl ParameterSetConformant for ServerKey { && *ciphertext_modulus == parameter_set.ciphertext_modulus() } } + +impl StandardServerKeyView<'_> { + pub fn bootstrapping_key_size_elements(&self) -> usize { + self.atomic_pattern + .bootstrapping_key + .bootstrapping_key_size_elements() + } + + pub fn bootstrapping_key_size_bytes(&self) -> usize { + self.atomic_pattern + .bootstrapping_key + .bootstrapping_key_size_bytes() + } + + pub fn key_switching_key_size_elements(&self) -> usize { + self.atomic_pattern.key_switching_key.as_ref().len() + } + + pub fn key_switching_key_size_bytes(&self) -> usize { + std::mem::size_of_val(self.atomic_pattern.key_switching_key.as_ref()) + } +} + +impl StandardServerKey { + pub fn bootstrapping_key_size_elements(&self) -> usize { + self.as_view().bootstrapping_key_size_elements() + } + + pub fn bootstrapping_key_size_bytes(&self) -> usize { + self.as_view().bootstrapping_key_size_bytes() + } + + pub fn key_switching_key_size_elements(&self) -> usize { + self.as_view().key_switching_key_size_elements() + } + + pub fn key_switching_key_size_bytes(&self) -> usize { + self.as_view().key_switching_key_size_bytes() + } +} diff --git a/tfhe/src/shortint/server_key/modulus_switched_compression.rs b/tfhe/src/shortint/server_key/modulus_switched_compression.rs index 9ddf71526..f69c3c738 100644 --- a/tfhe/src/shortint/server_key/modulus_switched_compression.rs +++ b/tfhe/src/shortint/server_key/modulus_switched_compression.rs @@ -1,20 +1,9 @@ -use super::compressed_modulus_switched_multi_bit_lwe_ciphertext::CompressedModulusSwitchedMultiBitLweCiphertext; -use super::{ - apply_modulus_switch_noise_reduction, apply_programmable_bootstrap_no_ms_noise_reduction, - extract_lwe_sample_from_glwe_ciphertext, multi_bit_deterministic_blind_rotate_assign, - GlweCiphertext, ShortintBootstrappingKey, -}; -use crate::core_crypto::commons::parameters::MonomialDegree; -use crate::core_crypto::prelude::compressed_modulus_switched_lwe_ciphertext::CompressedModulusSwitchedLweCiphertext; -use crate::core_crypto::prelude::{keyswitch_lwe_ciphertext, ComputationBuffers, LweCiphertext}; -use crate::shortint::ciphertext::{ - CompressedModulusSwitchedCiphertext, InternalCompressedModulusSwitchedCiphertext, NoiseLevel, -}; -use crate::shortint::engine::ShortintEngine; -use crate::shortint::server_key::LookupTableOwned; -use crate::shortint::{Ciphertext, PBSOrder, ServerKey}; +use crate::shortint::atomic_pattern::AtomicPattern; +use crate::shortint::ciphertext::CompressedModulusSwitchedCiphertext; +use crate::shortint::server_key::{GenericServerKey, LookupTableOwned}; +use crate::shortint::Ciphertext; -impl ServerKey { +impl GenericServerKey { /// Compresses a ciphertext to have a smaller serialization size /// /// See [`CompressedModulusSwitchedCiphertext#example`] for usage @@ -22,76 +11,7 @@ impl ServerKey { &self, ct: &Ciphertext, ) -> CompressedModulusSwitchedCiphertext { - let compressed_modulus_switched_lwe_ciphertext = - ShortintEngine::with_thread_local_mut(|engine| { - let (mut ciphertext_buffer, _) = engine.get_buffers(self); - let input_ct = match self.pbs_order { - PBSOrder::KeyswitchBootstrap => { - keyswitch_lwe_ciphertext( - &self.key_switching_key, - &ct.ct, - &mut ciphertext_buffer, - ); - ciphertext_buffer.as_view() - } - PBSOrder::BootstrapKeyswitch => ct.ct.as_view(), - }; - - match &self.bootstrapping_key { - ShortintBootstrappingKey::Classic { - bsk, - modulus_switch_noise_reduction_key, - } => { - let log_modulus = - bsk.polynomial_size().to_blind_rotation_input_modulus_log(); - - let input_improved_before_ms; - - // The solution suggested by clippy does not work because of the capture of - // `input_improved_before_ms` - #[allow(clippy::option_if_let_else)] - let input_modulus_switch = if let Some(modulus_switch_noise_reduction_key) = - modulus_switch_noise_reduction_key - { - input_improved_before_ms = apply_modulus_switch_noise_reduction( - modulus_switch_noise_reduction_key, - log_modulus, - &input_ct, - ); - - input_improved_before_ms.as_view() - } else { - input_ct - }; - - InternalCompressedModulusSwitchedCiphertext::Classic( - CompressedModulusSwitchedLweCiphertext::compress( - &input_modulus_switch, - log_modulus, - ), - ) - } - ShortintBootstrappingKey::MultiBit { fourier_bsk, .. } => { - InternalCompressedModulusSwitchedCiphertext::MultiBit( - CompressedModulusSwitchedMultiBitLweCiphertext::compress( - &input_ct, - self.bootstrapping_key - .polynomial_size() - .to_blind_rotation_input_modulus_log(), - fourier_bsk.grouping_factor(), - ), - ) - } - } - }); - - CompressedModulusSwitchedCiphertext { - compressed_modulus_switched_lwe_ciphertext, - degree: ct.degree, - message_modulus: ct.message_modulus, - carry_modulus: ct.carry_modulus, - atomic_pattern: ct.atomic_pattern, - } + self.atomic_pattern.switch_modulus_and_compress(ct) } /// Decompresses a compressed ciphertext @@ -145,105 +65,7 @@ impl ServerKey { compressed_ct: &CompressedModulusSwitchedCiphertext, acc: &LookupTableOwned, ) -> Ciphertext { - let mut output = LweCiphertext::from_container( - vec![0; self.ciphertext_lwe_dimension().to_lwe_size().0], - self.ciphertext_modulus, - ); - - ShortintEngine::with_thread_local_mut(|engine| { - let (mut ciphertext_buffer, buffers) = engine.get_buffers(self); - - match self.pbs_order { - PBSOrder::KeyswitchBootstrap => { - self.bootstrap_for_decompression( - compressed_ct, - &mut output.as_mut_view(), - acc, - buffers, - ); - } - PBSOrder::BootstrapKeyswitch => { - self.bootstrap_for_decompression( - compressed_ct, - &mut ciphertext_buffer, - acc, - buffers, - ); - keyswitch_lwe_ciphertext( - &self.key_switching_key, - &ciphertext_buffer, - &mut output, - ); - } - } - }); - - Ciphertext::new( - output, - acc.degree, - NoiseLevel::NOMINAL, - compressed_ct.message_modulus, - compressed_ct.carry_modulus, - compressed_ct.atomic_pattern, - ) - } - - fn bootstrap_for_decompression( - &self, - compressed_ct: &CompressedModulusSwitchedCiphertext, - out_ct: &mut LweCiphertext<&mut [u64]>, - acc: &LookupTableOwned, - buffers: &mut ComputationBuffers, - ) { - match &self.bootstrapping_key { - ShortintBootstrappingKey::Classic { .. } => { - let ct = match &compressed_ct.compressed_modulus_switched_lwe_ciphertext { - InternalCompressedModulusSwitchedCiphertext::Classic(a) => a.extract(), - InternalCompressedModulusSwitchedCiphertext::MultiBit(_) => { - panic!("Compression was done targeting a MultiBit bootstrap decompression, cannot decompress with a Classic bootstrapping key") - } - }; - apply_programmable_bootstrap_no_ms_noise_reduction( - &self.bootstrapping_key, - &ct, - out_ct, - &acc.acc, - buffers, - ); - } - ShortintBootstrappingKey::MultiBit { - fourier_bsk, - thread_count, - deterministic_execution: _, - } => { - let ct = match &compressed_ct.compressed_modulus_switched_lwe_ciphertext { - InternalCompressedModulusSwitchedCiphertext::MultiBit(a) => a.extract(), - InternalCompressedModulusSwitchedCiphertext::Classic(_) => { - panic!("Compression was done targeting a Classic bootstrap decompression, cannot decompress with a MultiBit bootstrapping key") - } - }; - - let mut local_accumulator = GlweCiphertext::new( - 0, - acc.acc.glwe_size(), - acc.acc.polynomial_size(), - acc.acc.ciphertext_modulus(), - ); - local_accumulator.as_mut().copy_from_slice(acc.acc.as_ref()); - - multi_bit_deterministic_blind_rotate_assign( - &ct, - &mut local_accumulator, - fourier_bsk, - *thread_count, - ); - - extract_lwe_sample_from_glwe_ciphertext( - &local_accumulator, - out_ct, - MonomialDegree(0), - ); - } - } + self.atomic_pattern + .decompress_and_apply_lookup_table(compressed_ct, acc) } } diff --git a/tfhe/src/shortint/server_key/mul.rs b/tfhe/src/shortint/server_key/mul.rs index 3a81d1118..d75cff842 100644 --- a/tfhe/src/shortint/server_key/mul.rs +++ b/tfhe/src/shortint/server_key/mul.rs @@ -1,10 +1,11 @@ use super::add::unchecked_add_assign; -use super::{CiphertextNoiseDegree, ServerKey}; +use super::CiphertextNoiseDegree; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::Degree; -use crate::shortint::server_key::CheckError; +use crate::shortint::server_key::{CheckError, GenericServerKey}; use crate::shortint::Ciphertext; -impl ServerKey { +impl GenericServerKey { /// Multiply two ciphertexts together without checks. /// /// Return the "least significant bits" of the multiplication, i.e., the result modulus the diff --git a/tfhe/src/shortint/server_key/neg.rs b/tfhe/src/shortint/server_key/neg.rs index 93c1448cb..3d33a2b68 100644 --- a/tfhe/src/shortint/server_key/neg.rs +++ b/tfhe/src/shortint/server_key/neg.rs @@ -1,11 +1,12 @@ use super::CiphertextNoiseDegree; use crate::core_crypto::algorithms::*; use crate::core_crypto::entities::*; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::Degree; -use crate::shortint::server_key::CheckError; -use crate::shortint::{Ciphertext, PaddingBit, ServerKey}; +use crate::shortint::server_key::{CheckError, GenericServerKey}; +use crate::shortint::{Ciphertext, PaddingBit}; -impl ServerKey { +impl GenericServerKey { /// Compute homomorphically a negation of a ciphertext. /// /// This checks that the negation is possible. In the case where the carry buffers are full, diff --git a/tfhe/src/shortint/server_key/scalar_add.rs b/tfhe/src/shortint/server_key/scalar_add.rs index 01270abf0..e5c3443ce 100644 --- a/tfhe/src/shortint/server_key/scalar_add.rs +++ b/tfhe/src/shortint/server_key/scalar_add.rs @@ -1,11 +1,12 @@ use super::CiphertextNoiseDegree; use crate::core_crypto::algorithms::*; use crate::core_crypto::entities::*; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::Degree; -use crate::shortint::server_key::CheckError; -use crate::shortint::{Ciphertext, PaddingBit, ServerKey}; +use crate::shortint::server_key::{CheckError, GenericServerKey}; +use crate::shortint::{Ciphertext, PaddingBit}; -impl ServerKey { +impl GenericServerKey { /// Compute homomorphically an addition between a ciphertext and a scalar. /// /// The result is returned in a _new_ ciphertext. diff --git a/tfhe/src/shortint/server_key/scalar_bitwise_op.rs b/tfhe/src/shortint/server_key/scalar_bitwise_op.rs index dfda9a777..f2fd66931 100644 --- a/tfhe/src/shortint/server_key/scalar_bitwise_op.rs +++ b/tfhe/src/shortint/server_key/scalar_bitwise_op.rs @@ -1,8 +1,9 @@ -use super::ServerKey; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::Degree; +use crate::shortint::server_key::GenericServerKey; use crate::shortint::Ciphertext; -impl ServerKey { +impl GenericServerKey { /// Compute homomorphically a bitwise AND between a ciphertext and a clear value /// /// diff --git a/tfhe/src/shortint/server_key/scalar_div_mod.rs b/tfhe/src/shortint/server_key/scalar_div_mod.rs index a3f5fff91..e20c16816 100644 --- a/tfhe/src/shortint/server_key/scalar_div_mod.rs +++ b/tfhe/src/shortint/server_key/scalar_div_mod.rs @@ -1,7 +1,9 @@ +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::Degree; -use crate::shortint::{Ciphertext, ServerKey}; +use crate::shortint::server_key::GenericServerKey; +use crate::shortint::Ciphertext; -impl ServerKey { +impl GenericServerKey { /// Alias to [`unchecked_scalar_div`](`Self::unchecked_scalar_div`) provided for convenience /// /// This function, like all "default" operations (i.e. not smart, checked or unchecked), will diff --git a/tfhe/src/shortint/server_key/scalar_mul.rs b/tfhe/src/shortint/server_key/scalar_mul.rs index 170592ec9..b182cb688 100644 --- a/tfhe/src/shortint/server_key/scalar_mul.rs +++ b/tfhe/src/shortint/server_key/scalar_mul.rs @@ -1,11 +1,12 @@ use super::CiphertextNoiseDegree; use crate::core_crypto::algorithms::*; use crate::core_crypto::entities::*; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::Degree; -use crate::shortint::server_key::CheckError; -use crate::shortint::{Ciphertext, MaxNoiseLevel, ServerKey}; +use crate::shortint::server_key::{CheckError, GenericServerKey}; +use crate::shortint::{Ciphertext, MaxNoiseLevel}; -impl ServerKey { +impl GenericServerKey { /// Compute homomorphically a multiplication of a ciphertext by a scalar. /// /// This function, like all "default" operations (i.e. not smart, checked or unchecked), will diff --git a/tfhe/src/shortint/server_key/scalar_sub.rs b/tfhe/src/shortint/server_key/scalar_sub.rs index ecb1681dd..e34bd6901 100644 --- a/tfhe/src/shortint/server_key/scalar_sub.rs +++ b/tfhe/src/shortint/server_key/scalar_sub.rs @@ -1,11 +1,12 @@ use super::CiphertextNoiseDegree; use crate::core_crypto::algorithms::*; use crate::core_crypto::entities::*; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::Degree; -use crate::shortint::server_key::CheckError; -use crate::shortint::{Ciphertext, MessageModulus, PaddingBit, ServerKey}; +use crate::shortint::server_key::{CheckError, GenericServerKey}; +use crate::shortint::{Ciphertext, MessageModulus, PaddingBit}; -impl ServerKey { +impl GenericServerKey { /// Compute homomorphically a subtraction of a ciphertext by a scalar. /// /// The result is returned in a _new_ ciphertext. diff --git a/tfhe/src/shortint/server_key/shift.rs b/tfhe/src/shortint/server_key/shift.rs index 2d90b9493..e60ccaa67 100644 --- a/tfhe/src/shortint/server_key/shift.rs +++ b/tfhe/src/shortint/server_key/shift.rs @@ -1,10 +1,11 @@ use super::CiphertextNoiseDegree; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::Degree; use crate::shortint::server_key::scalar_mul::unchecked_scalar_mul_assign; -use crate::shortint::server_key::CheckError; -use crate::shortint::{Ciphertext, ServerKey}; +use crate::shortint::server_key::{CheckError, GenericServerKey}; +use crate::shortint::Ciphertext; -impl ServerKey { +impl GenericServerKey { /// Compute homomorphically a right shift of the bits. /// /// This returns a new ciphertext. diff --git a/tfhe/src/shortint/server_key/sub.rs b/tfhe/src/shortint/server_key/sub.rs index 0fd77367a..1dd8a4303 100644 --- a/tfhe/src/shortint/server_key/sub.rs +++ b/tfhe/src/shortint/server_key/sub.rs @@ -1,10 +1,11 @@ use super::{CiphertextNoiseDegree, SmartCleaningOperation}; use crate::core_crypto::algorithms::*; +use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::Degree; -use crate::shortint::server_key::CheckError; -use crate::shortint::{Ciphertext, ServerKey}; +use crate::shortint::server_key::{CheckError, GenericServerKey}; +use crate::shortint::Ciphertext; -impl ServerKey { +impl GenericServerKey { /// Compute homomorphically a subtraction between two ciphertexts. /// /// This returns a new ciphertext. diff --git a/tfhe/src/shortint/wopbs/mod.rs b/tfhe/src/shortint/wopbs/mod.rs index 635f4c8bd..c2c9067f5 100644 --- a/tfhe/src/shortint/wopbs/mod.rs +++ b/tfhe/src/shortint/wopbs/mod.rs @@ -9,7 +9,7 @@ use crate::core_crypto::entities::*; -use crate::shortint::{ServerKey, WopbsParameters}; +use crate::shortint::WopbsParameters; use serde::{Deserialize, Serialize}; #[cfg(all(test, feature = "experimental"))] @@ -20,8 +20,8 @@ mod test; #[cfg_attr(dylint_lib = "tfhe_lints", allow(serialize_without_versionize))] pub struct WopbsKey { //Key for the private functional keyswitch - pub wopbs_server_key: ServerKey, - pub pbs_server_key: ServerKey, + pub wopbs_server_key: StandardServerKey, + pub pbs_server_key: StandardServerKey, pub cbs_pfpksk: LwePrivateFunctionalPackingKeyswitchKeyListOwned, pub ksk_pbs_to_wopbs: LweKeyswitchKeyOwned, pub param: WopbsParameters, @@ -30,6 +30,8 @@ pub struct WopbsKey { #[cfg(feature = "experimental")] pub use experimental::*; +use super::server_key::StandardServerKey; + #[cfg(feature = "experimental")] mod experimental { use crate::core_crypto::algorithms::*; @@ -37,9 +39,12 @@ mod experimental { use crate::core_crypto::commons::traits::*; use crate::core_crypto::entities::*; use crate::core_crypto::fft_impl::fft64::math::fft::Fft; + use crate::shortint::atomic_pattern::AtomicPattern; use crate::shortint::ciphertext::*; use crate::shortint::engine::ShortintEngine; - use crate::shortint::server_key::ShortintBootstrappingKey; + use crate::shortint::server_key::{ + ShortintBootstrappingKey, StandardServerKey, StandardServerKeyView, + }; use super::WopbsKey; use crate::shortint::{ClientKey, ServerKey, WopbsParameters}; @@ -243,9 +248,12 @@ mod experimental { /// /// // Generate the client key and the server key: /// let (cks, sks) = gen_keys(LEGACY_WOPBS_ONLY_8_BLOCKS_PARAM_MESSAGE_1_CARRY_1_KS_PBS); - /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, sks.as_view().try_into().unwrap()); /// ``` - pub fn new_wopbs_key_only_for_wopbs(cks: &ClientKey, sks: &ServerKey) -> Self { + pub fn new_wopbs_key_only_for_wopbs( + cks: &ClientKey, + sks: StandardServerKeyView<'_>, + ) -> Self { ShortintEngine::with_thread_local_mut(|engine| { engine.new_wopbs_key_only_for_wopbs(cks, sks).unwrap() }) @@ -262,11 +270,12 @@ mod experimental { /// /// // Generate the client key and the server key: /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128); - /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &LEGACY_WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS); + /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, sks.as_view().try_into().unwrap(), + /// &LEGACY_WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS); /// ``` pub fn new_wopbs_key( cks: &ClientKey, - sks: &ServerKey, + sks: StandardServerKeyView<'_>, parameters: &WopbsParameters, ) -> Self { ShortintEngine::with_thread_local_mut(|engine| { @@ -278,8 +287,8 @@ mod experimental { pub fn into_raw_parts( self, ) -> ( - ServerKey, - ServerKey, + StandardServerKey, + StandardServerKey, LwePrivateFunctionalPackingKeyswitchKeyListOwned, LweKeyswitchKeyOwned, WopbsParameters, @@ -307,8 +316,8 @@ mod experimental { /// /// Panics if the constituents are not compatible with each others. pub fn from_raw_parts( - wopbs_server_key: ServerKey, - pbs_server_key: ServerKey, + wopbs_server_key: StandardServerKey, + pbs_server_key: StandardServerKey, cbs_pfpksk: LwePrivateFunctionalPackingKeyswitchKeyListOwned, ksk_pbs_to_wopbs: LweKeyswitchKeyOwned, param: WopbsParameters, @@ -348,7 +357,8 @@ mod experimental { /// /// // Generate the client key and the server key: /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128); - /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &LEGACY_WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS); + /// let std_sks = sks.as_view().try_into().unwrap(); + /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, std_sks, &LEGACY_WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS); /// let message_modulus = LEGACY_WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS.message_modulus.0; /// let m = 2; /// let ct = cks.encrypt(m); @@ -364,7 +374,12 @@ mod experimental { // The function is applied only on the message modulus bits let basis = ct.message_modulus.0 * ct.carry_modulus.0; let delta = 64 - f64::log2(basis as f64).ceil() as u64 - 1; - let poly_size = self.wopbs_server_key.bootstrapping_key.polynomial_size().0; + let poly_size = self + .wopbs_server_key + .atomic_pattern + .bootstrapping_key + .polynomial_size() + .0; let mut lut = ShortintWopbsLUT::new(PlaintextCount(poly_size)); for (i, value) in lut.iter_mut().enumerate().take(basis as usize) { *value = f(i as u64 % ct.message_modulus.0) << delta; @@ -385,7 +400,8 @@ mod experimental { /// /// // Generate the client key and the server key: /// let (cks, sks) = gen_keys(LEGACY_WOPBS_ONLY_4_BLOCKS_PARAM_MESSAGE_2_CARRY_2_KS_PBS); - /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// let std_sks = sks.as_view().try_into().unwrap(); + /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, std_sks); /// let message_modulus = LEGACY_WOPBS_ONLY_4_BLOCKS_PARAM_MESSAGE_2_CARRY_2_KS_PBS.message_modulus.0; /// let m = 2; /// let ct = cks.encrypt_without_padding(m); @@ -401,7 +417,12 @@ mod experimental { // The function is applied only on the message modulus bits let basis = ct.message_modulus.0 * ct.carry_modulus.0; let delta = 64 - f64::log2((basis) as f64).ceil() as u64; - let poly_size = self.wopbs_server_key.bootstrapping_key.polynomial_size().0; + let poly_size = self + .wopbs_server_key + .atomic_pattern + .bootstrapping_key + .polynomial_size() + .0; let mut vec_lut = vec![0; poly_size]; for (i, value) in vec_lut.iter_mut().enumerate().take(basis as usize) { *value = f(i as u64 % ct.message_modulus.0) << delta; @@ -422,7 +443,7 @@ mod experimental { /// /// // Generate the client key and the server key: /// let (cks, sks) = gen_keys(LEGACY_WOPBS_PARAM_MESSAGE_3_CARRY_3_KS_PBS); - /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, sks.as_view().try_into().unwrap()); /// let message_modulus = MessageModulus(5); /// let m = 2; /// let ct = cks.encrypt_native_crt(m, message_modulus); @@ -438,7 +459,12 @@ mod experimental { // The function is applied only on the message modulus bits let basis = ct.message_modulus.0 * ct.carry_modulus.0; let nb_bit = f64::log2((basis) as f64).ceil() as u64; - let poly_size = self.wopbs_server_key.bootstrapping_key.polynomial_size().0; + let poly_size = self + .wopbs_server_key + .atomic_pattern + .bootstrapping_key + .polynomial_size() + .0; let mut lut = ShortintWopbsLUT::new(PlaintextCount(poly_size)); for i in 0..basis { let index_lut = (((i % basis) << nb_bit) / basis) as usize; @@ -462,7 +488,8 @@ mod experimental { /// /// // Generate the client key and the server key: /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); - /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &LEGACY_WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS); + /// let std_sks = sks.as_view().try_into().unwrap(); + /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, std_sks, &LEGACY_WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS); /// let mut rng = rand::thread_rng(); /// let message_modulus = LEGACY_WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS.message_modulus.0; /// let ct = cks.encrypt(rng.gen::() % message_modulus); @@ -498,7 +525,8 @@ mod experimental { /// /// // Generate the client key and the server key: /// let (cks, sks) = gen_keys(LEGACY_WOPBS_ONLY_4_BLOCKS_PARAM_MESSAGE_2_CARRY_2_KS_PBS); - /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// let std_sks = sks.as_view().try_into().unwrap(); + /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, std_sks); /// let mut rng = rand::thread_rng(); /// let message_modulus = LEGACY_WOPBS_ONLY_4_BLOCKS_PARAM_MESSAGE_2_CARRY_2_KS_PBS.message_modulus.0; /// let ct = cks.encrypt(rng.gen::() % message_modulus); @@ -540,7 +568,8 @@ mod experimental { /// let mut msg_1_carry_0_params = LEGACY_WOPBS_ONLY_8_BLOCKS_PARAM_MESSAGE_1_CARRY_1_KS_PBS; /// msg_1_carry_0_params.carry_modulus = CarryModulus(1); /// let (cks, sks) = gen_keys(msg_1_carry_0_params); - /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// let std_sks = sks.as_view().try_into().unwrap(); + /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, std_sks); /// let mut rng = rand::thread_rng(); /// let ct = cks.encrypt_without_padding(rng.gen::() % 2); /// let lut = vec![1_u64 << 63; wopbs_key.param.polynomial_size.0].into(); @@ -584,7 +613,8 @@ mod experimental { /// use tfhe::shortint::wopbs::*; /// /// let (cks, sks) = gen_keys(LEGACY_WOPBS_PARAM_MESSAGE_3_CARRY_3_KS_PBS); - /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + /// let std_sks = sks.as_view().try_into().unwrap(); + /// let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, std_sks); /// let msg = 2; /// let modulus = MessageModulus(5); /// let ct = cks.encrypt_native_crt(msg, modulus); @@ -635,6 +665,7 @@ mod experimental { let server_key = &self.wopbs_server_key; let lwe_size = server_key + .atomic_pattern .key_switching_key .output_key_lwe_dimension() .to_lwe_size(); @@ -665,8 +696,8 @@ mod experimental { { let server_key = &self.wopbs_server_key; - let bsk = &server_key.bootstrapping_key; - let ksk = &server_key.key_switching_key; + let bsk = &server_key.atomic_pattern.bootstrapping_key; + let ksk = &server_key.atomic_pattern.key_switching_key; let fft = Fft::new(bsk.polynomial_size()); let fft = fft.as_view(); @@ -752,15 +783,20 @@ mod experimental { let acc = self.pbs_server_key.generate_lookup_table(|x| x); ShortintEngine::with_thread_local_mut(|engine| { - let (mut ciphertext_buffer, buffers) = engine.get_buffers(&self.pbs_server_key); + let (mut ciphertext_buffer, buffers) = engine.get_buffers( + self.pbs_server_key + .atomic_pattern + .intermediate_lwe_dimension(), + self.pbs_server_key.atomic_pattern.ciphertext_modulus(), + ); // Compute a key switch keyswitch_lwe_ciphertext( - &self.pbs_server_key.key_switching_key, + &self.pbs_server_key.atomic_pattern.key_switching_key, &ct_in.ct, &mut ciphertext_buffer, ); - let ct_out = match &self.pbs_server_key.bootstrapping_key { + let ct_out = match &self.pbs_server_key.atomic_pattern.bootstrapping_key { ShortintBootstrappingKey::Classic { bsk: fourier_bsk, modulus_switch_noise_reduction_key: _, @@ -853,7 +889,7 @@ mod experimental { LutCont: Container, { let sks = &self.wopbs_server_key; - let fourier_bsk = &sks.bootstrapping_key; + let fourier_bsk = &sks.atomic_pattern.bootstrapping_key; let output_lwe_size = fourier_bsk.output_lwe_dimension().to_lwe_size(); @@ -889,7 +925,7 @@ mod experimental { let stack = buffers.stack(); - match &sks.bootstrapping_key { + match &sks.atomic_pattern.bootstrapping_key { ShortintBootstrappingKey::Classic{bsk, modulus_switch_noise_reduction_key:_ } => { circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_list_mem_optimized( extracted_bits, diff --git a/tfhe/src/shortint/wopbs/test.rs b/tfhe/src/shortint/wopbs/test.rs index c189dde14..a12d5e8e5 100644 --- a/tfhe/src/shortint/wopbs/test.rs +++ b/tfhe/src/shortint/wopbs/test.rs @@ -136,7 +136,7 @@ fn generate_lut(params: (ClassicPBSParameters, WopbsParameters)) { #[cfg(not(tarpaulin))] fn generate_lut_wop_only(params: WopbsParameters) { let (cks, sks) = gen_keys(params); - let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, sks.as_view().try_into().unwrap()); let mut rng = rand::thread_rng(); let mut tmp = 0; @@ -186,7 +186,7 @@ fn generate_lut_modulus(params: (ClassicPBSParameters, WopbsParameters)) { #[cfg(not(tarpaulin))] fn generate_lut_modulus_not_power_of_two(params: WopbsParameters) { let (cks, sks) = gen_keys(params); - let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks); + let wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, sks.as_view().try_into().unwrap()); let mut rng = rand::thread_rng();