mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-04-28 03:01:21 -04:00
refactor(hlapi): add IntegerExpandedServerKey::convert_to_gpu
And use it to convert from CompressedServerKey to CudaServerKey.
This commit is contained in:
committed by
tmontaigu
parent
1a7b7ace47
commit
58dbdf7dd4
@@ -43,6 +43,27 @@ where
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl<Scalar, ModSwitchScalar> ShortintExpandedBootstrappingKey<Scalar, ModSwitchScalar>
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
ModSwitchScalar: UnsignedInteger,
|
||||
{
|
||||
pub(crate) fn glwe_dimension(&self) -> GlweDimension {
|
||||
match self {
|
||||
Self::Classic { bsk, .. } => bsk.glwe_size().to_glwe_dimension(),
|
||||
Self::MultiBit { bsk, .. } => bsk.glwe_size().to_glwe_dimension(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn polynomial_size(&self) -> PolynomialSize {
|
||||
match self {
|
||||
Self::Classic { bsk, .. } => bsk.polynomial_size(),
|
||||
Self::MultiBit { bsk, .. } => bsk.polynomial_size(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<ModSwitchScalar> ShortintExpandedBootstrappingKey<u64, ModSwitchScalar>
|
||||
where
|
||||
ModSwitchScalar: UnsignedInteger,
|
||||
@@ -170,6 +191,29 @@ pub(crate) enum ExpandedAtomicPatternServerKey {
|
||||
|
||||
pub(crate) type ShortintExpandedServerKey = GenericServerKey<ExpandedAtomicPatternServerKey>;
|
||||
|
||||
#[cfg(feature = "gpu")]
|
||||
impl ShortintExpandedServerKey {
|
||||
pub(crate) fn glwe_dimension(&self) -> GlweDimension {
|
||||
match &self.atomic_pattern {
|
||||
ExpandedAtomicPatternServerKey::Standard(std) => std.bootstrapping_key.glwe_dimension(),
|
||||
ExpandedAtomicPatternServerKey::KeySwitch32(ks32) => {
|
||||
ks32.bootstrapping_key.glwe_dimension()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn polynomial_size(&self) -> PolynomialSize {
|
||||
match &self.atomic_pattern {
|
||||
ExpandedAtomicPatternServerKey::Standard(std) => {
|
||||
std.bootstrapping_key.polynomial_size()
|
||||
}
|
||||
ExpandedAtomicPatternServerKey::KeySwitch32(ks32) => {
|
||||
ks32.bootstrapping_key.polynomial_size()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) enum ExpandedAtomicPatternNoiseSquashingKey {
|
||||
Standard(ShortintExpandedBootstrappingKey<u128, u64>),
|
||||
KeySwitch32(ShortintExpandedBootstrappingKey<u128, u32>),
|
||||
@@ -296,9 +340,95 @@ impl IntegerExpandedServerKey {
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Expand implementations for compressed types (NOT using pre-seeded generators)
|
||||
// =============================================================================
|
||||
#[cfg(feature = "gpu")]
|
||||
impl IntegerExpandedServerKey {
|
||||
pub(crate) fn convert_to_gpu(
|
||||
&self,
|
||||
streams: &crate::core_crypto::gpu::CudaStreams,
|
||||
) -> crate::Result<crate::high_level_api::keys::inner::IntegerCudaServerKey> {
|
||||
use crate::high_level_api::keys::cpk_re_randomization::ReRandomizationKeySwitchingKey;
|
||||
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKeyMaterial;
|
||||
use crate::integer::gpu::list_compression::server_keys::{
|
||||
CudaCompressionKey, CudaDecompressionKey, CudaNoiseSquashingCompressionKey,
|
||||
};
|
||||
use crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey;
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
|
||||
// Destructure to ensure all fields are handled
|
||||
let Self {
|
||||
compute_key,
|
||||
cpk_key_switching_key_material,
|
||||
compression_key,
|
||||
decompression_key,
|
||||
noise_squashing_key,
|
||||
noise_squashing_compression_key,
|
||||
cpk_re_randomization_key_switching_key_material,
|
||||
} = self;
|
||||
|
||||
let key = CudaServerKey::from_expanded_server_key(compute_key, streams)?;
|
||||
|
||||
let cpk_key_switching_key_material = cpk_key_switching_key_material.as_ref().map(|ksk| {
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key_material(&ksk.as_view(), streams)
|
||||
});
|
||||
|
||||
let compression_key = compression_key
|
||||
.as_ref()
|
||||
.map(|ck| CudaCompressionKey::from_compression_key(ck, streams));
|
||||
|
||||
let decompression_key = decompression_key
|
||||
.as_ref()
|
||||
.map(|dk| {
|
||||
CudaDecompressionKey::from_expanded_decompression_key(
|
||||
dk,
|
||||
compute_key.glwe_dimension(),
|
||||
compute_key.polynomial_size(),
|
||||
compute_key.message_modulus,
|
||||
compute_key.carry_modulus,
|
||||
compute_key.ciphertext_modulus,
|
||||
streams,
|
||||
)
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let noise_squashing_key = noise_squashing_key
|
||||
.as_ref()
|
||||
.map(|nsk| CudaNoiseSquashingKey::from_expanded_noise_squashing_key(nsk, streams));
|
||||
|
||||
let noise_squashing_compression_key =
|
||||
noise_squashing_compression_key.as_ref().map(|nsck| {
|
||||
CudaNoiseSquashingCompressionKey::from_noise_squashing_compression_key(
|
||||
nsck, streams,
|
||||
)
|
||||
});
|
||||
|
||||
let cpk_re_randomization_key_switching_key_material =
|
||||
cpk_re_randomization_key_switching_key_material
|
||||
.as_ref()
|
||||
.map(|re_rand_ksk| match re_rand_ksk {
|
||||
ReRandomizationKeySwitchingKey::UseCPKEncryptionKSK => {
|
||||
ReRandomizationKeySwitchingKey::UseCPKEncryptionKSK
|
||||
}
|
||||
ReRandomizationKeySwitchingKey::DedicatedKSK(ksk) => {
|
||||
ReRandomizationKeySwitchingKey::DedicatedKSK(
|
||||
CudaKeySwitchingKeyMaterial::from_key_switching_key_material(
|
||||
&ksk.as_view(),
|
||||
streams,
|
||||
),
|
||||
)
|
||||
}
|
||||
});
|
||||
|
||||
Ok(crate::high_level_api::keys::inner::IntegerCudaServerKey {
|
||||
key,
|
||||
cpk_key_switching_key_material,
|
||||
compression_key,
|
||||
decompression_key,
|
||||
noise_squashing_key,
|
||||
noise_squashing_compression_key,
|
||||
cpk_re_randomization_key_switching_key_material,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<ModSwitchScalar> ShortintCompressedBootstrappingKey<ModSwitchScalar>
|
||||
where
|
||||
|
||||
@@ -2,9 +2,7 @@ use super::ClientKey;
|
||||
use crate::backward_compatibility::keys::{CompressedServerKeyVersions, ServerKeyVersions};
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::core_crypto::gpu::{synchronize_devices, CudaStreams};
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::high_level_api::keys::inner::IntegerCudaServerKey;
|
||||
use crate::high_level_api::keys::{
|
||||
@@ -340,114 +338,13 @@ impl CompressedServerKey {
|
||||
gpu_choice: impl Into<crate::CudaGpuChoice>,
|
||||
) -> CudaServerKey {
|
||||
let streams = gpu_choice.into().build_streams();
|
||||
let key = crate::integer::gpu::CudaServerKey::decompress_from_cpu(
|
||||
&self.integer_key.key,
|
||||
&streams,
|
||||
);
|
||||
let cpk_key_switching_key_material = self
|
||||
let key = self
|
||||
.integer_key
|
||||
.cpk_key_switching_key_material
|
||||
.as_ref()
|
||||
.map(|cpk_ksk_material| {
|
||||
let ksk_material = cpk_ksk_material.decompress();
|
||||
let d_ksk = CudaLweKeyswitchKey::from_lwe_keyswitch_key(
|
||||
&ksk_material.material.key_switching_key,
|
||||
&streams,
|
||||
);
|
||||
CudaKeySwitchingKeyMaterial {
|
||||
lwe_keyswitch_key: d_ksk,
|
||||
destination_key: ksk_material.material.destination_key,
|
||||
cast_rshift: ksk_material.material.cast_rshift,
|
||||
}
|
||||
});
|
||||
|
||||
let cpk_re_randomization_key_switching_key_material = self
|
||||
.integer_key
|
||||
.cpk_re_randomization_key_switching_key_material
|
||||
.as_ref()
|
||||
.map(
|
||||
|cpk_re_randomization_ksk_material| match cpk_re_randomization_ksk_material {
|
||||
CompressedReRandomizationKeySwitchingKey::UseCPKEncryptionKSK => {
|
||||
ReRandomizationKeySwitchingKey::UseCPKEncryptionKSK
|
||||
}
|
||||
CompressedReRandomizationKeySwitchingKey::DedicatedKSK(dedicated_ksk) => {
|
||||
let ksk_material = dedicated_ksk.decompress();
|
||||
let d_ksk = CudaLweKeyswitchKey::from_lwe_keyswitch_key(
|
||||
&ksk_material.material.key_switching_key,
|
||||
&streams,
|
||||
);
|
||||
let d_ksk_material = CudaKeySwitchingKeyMaterial {
|
||||
lwe_keyswitch_key: d_ksk,
|
||||
destination_key: ksk_material.material.destination_key,
|
||||
cast_rshift: ksk_material.material.cast_rshift,
|
||||
};
|
||||
|
||||
ReRandomizationKeySwitchingKey::DedicatedKSK(d_ksk_material)
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
let compression_key: Option<
|
||||
crate::integer::gpu::list_compression::server_keys::CudaCompressionKey,
|
||||
> = self
|
||||
.integer_key
|
||||
.compression_key
|
||||
.as_ref()
|
||||
.map(|compression_key| compression_key.decompress_to_cuda(&streams));
|
||||
let decompression_key: Option<
|
||||
crate::integer::gpu::list_compression::server_keys::CudaDecompressionKey,
|
||||
> = match &self.integer_key.decompression_key {
|
||||
// Convert decompression_key in the (cpu) integer keyset to the GPU if it's defined
|
||||
Some(decompression_key) => {
|
||||
let polynomial_size = decompression_key.key.polynomial_size();
|
||||
let glwe_dimension = decompression_key.key.glwe_size().to_glwe_dimension();
|
||||
let message_modulus = key.message_modulus;
|
||||
let carry_modulus = key.carry_modulus;
|
||||
let ciphertext_modulus = decompression_key.key.ciphertext_modulus();
|
||||
Some(decompression_key.decompress_to_cuda(
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
message_modulus,
|
||||
carry_modulus,
|
||||
ciphertext_modulus,
|
||||
&streams,
|
||||
))
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
// Convert noise_squashing_key in the (cpu) integer keyset to the GPU if it's defined
|
||||
let noise_squashing_key: Option<
|
||||
crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey,
|
||||
> = self
|
||||
.integer_key
|
||||
.noise_squashing_key
|
||||
.as_ref()
|
||||
.map(|noise_squashing_key| noise_squashing_key.decompress_to_cuda(&streams));
|
||||
|
||||
// Convert noise_squashing_compression_key in the (cpu) integer keyset to the GPU if it's
|
||||
// defined
|
||||
let noise_squashing_compression_key: Option<
|
||||
crate::integer::gpu::list_compression::server_keys::CudaNoiseSquashingCompressionKey,
|
||||
> = self
|
||||
.integer_key
|
||||
.noise_squashing_compression_key
|
||||
.as_ref()
|
||||
.map(|noise_squashing_compression_key| {
|
||||
noise_squashing_compression_key.decompress_to_cuda(&streams)
|
||||
});
|
||||
|
||||
synchronize_devices(&streams);
|
||||
.expand()
|
||||
.convert_to_gpu(&streams)
|
||||
.expect("Unsupported configuration");
|
||||
CudaServerKey {
|
||||
key: Arc::new(IntegerCudaServerKey {
|
||||
key,
|
||||
cpk_key_switching_key_material,
|
||||
compression_key,
|
||||
decompression_key,
|
||||
noise_squashing_key,
|
||||
noise_squashing_compression_key,
|
||||
cpk_re_randomization_key_switching_key_material,
|
||||
}),
|
||||
key: Arc::new(key),
|
||||
tag: self.tag.clone(),
|
||||
streams,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user