From 0223913aef982609b567cf0824958a17168dd6ab Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 27 Oct 2025 16:47:10 +0100 Subject: [PATCH] chore: make functions consistent to generate keyswitching keys - so that normal and seeded variants have similar APIs --- .../client_key/atomic_pattern/ks32.rs | 37 ++++++++-------- .../shortint/client_key/atomic_pattern/mod.rs | 11 +++-- .../client_key/atomic_pattern/standard.rs | 43 +++++++++---------- tfhe/src/shortint/key_switching_key/mod.rs | 8 ++-- 4 files changed, 50 insertions(+), 49 deletions(-) diff --git a/tfhe/src/shortint/client_key/atomic_pattern/ks32.rs b/tfhe/src/shortint/client_key/atomic_pattern/ks32.rs index 24491d484..f0282a8c2 100644 --- a/tfhe/src/shortint/client_key/atomic_pattern/ks32.rs +++ b/tfhe/src/shortint/client_key/atomic_pattern/ks32.rs @@ -259,35 +259,32 @@ impl KS32AtomicPatternClientKey { } } - pub(crate) fn new_seeded_keyswitching_key( + pub(crate) fn new_seeded_keyswitching_key_with_engine( &self, input_secret_key: &SecretEncryptionKeyView<'_>, params: ShortintKeySwitchingParameters, + engine: &mut ShortintEngine, ) -> SeededLweKeyswitchKeyOwned { match params.destination_key { - EncryptionKeyChoice::Big => ShortintEngine::with_thread_local_mut(|engine| { - allocate_and_generate_new_seeded_lwe_keyswitch_key( + EncryptionKeyChoice::Big => allocate_and_generate_new_seeded_lwe_keyswitch_key( + &input_secret_key.lwe_secret_key, + &self.large_lwe_secret_key(), + params.ks_base_log, + params.ks_level, + self.parameters.glwe_noise_distribution(), + self.parameters.ciphertext_modulus(), + &mut engine.seeder, + ), + EncryptionKeyChoice::Small => { + let ksk = allocate_and_generate_new_seeded_lwe_keyswitch_key( &input_secret_key.lwe_secret_key, - &self.large_lwe_secret_key(), + &self.small_lwe_secret_key(), params.ks_base_log, params.ks_level, - self.parameters.glwe_noise_distribution(), - self.parameters.ciphertext_modulus(), + self.parameters.lwe_noise_distribution(), + self.parameters.post_keyswitch_ciphertext_modulus(), &mut engine.seeder, - ) - }), - EncryptionKeyChoice::Small => { - let ksk = ShortintEngine::with_thread_local_mut(|engine| { - allocate_and_generate_new_seeded_lwe_keyswitch_key( - &input_secret_key.lwe_secret_key, - &self.small_lwe_secret_key(), - params.ks_base_log, - params.ks_level, - self.parameters.lwe_noise_distribution(), - self.parameters.post_keyswitch_ciphertext_modulus(), - &mut engine.seeder, - ) - }); + ); let shift = u64::BITS - u32::BITS; SeededLweKeyswitchKeyOwned::from_container( diff --git a/tfhe/src/shortint/client_key/atomic_pattern/mod.rs b/tfhe/src/shortint/client_key/atomic_pattern/mod.rs index 9e1f92d9a..b70991ba1 100644 --- a/tfhe/src/shortint/client_key/atomic_pattern/mod.rs +++ b/tfhe/src/shortint/client_key/atomic_pattern/mod.rs @@ -211,14 +211,19 @@ impl AtomicPatternClientKey { } } - pub(crate) fn new_seeded_keyswitching_key( + pub(crate) fn new_seeded_keyswitching_key_with_engine( &self, input_secret_key: &SecretEncryptionKeyView<'_>, params: ShortintKeySwitchingParameters, + engine: &mut ShortintEngine, ) -> SeededLweKeyswitchKeyOwned { match self { - Self::Standard(ap) => ap.new_seeded_keyswitching_key(input_secret_key, params), - Self::KeySwitch32(ap) => ap.new_seeded_keyswitching_key(input_secret_key, params), + Self::Standard(ap) => { + ap.new_seeded_keyswitching_key_with_engine(input_secret_key, params, engine) + } + Self::KeySwitch32(ap) => { + ap.new_seeded_keyswitching_key_with_engine(input_secret_key, params, engine) + } } } } diff --git a/tfhe/src/shortint/client_key/atomic_pattern/standard.rs b/tfhe/src/shortint/client_key/atomic_pattern/standard.rs index 8bffb4dc7..74588f3ab 100644 --- a/tfhe/src/shortint/client_key/atomic_pattern/standard.rs +++ b/tfhe/src/shortint/client_key/atomic_pattern/standard.rs @@ -331,34 +331,31 @@ impl StandardAtomicPatternClientKey { } } - pub(crate) fn new_seeded_keyswitching_key( + pub(crate) fn new_seeded_keyswitching_key_with_engine( &self, input_secret_key: &SecretEncryptionKeyView<'_>, params: ShortintKeySwitchingParameters, + engine: &mut ShortintEngine, ) -> SeededLweKeyswitchKeyOwned { match params.destination_key { - EncryptionKeyChoice::Big => ShortintEngine::with_thread_local_mut(|engine| { - allocate_and_generate_new_seeded_lwe_keyswitch_key( - &input_secret_key.lwe_secret_key, - &self.large_lwe_secret_key(), - params.ks_base_log, - params.ks_level, - self.parameters().glwe_noise_distribution(), - self.parameters().ciphertext_modulus(), - &mut engine.seeder, - ) - }), - EncryptionKeyChoice::Small => ShortintEngine::with_thread_local_mut(|engine| { - allocate_and_generate_new_seeded_lwe_keyswitch_key( - &input_secret_key.lwe_secret_key, - &self.small_lwe_secret_key(), - params.ks_base_log, - params.ks_level, - self.parameters().lwe_noise_distribution(), - self.parameters().ciphertext_modulus(), - &mut engine.seeder, - ) - }), + EncryptionKeyChoice::Big => allocate_and_generate_new_seeded_lwe_keyswitch_key( + &input_secret_key.lwe_secret_key, + &self.large_lwe_secret_key(), + params.ks_base_log, + params.ks_level, + self.parameters().glwe_noise_distribution(), + self.parameters().ciphertext_modulus(), + &mut engine.seeder, + ), + EncryptionKeyChoice::Small => allocate_and_generate_new_seeded_lwe_keyswitch_key( + &input_secret_key.lwe_secret_key, + &self.small_lwe_secret_key(), + params.ks_base_log, + params.ks_level, + self.parameters().lwe_noise_distribution(), + self.parameters().ciphertext_modulus(), + &mut engine.seeder, + ), } } } diff --git a/tfhe/src/shortint/key_switching_key/mod.rs b/tfhe/src/shortint/key_switching_key/mod.rs index 14b90c507..3f943deaa 100644 --- a/tfhe/src/shortint/key_switching_key/mod.rs +++ b/tfhe/src/shortint/key_switching_key/mod.rs @@ -1089,9 +1089,11 @@ impl<'keys> CompressedKeySwitchingKeyBuildHelper<'keys> { let output_cks = output_key_pair.0; // Creation of the key switching key - let key_switching_key = output_cks - .atomic_pattern - .new_seeded_keyswitching_key(&input_secret_key, params); + let key_switching_key = ShortintEngine::with_thread_local_mut(|engine| { + output_cks + .atomic_pattern + .new_seeded_keyswitching_key_with_engine(&input_secret_key, params, engine) + }); let full_message_modulus_input = input_secret_key.carry_modulus.0 * input_secret_key.message_modulus.0;