refactor(hlapi): add IntegerExpandedServerKey::convert_to_gpu

And use it to convert from CompressedServerKey to CudaServerKey.
This commit is contained in:
Thomas Montaigu
2026-01-28 11:28:11 +01:00
committed by tmontaigu
parent 1a7b7ace47
commit 58dbdf7dd4
2 changed files with 139 additions and 112 deletions

View File

@@ -43,6 +43,27 @@ where
},
}
#[cfg(feature = "gpu")]
impl<Scalar, ModSwitchScalar> ShortintExpandedBootstrappingKey<Scalar, ModSwitchScalar>
where
Scalar: UnsignedInteger,
ModSwitchScalar: UnsignedInteger,
{
pub(crate) fn glwe_dimension(&self) -> GlweDimension {
match self {
Self::Classic { bsk, .. } => bsk.glwe_size().to_glwe_dimension(),
Self::MultiBit { bsk, .. } => bsk.glwe_size().to_glwe_dimension(),
}
}
pub(crate) fn polynomial_size(&self) -> PolynomialSize {
match self {
Self::Classic { bsk, .. } => bsk.polynomial_size(),
Self::MultiBit { bsk, .. } => bsk.polynomial_size(),
}
}
}
impl<ModSwitchScalar> ShortintExpandedBootstrappingKey<u64, ModSwitchScalar>
where
ModSwitchScalar: UnsignedInteger,
@@ -170,6 +191,29 @@ pub(crate) enum ExpandedAtomicPatternServerKey {
pub(crate) type ShortintExpandedServerKey = GenericServerKey<ExpandedAtomicPatternServerKey>;
#[cfg(feature = "gpu")]
impl ShortintExpandedServerKey {
pub(crate) fn glwe_dimension(&self) -> GlweDimension {
match &self.atomic_pattern {
ExpandedAtomicPatternServerKey::Standard(std) => std.bootstrapping_key.glwe_dimension(),
ExpandedAtomicPatternServerKey::KeySwitch32(ks32) => {
ks32.bootstrapping_key.glwe_dimension()
}
}
}
pub(crate) fn polynomial_size(&self) -> PolynomialSize {
match &self.atomic_pattern {
ExpandedAtomicPatternServerKey::Standard(std) => {
std.bootstrapping_key.polynomial_size()
}
ExpandedAtomicPatternServerKey::KeySwitch32(ks32) => {
ks32.bootstrapping_key.polynomial_size()
}
}
}
}
pub(crate) enum ExpandedAtomicPatternNoiseSquashingKey {
Standard(ShortintExpandedBootstrappingKey<u128, u64>),
KeySwitch32(ShortintExpandedBootstrappingKey<u128, u32>),
@@ -296,9 +340,95 @@ impl IntegerExpandedServerKey {
}
}
// =============================================================================
// Expand implementations for compressed types (NOT using pre-seeded generators)
// =============================================================================
#[cfg(feature = "gpu")]
impl IntegerExpandedServerKey {
pub(crate) fn convert_to_gpu(
&self,
streams: &crate::core_crypto::gpu::CudaStreams,
) -> crate::Result<crate::high_level_api::keys::inner::IntegerCudaServerKey> {
use crate::high_level_api::keys::cpk_re_randomization::ReRandomizationKeySwitchingKey;
use crate::integer::gpu::key_switching_key::CudaKeySwitchingKeyMaterial;
use crate::integer::gpu::list_compression::server_keys::{
CudaCompressionKey, CudaDecompressionKey, CudaNoiseSquashingCompressionKey,
};
use crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey;
use crate::integer::gpu::CudaServerKey;
// Destructure to ensure all fields are handled
let Self {
compute_key,
cpk_key_switching_key_material,
compression_key,
decompression_key,
noise_squashing_key,
noise_squashing_compression_key,
cpk_re_randomization_key_switching_key_material,
} = self;
let key = CudaServerKey::from_expanded_server_key(compute_key, streams)?;
let cpk_key_switching_key_material = cpk_key_switching_key_material.as_ref().map(|ksk| {
CudaKeySwitchingKeyMaterial::from_key_switching_key_material(&ksk.as_view(), streams)
});
let compression_key = compression_key
.as_ref()
.map(|ck| CudaCompressionKey::from_compression_key(ck, streams));
let decompression_key = decompression_key
.as_ref()
.map(|dk| {
CudaDecompressionKey::from_expanded_decompression_key(
dk,
compute_key.glwe_dimension(),
compute_key.polynomial_size(),
compute_key.message_modulus,
compute_key.carry_modulus,
compute_key.ciphertext_modulus,
streams,
)
})
.transpose()?;
let noise_squashing_key = noise_squashing_key
.as_ref()
.map(|nsk| CudaNoiseSquashingKey::from_expanded_noise_squashing_key(nsk, streams));
let noise_squashing_compression_key =
noise_squashing_compression_key.as_ref().map(|nsck| {
CudaNoiseSquashingCompressionKey::from_noise_squashing_compression_key(
nsck, streams,
)
});
let cpk_re_randomization_key_switching_key_material =
cpk_re_randomization_key_switching_key_material
.as_ref()
.map(|re_rand_ksk| match re_rand_ksk {
ReRandomizationKeySwitchingKey::UseCPKEncryptionKSK => {
ReRandomizationKeySwitchingKey::UseCPKEncryptionKSK
}
ReRandomizationKeySwitchingKey::DedicatedKSK(ksk) => {
ReRandomizationKeySwitchingKey::DedicatedKSK(
CudaKeySwitchingKeyMaterial::from_key_switching_key_material(
&ksk.as_view(),
streams,
),
)
}
});
Ok(crate::high_level_api::keys::inner::IntegerCudaServerKey {
key,
cpk_key_switching_key_material,
compression_key,
decompression_key,
noise_squashing_key,
noise_squashing_compression_key,
cpk_re_randomization_key_switching_key_material,
})
}
}
impl<ModSwitchScalar> ShortintCompressedBootstrappingKey<ModSwitchScalar>
where

View File

@@ -2,9 +2,7 @@ use super::ClientKey;
use crate::backward_compatibility::keys::{CompressedServerKeyVersions, ServerKeyVersions};
use crate::conformance::ParameterSetConformant;
#[cfg(feature = "gpu")]
use crate::core_crypto::gpu::lwe_keyswitch_key::CudaLweKeyswitchKey;
#[cfg(feature = "gpu")]
use crate::core_crypto::gpu::{synchronize_devices, CudaStreams};
use crate::core_crypto::gpu::CudaStreams;
#[cfg(feature = "gpu")]
use crate::high_level_api::keys::inner::IntegerCudaServerKey;
use crate::high_level_api::keys::{
@@ -340,114 +338,13 @@ impl CompressedServerKey {
gpu_choice: impl Into<crate::CudaGpuChoice>,
) -> CudaServerKey {
let streams = gpu_choice.into().build_streams();
let key = crate::integer::gpu::CudaServerKey::decompress_from_cpu(
&self.integer_key.key,
&streams,
);
let cpk_key_switching_key_material = self
let key = self
.integer_key
.cpk_key_switching_key_material
.as_ref()
.map(|cpk_ksk_material| {
let ksk_material = cpk_ksk_material.decompress();
let d_ksk = CudaLweKeyswitchKey::from_lwe_keyswitch_key(
&ksk_material.material.key_switching_key,
&streams,
);
CudaKeySwitchingKeyMaterial {
lwe_keyswitch_key: d_ksk,
destination_key: ksk_material.material.destination_key,
cast_rshift: ksk_material.material.cast_rshift,
}
});
let cpk_re_randomization_key_switching_key_material = self
.integer_key
.cpk_re_randomization_key_switching_key_material
.as_ref()
.map(
|cpk_re_randomization_ksk_material| match cpk_re_randomization_ksk_material {
CompressedReRandomizationKeySwitchingKey::UseCPKEncryptionKSK => {
ReRandomizationKeySwitchingKey::UseCPKEncryptionKSK
}
CompressedReRandomizationKeySwitchingKey::DedicatedKSK(dedicated_ksk) => {
let ksk_material = dedicated_ksk.decompress();
let d_ksk = CudaLweKeyswitchKey::from_lwe_keyswitch_key(
&ksk_material.material.key_switching_key,
&streams,
);
let d_ksk_material = CudaKeySwitchingKeyMaterial {
lwe_keyswitch_key: d_ksk,
destination_key: ksk_material.material.destination_key,
cast_rshift: ksk_material.material.cast_rshift,
};
ReRandomizationKeySwitchingKey::DedicatedKSK(d_ksk_material)
}
},
);
let compression_key: Option<
crate::integer::gpu::list_compression::server_keys::CudaCompressionKey,
> = self
.integer_key
.compression_key
.as_ref()
.map(|compression_key| compression_key.decompress_to_cuda(&streams));
let decompression_key: Option<
crate::integer::gpu::list_compression::server_keys::CudaDecompressionKey,
> = match &self.integer_key.decompression_key {
// Convert decompression_key in the (cpu) integer keyset to the GPU if it's defined
Some(decompression_key) => {
let polynomial_size = decompression_key.key.polynomial_size();
let glwe_dimension = decompression_key.key.glwe_size().to_glwe_dimension();
let message_modulus = key.message_modulus;
let carry_modulus = key.carry_modulus;
let ciphertext_modulus = decompression_key.key.ciphertext_modulus();
Some(decompression_key.decompress_to_cuda(
glwe_dimension,
polynomial_size,
message_modulus,
carry_modulus,
ciphertext_modulus,
&streams,
))
}
None => None,
};
// Convert noise_squashing_key in the (cpu) integer keyset to the GPU if it's defined
let noise_squashing_key: Option<
crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey,
> = self
.integer_key
.noise_squashing_key
.as_ref()
.map(|noise_squashing_key| noise_squashing_key.decompress_to_cuda(&streams));
// Convert noise_squashing_compression_key in the (cpu) integer keyset to the GPU if it's
// defined
let noise_squashing_compression_key: Option<
crate::integer::gpu::list_compression::server_keys::CudaNoiseSquashingCompressionKey,
> = self
.integer_key
.noise_squashing_compression_key
.as_ref()
.map(|noise_squashing_compression_key| {
noise_squashing_compression_key.decompress_to_cuda(&streams)
});
synchronize_devices(&streams);
.expand()
.convert_to_gpu(&streams)
.expect("Unsupported configuration");
CudaServerKey {
key: Arc::new(IntegerCudaServerKey {
key,
cpk_key_switching_key_material,
compression_key,
decompression_key,
noise_squashing_key,
noise_squashing_compression_key,
cpk_re_randomization_key_switching_key_material,
}),
key: Arc::new(key),
tag: self.tag.clone(),
streams,
}