From 58dbdf7dd451e10ee0f523d2f7cdeba549ef2b8a Mon Sep 17 00:00:00 2001 From: Thomas Montaigu Date: Wed, 28 Jan 2026 11:28:11 +0100 Subject: [PATCH] refactor(hlapi): add IntegerExpandedServerKey::convert_to_gpu And use it to convert from CompressedServerKey to CudaServerKey. --- tfhe/src/high_level_api/keys/expanded.rs | 136 ++++++++++++++++++++++- tfhe/src/high_level_api/keys/server.rs | 115 +------------------ 2 files changed, 139 insertions(+), 112 deletions(-) diff --git a/tfhe/src/high_level_api/keys/expanded.rs b/tfhe/src/high_level_api/keys/expanded.rs index c81231d3e..f5331092b 100644 --- a/tfhe/src/high_level_api/keys/expanded.rs +++ b/tfhe/src/high_level_api/keys/expanded.rs @@ -43,6 +43,27 @@ where }, } +#[cfg(feature = "gpu")] +impl ShortintExpandedBootstrappingKey +where + Scalar: UnsignedInteger, + ModSwitchScalar: UnsignedInteger, +{ + pub(crate) fn glwe_dimension(&self) -> GlweDimension { + match self { + Self::Classic { bsk, .. } => bsk.glwe_size().to_glwe_dimension(), + Self::MultiBit { bsk, .. } => bsk.glwe_size().to_glwe_dimension(), + } + } + + pub(crate) fn polynomial_size(&self) -> PolynomialSize { + match self { + Self::Classic { bsk, .. } => bsk.polynomial_size(), + Self::MultiBit { bsk, .. } => bsk.polynomial_size(), + } + } +} + impl ShortintExpandedBootstrappingKey where ModSwitchScalar: UnsignedInteger, @@ -170,6 +191,29 @@ pub(crate) enum ExpandedAtomicPatternServerKey { pub(crate) type ShortintExpandedServerKey = GenericServerKey; +#[cfg(feature = "gpu")] +impl ShortintExpandedServerKey { + pub(crate) fn glwe_dimension(&self) -> GlweDimension { + match &self.atomic_pattern { + ExpandedAtomicPatternServerKey::Standard(std) => std.bootstrapping_key.glwe_dimension(), + ExpandedAtomicPatternServerKey::KeySwitch32(ks32) => { + ks32.bootstrapping_key.glwe_dimension() + } + } + } + + pub(crate) fn polynomial_size(&self) -> PolynomialSize { + match &self.atomic_pattern { + ExpandedAtomicPatternServerKey::Standard(std) => { + std.bootstrapping_key.polynomial_size() + } + ExpandedAtomicPatternServerKey::KeySwitch32(ks32) => { + ks32.bootstrapping_key.polynomial_size() + } + } + } +} + pub(crate) enum ExpandedAtomicPatternNoiseSquashingKey { Standard(ShortintExpandedBootstrappingKey), KeySwitch32(ShortintExpandedBootstrappingKey), @@ -296,9 +340,95 @@ impl IntegerExpandedServerKey { } } -// ============================================================================= -// Expand implementations for compressed types (NOT using pre-seeded generators) -// ============================================================================= +#[cfg(feature = "gpu")] +impl IntegerExpandedServerKey { + pub(crate) fn convert_to_gpu( + &self, + streams: &crate::core_crypto::gpu::CudaStreams, + ) -> crate::Result { + use crate::high_level_api::keys::cpk_re_randomization::ReRandomizationKeySwitchingKey; + use crate::integer::gpu::key_switching_key::CudaKeySwitchingKeyMaterial; + use crate::integer::gpu::list_compression::server_keys::{ + CudaCompressionKey, CudaDecompressionKey, CudaNoiseSquashingCompressionKey, + }; + use crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey; + use crate::integer::gpu::CudaServerKey; + + // Destructure to ensure all fields are handled + let Self { + compute_key, + cpk_key_switching_key_material, + compression_key, + decompression_key, + noise_squashing_key, + noise_squashing_compression_key, + cpk_re_randomization_key_switching_key_material, + } = self; + + let key = CudaServerKey::from_expanded_server_key(compute_key, streams)?; + + let cpk_key_switching_key_material = cpk_key_switching_key_material.as_ref().map(|ksk| { + CudaKeySwitchingKeyMaterial::from_key_switching_key_material(&ksk.as_view(), streams) + }); + + let compression_key = compression_key + .as_ref() + .map(|ck| CudaCompressionKey::from_compression_key(ck, streams)); + + let decompression_key = decompression_key + .as_ref() + .map(|dk| { + CudaDecompressionKey::from_expanded_decompression_key( + dk, + compute_key.glwe_dimension(), + compute_key.polynomial_size(), + compute_key.message_modulus, + compute_key.carry_modulus, + compute_key.ciphertext_modulus, + streams, + ) + }) + .transpose()?; + + let noise_squashing_key = noise_squashing_key + .as_ref() + .map(|nsk| CudaNoiseSquashingKey::from_expanded_noise_squashing_key(nsk, streams)); + + let noise_squashing_compression_key = + noise_squashing_compression_key.as_ref().map(|nsck| { + CudaNoiseSquashingCompressionKey::from_noise_squashing_compression_key( + nsck, streams, + ) + }); + + let cpk_re_randomization_key_switching_key_material = + cpk_re_randomization_key_switching_key_material + .as_ref() + .map(|re_rand_ksk| match re_rand_ksk { + ReRandomizationKeySwitchingKey::UseCPKEncryptionKSK => { + ReRandomizationKeySwitchingKey::UseCPKEncryptionKSK + } + ReRandomizationKeySwitchingKey::DedicatedKSK(ksk) => { + ReRandomizationKeySwitchingKey::DedicatedKSK( + CudaKeySwitchingKeyMaterial::from_key_switching_key_material( + &ksk.as_view(), + streams, + ), + ) + } + }); + + Ok(crate::high_level_api::keys::inner::IntegerCudaServerKey { + key, + cpk_key_switching_key_material, + compression_key, + decompression_key, + noise_squashing_key, + noise_squashing_compression_key, + cpk_re_randomization_key_switching_key_material, + }) + } +} impl ShortintCompressedBootstrappingKey where diff --git a/tfhe/src/high_level_api/keys/server.rs b/tfhe/src/high_level_api/keys/server.rs index 6dfcaf047..17273786d 100644 --- a/tfhe/src/high_level_api/keys/server.rs +++ b/tfhe/src/high_level_api/keys/server.rs @@ -2,9 +2,7 @@ use super::ClientKey; use crate::backward_compatibility::keys::{CompressedServerKeyVersions, ServerKeyVersions}; use crate::conformance::ParameterSetConformant; #[cfg(feature = "gpu")] -use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey; -#[cfg(feature = "gpu")] -use crate::core_crypto::gpu::{synchronize_devices, CudaStreams}; +use crate::core_crypto::gpu::CudaStreams; #[cfg(feature = "gpu")] use crate::high_level_api::keys::inner::IntegerCudaServerKey; use crate::high_level_api::keys::{ @@ -340,114 +338,13 @@ impl CompressedServerKey { gpu_choice: impl Into, ) -> CudaServerKey { let streams = gpu_choice.into().build_streams(); - let key = crate::integer::gpu::CudaServerKey::decompress_from_cpu( - &self.integer_key.key, - &streams, - ); - let cpk_key_switching_key_material = self + let key = self .integer_key - .cpk_key_switching_key_material - .as_ref() - .map(|cpk_ksk_material| { - let ksk_material = cpk_ksk_material.decompress(); - let d_ksk = CudaLweKeyswitchKey::from_lwe_keyswitch_key( - &ksk_material.material.key_switching_key, - &streams, - ); - CudaKeySwitchingKeyMaterial { - lwe_keyswitch_key: d_ksk, - destination_key: ksk_material.material.destination_key, - cast_rshift: ksk_material.material.cast_rshift, - } - }); - - let cpk_re_randomization_key_switching_key_material = self - .integer_key - .cpk_re_randomization_key_switching_key_material - .as_ref() - .map( - |cpk_re_randomization_ksk_material| match cpk_re_randomization_ksk_material { - CompressedReRandomizationKeySwitchingKey::UseCPKEncryptionKSK => { - ReRandomizationKeySwitchingKey::UseCPKEncryptionKSK - } - CompressedReRandomizationKeySwitchingKey::DedicatedKSK(dedicated_ksk) => { - let ksk_material = dedicated_ksk.decompress(); - let d_ksk = CudaLweKeyswitchKey::from_lwe_keyswitch_key( - &ksk_material.material.key_switching_key, - &streams, - ); - let d_ksk_material = CudaKeySwitchingKeyMaterial { - lwe_keyswitch_key: d_ksk, - destination_key: ksk_material.material.destination_key, - cast_rshift: ksk_material.material.cast_rshift, - }; - - ReRandomizationKeySwitchingKey::DedicatedKSK(d_ksk_material) - } - }, - ); - - let compression_key: Option< - crate::integer::gpu::list_compression::server_keys::CudaCompressionKey, - > = self - .integer_key - .compression_key - .as_ref() - .map(|compression_key| compression_key.decompress_to_cuda(&streams)); - let decompression_key: Option< - crate::integer::gpu::list_compression::server_keys::CudaDecompressionKey, - > = match &self.integer_key.decompression_key { - // Convert decompression_key in the (cpu) integer keyset to the GPU if it's defined - Some(decompression_key) => { - let polynomial_size = decompression_key.key.polynomial_size(); - let glwe_dimension = decompression_key.key.glwe_size().to_glwe_dimension(); - let message_modulus = key.message_modulus; - let carry_modulus = key.carry_modulus; - let ciphertext_modulus = decompression_key.key.ciphertext_modulus(); - Some(decompression_key.decompress_to_cuda( - glwe_dimension, - polynomial_size, - message_modulus, - carry_modulus, - ciphertext_modulus, - &streams, - )) - } - None => None, - }; - - // Convert noise_squashing_key in the (cpu) integer keyset to the GPU if it's defined - let noise_squashing_key: Option< - crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey, - > = self - .integer_key - .noise_squashing_key - .as_ref() - .map(|noise_squashing_key| noise_squashing_key.decompress_to_cuda(&streams)); - - // Convert noise_squashing_compression_key in the (cpu) integer keyset to the GPU if it's - // defined - let noise_squashing_compression_key: Option< - crate::integer::gpu::list_compression::server_keys::CudaNoiseSquashingCompressionKey, - > = self - .integer_key - .noise_squashing_compression_key - .as_ref() - .map(|noise_squashing_compression_key| { - noise_squashing_compression_key.decompress_to_cuda(&streams) - }); - - synchronize_devices(&streams); + .expand() + .convert_to_gpu(&streams) + .expect("Unsupported configuration"); CudaServerKey { - key: Arc::new(IntegerCudaServerKey { - key, - cpk_key_switching_key_material, - compression_key, - decompression_key, - noise_squashing_key, - noise_squashing_compression_key, - cpk_re_randomization_key_switching_key_material, - }), + key: Arc::new(key), tag: self.tag.clone(), streams, }