refactor(gpu): multibit decompression

This commit is contained in:
Enzo Di Maria
2025-10-27 10:23:28 +01:00
committed by Agnès Leroy
parent 867f8fb579
commit 026cc376ed
8 changed files with 221 additions and 101 deletions

View File

@@ -17,9 +17,9 @@ uint64_t scratch_cuda_integer_decompress_radix_ciphertext_64(
uint32_t encryption_glwe_dimension, uint32_t encryption_polynomial_size,
uint32_t compression_glwe_dimension, uint32_t compression_polynomial_size,
uint32_t lwe_dimension, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t num_blocks_to_decompress, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type);
uint32_t grouping_factor, uint32_t num_blocks_to_decompress,
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type);
void cuda_integer_compress_radix_ciphertext_64(
CudaStreamsFFI streams, CudaPackedGlweCiphertextListFFI *glwe_array_out,

View File

@@ -22,21 +22,21 @@ uint64_t scratch_cuda_integer_decompress_radix_ciphertext_64(
uint32_t encryption_glwe_dimension, uint32_t encryption_polynomial_size,
uint32_t compression_glwe_dimension, uint32_t compression_polynomial_size,
uint32_t lwe_dimension, uint32_t pbs_level, uint32_t pbs_base_log,
uint32_t num_blocks_to_decompress, uint32_t message_modulus,
uint32_t carry_modulus, PBS_TYPE pbs_type, bool allocate_gpu_memory,
PBS_MS_REDUCTION_T noise_reduction_type) {
uint32_t grouping_factor, uint32_t num_blocks_to_decompress,
uint32_t message_modulus, uint32_t carry_modulus, PBS_TYPE pbs_type,
bool allocate_gpu_memory, PBS_MS_REDUCTION_T noise_reduction_type) {
// Decompression doesn't keyswitch, so big and small dimensions are the same
int_radix_params encryption_params(
pbs_type, encryption_glwe_dimension, encryption_polynomial_size,
lwe_dimension, lwe_dimension, 0, 0, pbs_level, pbs_base_log, 0,
message_modulus, carry_modulus, noise_reduction_type);
lwe_dimension, lwe_dimension, 0, 0, pbs_level, pbs_base_log,
grouping_factor, message_modulus, carry_modulus, noise_reduction_type);
int_radix_params compression_params(
pbs_type, compression_glwe_dimension, compression_polynomial_size,
lwe_dimension, compression_glwe_dimension * compression_polynomial_size,
0, 0, pbs_level, pbs_base_log, 0, message_modulus, carry_modulus,
noise_reduction_type);
0, 0, pbs_level, pbs_base_log, grouping_factor, message_modulus,
carry_modulus, noise_reduction_type);
return scratch_cuda_integer_decompress_radix_ciphertext<uint64_t>(
CudaStreams(streams), (int_decompression<uint64_t> **)mem_ptr,

View File

@@ -1656,6 +1656,7 @@ unsafe extern "C" {
lwe_dimension: u32,
pbs_level: u32,
pbs_base_log: u32,
grouping_factor: u32,
num_blocks_to_decompress: u32,
message_modulus: u32,
carry_modulus: u32,

View File

@@ -1,8 +1,9 @@
use crate::core_crypto::gpu::lwe_bootstrap_key::CudaLweBootstrapKey;
use crate::core_crypto::gpu::lwe_multi_bit_bootstrap_key::CudaLweMultiBitBootstrapKey;
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::{
allocate_and_generate_new_lwe_packing_keyswitch_key, par_generate_lwe_bootstrap_key,
LweBootstrapKey,
par_generate_lwe_multi_bit_bootstrap_key, LweBootstrapKey, LweMultiBitBootstrapKey,
};
use crate::integer::compression_keys::{CompressionKey, CompressionPrivateKeys};
use crate::integer::gpu::list_compression::server_keys::{
@@ -12,6 +13,7 @@ use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::RadixClientKey;
use crate::shortint::client_key::atomic_pattern::AtomicPatternClientKey;
use crate::shortint::engine::ShortintEngine;
use crate::shortint::parameters::PBSParameters;
use crate::shortint::EncryptionKeyChoice;
impl RadixClientKey {
@@ -58,34 +60,71 @@ impl RadixClientKey {
let cuda_compression_key =
CudaCompressionKey::from_compression_key(&glwe_compression_key, streams);
// Decompression key
let mut bsk = LweBootstrapKey::new(
0u64,
self.parameters().glwe_dimension().to_glwe_size(),
self.parameters().polynomial_size(),
private_compression_key.params.br_base_log(),
private_compression_key.params.br_level(),
compression_params
.packing_ks_glwe_dimension()
.to_equivalent_lwe_dimension(compression_params.packing_ks_polynomial_size()),
self.parameters().ciphertext_modulus(),
);
let blind_rotate_key = match std_cks.parameters {
PBSParameters::PBS(_) => {
let mut bsk = LweBootstrapKey::new(
0u64,
self.parameters().glwe_dimension().to_glwe_size(),
self.parameters().polynomial_size(),
private_compression_key.params.br_base_log(),
private_compression_key.params.br_level(),
compression_params
.packing_ks_glwe_dimension()
.to_equivalent_lwe_dimension(
compression_params.packing_ks_polynomial_size(),
),
self.parameters().ciphertext_modulus(),
);
ShortintEngine::with_thread_local_mut(|engine| {
par_generate_lwe_bootstrap_key(
&private_compression_key
.post_packing_ks_key
.as_lwe_secret_key(),
&std_cks.glwe_secret_key,
&mut bsk,
self.parameters().glwe_noise_distribution(),
&mut engine.encryption_generator,
);
});
ShortintEngine::with_thread_local_mut(|engine| {
par_generate_lwe_bootstrap_key(
&private_compression_key
.post_packing_ks_key
.as_lwe_secret_key(),
&std_cks.glwe_secret_key,
&mut bsk,
self.parameters().glwe_noise_distribution(),
&mut engine.encryption_generator,
);
});
let blind_rotate_key = CudaBootstrappingKey::Classic(
CudaLweBootstrapKey::from_lwe_bootstrap_key(&bsk, None, streams),
);
CudaBootstrappingKey::Classic(CudaLweBootstrapKey::from_lwe_bootstrap_key(
&bsk, None, streams,
))
}
PBSParameters::MultiBitPBS(pbs_params) => {
let mut bsk = LweMultiBitBootstrapKey::new(
0u64,
self.parameters().glwe_dimension().to_glwe_size(),
self.parameters().polynomial_size(),
private_compression_key.params.br_base_log(),
private_compression_key.params.br_level(),
compression_params
.packing_ks_glwe_dimension()
.to_equivalent_lwe_dimension(
compression_params.packing_ks_polynomial_size(),
),
pbs_params.grouping_factor,
self.parameters().ciphertext_modulus(),
);
ShortintEngine::with_thread_local_mut(|engine| {
par_generate_lwe_multi_bit_bootstrap_key(
&private_compression_key
.post_packing_ks_key
.as_lwe_secret_key(),
&std_cks.glwe_secret_key,
&mut bsk,
self.parameters().glwe_noise_distribution(),
&mut engine.encryption_generator,
);
});
CudaBootstrappingKey::MultiBit(
CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key(&bsk, streams),
)
}
};
let cuda_decompression_key = CudaDecompressionKey {
blind_rotate_key,

View File

@@ -1,4 +1,5 @@
use crate::core_crypto::gpu::lwe_bootstrap_key::CudaLweBootstrapKey;
use crate::core_crypto::gpu::lwe_multi_bit_bootstrap_key::CudaLweMultiBitBootstrapKey;
use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::{CiphertextModulus, GlweDimension, PolynomialSize};
use crate::integer::compression_keys::{
@@ -44,8 +45,30 @@ impl CompressedDecompressionKey {
ciphertext_modulus,
}
}
crate::shortint::list_compression::CompressedDecompressionKey::MultiBit { .. } => {
todo!()
crate::shortint::list_compression::CompressedDecompressionKey::MultiBit {
multi_bit_blind_rotate_key,
lwe_per_glwe,
} => {
let h_bootstrap_key = multi_bit_blind_rotate_key
.as_view()
.par_decompress_into_lwe_multi_bit_bootstrap_key();
let d_bootstrap_key = CudaLweMultiBitBootstrapKey::from_lwe_multi_bit_bootstrap_key(
&h_bootstrap_key,
streams,
);
let blind_rotate_key = CudaBootstrappingKey::MultiBit(d_bootstrap_key);
CudaDecompressionKey {
blind_rotate_key,
lwe_per_glwe: *lwe_per_glwe,
glwe_dimension,
polynomial_size,
message_modulus,
carry_modulus,
ciphertext_modulus,
}
}
}
}

View File

@@ -5,7 +5,7 @@ use crate::core_crypto::gpu::CudaStreams;
use crate::core_crypto::prelude::packed_integers::PackedIntegers;
use crate::core_crypto::prelude::{
glwe_ciphertext_size, glwe_mask_size, CiphertextModulus, CiphertextModulusLog,
GlweCiphertextCount, LweCiphertextCount, PolynomialSize, UnsignedInteger,
GlweCiphertextCount, LweBskGroupingFactor, LweCiphertextCount, PolynomialSize, UnsignedInteger,
};
use crate::error;
use crate::integer::ciphertext::{DataKind, NoiseSquashingCompressionKey};
@@ -16,7 +16,7 @@ use crate::integer::gpu::ciphertext::CudaRadixCiphertext;
use crate::integer::gpu::server_key::CudaBootstrappingKey;
use crate::integer::gpu::{
cuda_backend_compress, cuda_backend_decompress, cuda_backend_get_compression_size_on_gpu,
cuda_backend_get_decompression_size_on_gpu, cuda_memcpy_async_gpu_to_gpu,
cuda_backend_get_decompression_size_on_gpu, cuda_memcpy_async_gpu_to_gpu, PBSType,
};
use crate::prelude::CastInto;
use crate::shortint::ciphertext::{
@@ -420,7 +420,7 @@ impl CudaDecompressionKey {
let carry_modulus = self.carry_modulus;
let ciphertext_modulus = self.ciphertext_modulus;
match &self.blind_rotate_key {
let output_lwe = match &self.blind_rotate_key {
CudaBootstrappingKey::Classic(bsk) => {
assert!(
bsk.ms_noise_reduction_configuration.is_none(),
@@ -450,39 +450,70 @@ impl CudaDecompressionKey {
lwe_dimension,
bsk.decomp_base_log(),
bsk.decomp_level_count(),
LweBskGroupingFactor(0),
PBSType::Classical,
indexes_array.as_slice(),
indexes_array_len.0 as u32,
);
}
streams.synchronize();
let degree = match kind {
DataKind::Unsigned(_) | DataKind::Signed(_) | DataKind::String { .. } => {
Degree::new(message_modulus.0 - 1)
}
DataKind::Boolean => Degree::new(1),
};
let first_block_info = CudaBlockInfo {
degree,
message_modulus,
carry_modulus,
atomic_pattern: AtomicPatternKind::Standard(PBSOrder::KeyswitchBootstrap),
noise_level: NoiseLevel::NOMINAL,
};
let blocks = vec![first_block_info; output_lwe.0.lwe_ciphertext_count.0];
Ok(CudaRadixCiphertext {
d_blocks: output_lwe,
info: CudaRadixCiphertextInfo { blocks },
})
output_lwe
}
CudaBootstrappingKey::MultiBit(_) => {
panic! {"Compression is currently not compatible with Multi-Bit PBS"}
CudaBootstrappingKey::MultiBit(bsk) => {
let lwe_dimension = bsk.output_lwe_dimension();
let mut output_lwe = CudaLweCiphertextList::new(
lwe_dimension,
indexes_array_len,
ciphertext_modulus,
streams,
);
unsafe {
cuda_backend_decompress(
streams,
&mut output_lwe,
packed_list,
&bsk.d_vec,
message_modulus,
carry_modulus,
encryption_glwe_dimension,
encryption_polynomial_size,
compression_glwe_dimension,
compression_polynomial_size,
lwe_dimension,
bsk.decomp_base_log(),
bsk.decomp_level_count(),
bsk.grouping_factor,
PBSType::MultiBit,
indexes_array.as_slice(),
indexes_array_len.0 as u32,
);
}
output_lwe
}
}
};
let degree = match kind {
DataKind::Unsigned(_) | DataKind::Signed(_) | DataKind::String { .. } => {
Degree::new(message_modulus.0 - 1)
}
DataKind::Boolean => Degree::new(1),
};
let first_block_info = CudaBlockInfo {
degree,
message_modulus,
carry_modulus,
atomic_pattern: AtomicPatternKind::Standard(PBSOrder::KeyswitchBootstrap),
noise_level: NoiseLevel::NOMINAL,
};
let blocks = vec![first_block_info; output_lwe.0.lwe_ciphertext_count.0];
Ok(CudaRadixCiphertext {
d_blocks: output_lwe,
info: CudaRadixCiphertextInfo { blocks },
})
}
pub fn get_gpu_list_unpack_size_on_gpu(
&self,
@@ -532,11 +563,29 @@ impl CudaDecompressionKey {
lwe_dimension,
bsk.decomp_base_log(),
bsk.decomp_level_count(),
LweBskGroupingFactor(0),
PBSType::Classical,
indexes_array_len.0 as u32,
)
}
CudaBootstrappingKey::MultiBit(_) => {
panic! {"Compression is currently not compatible with Multi-Bit PBS"}
CudaBootstrappingKey::MultiBit(bsk) => {
let lwe_dimension = bsk.output_lwe_dimension();
cuda_backend_get_decompression_size_on_gpu(
streams,
message_modulus,
carry_modulus,
encryption_glwe_dimension,
encryption_polynomial_size,
compression_glwe_dimension,
compression_polynomial_size,
lwe_dimension,
bsk.decomp_base_log(),
bsk.decomp_level_count(),
bsk.grouping_factor,
PBSType::MultiBit,
indexes_array_len.0 as u32,
)
}
}
}
@@ -587,11 +636,29 @@ impl CudaDecompressionKey {
lwe_dimension,
bsk.decomp_base_log(),
bsk.decomp_level_count(),
LweBskGroupingFactor(0),
PBSType::Classical,
indexes_array_len.0 as u32,
)
}
CudaBootstrappingKey::MultiBit(_) => {
panic! {"Compression is currently not compatible with Multi-Bit PBS"}
CudaBootstrappingKey::MultiBit(bsk) => {
let lwe_dimension = bsk.output_lwe_dimension();
cuda_backend_get_decompression_size_on_gpu(
streams,
message_modulus,
carry_modulus,
encryption_glwe_dimension,
encryption_polynomial_size,
compression_glwe_dimension,
compression_polynomial_size,
lwe_dimension,
bsk.decomp_base_log(),
bsk.decomp_level_count(),
bsk.grouping_factor,
PBSType::MultiBit,
indexes_array_len.0 as u32,
)
}
}
}

View File

@@ -891,30 +891,11 @@ pub(crate) unsafe fn cuda_backend_decompress<B: Numeric>(
lwe_dimension: LweDimension,
pbs_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
pbs_type: PBSType,
vec_indexes: &[u32],
num_blocks_to_decompress: u32,
) {
assert_eq!(
streams.gpu_indexes[0],
lwe_array_out.0.d_vec.gpu_index(0),
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
streams.gpu_indexes[0].get(),
lwe_array_out.0.d_vec.gpu_index(0).get(),
);
assert_eq!(
streams.gpu_indexes[0],
glwe_in.data.gpu_index(0),
"GPU error: first stream is on GPU {}, first input pointer is on GPU {}",
streams.gpu_indexes[0].get(),
glwe_in.data.gpu_index(0).get(),
);
assert_eq!(
streams.gpu_indexes[0],
bootstrapping_key.gpu_index(0),
"GPU error: first stream is on GPU {}, first bsk pointer is on GPU {}",
streams.gpu_indexes[0].get(),
bootstrapping_key.gpu_index(0).get(),
);
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let mut lwe_array_out_ffi = prepare_cuda_lwe_ct_ffi(lwe_array_out);
@@ -930,10 +911,11 @@ pub(crate) unsafe fn cuda_backend_decompress<B: Numeric>(
lwe_dimension.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks_to_decompress,
message_modulus.0 as u32,
carry_modulus.0 as u32,
PBSType::Classical as u32,
pbs_type as u32,
true,
PBSMSNoiseReductionType::NoReduction as u32,
);
@@ -1031,6 +1013,8 @@ pub(crate) fn cuda_backend_get_decompression_size_on_gpu(
lwe_dimension: LweDimension,
pbs_base_log: DecompositionBaseLog,
pbs_level: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
pbs_type: PBSType,
num_blocks_to_decompress: u32,
) -> u64 {
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
@@ -1045,10 +1029,11 @@ pub(crate) fn cuda_backend_get_decompression_size_on_gpu(
lwe_dimension.0 as u32,
pbs_level.0 as u32,
pbs_base_log.0 as u32,
grouping_factor.0 as u32,
num_blocks_to_decompress,
message_modulus.0 as u32,
carry_modulus.0 as u32,
PBSType::Classical as u32,
pbs_type as u32,
false,
PBSMSNoiseReductionType::NoReduction as u32,
)

View File

@@ -1,4 +1,7 @@
use crate::shortint::parameters::list_compression::ClassicCompressionParameters;
use crate::core_crypto::prelude::LweBskGroupingFactor;
use crate::shortint::parameters::list_compression::{
ClassicCompressionParameters, MultiBitCompressionParameters,
};
use crate::shortint::parameters::{
CiphertextModulusLog, CompressionParameters, DecompositionBaseLog, DecompositionLevelCount,
DynamicDistribution, GlweDimension, LweCiphertextCount, PolynomialSize, StandardDev,
@@ -20,9 +23,9 @@ pub const V1_5_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: CompressionPa
/// p-fail = 2^-129.275, algorithmic cost ~ 41458
pub const V1_5_COMP_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128:
CompressionParameters = CompressionParameters::Classic(ClassicCompressionParameters {
CompressionParameters = CompressionParameters::MultiBit(MultiBitCompressionParameters {
br_level: DecompositionLevelCount(1),
br_base_log: DecompositionBaseLog(23),
br_base_log: DecompositionBaseLog(22),
packing_ks_level: DecompositionLevelCount(3),
packing_ks_base_log: DecompositionBaseLog(4),
packing_ks_polynomial_size: PolynomialSize(256),
@@ -30,6 +33,7 @@ pub const V1_5_COMP_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFOR
lwe_per_glwe: LweCiphertextCount(256),
storage_log_modulus: CiphertextModulusLog(12),
packing_ks_key_noise_distribution: DynamicDistribution::new_t_uniform(43),
decompression_grouping_factor: LweBskGroupingFactor(4),
});
/// p-fail = 2^-128.218, algorithmic cost ~ 42199
@@ -50,9 +54,9 @@ pub const V1_5_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128: CompressionPa
/// p-fail = 2^-128.218, algorithmic cost ~ 42199
pub const V1_5_COMP_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M128:
CompressionParameters = CompressionParameters::Classic(ClassicCompressionParameters {
CompressionParameters = CompressionParameters::MultiBit(MultiBitCompressionParameters {
br_level: DecompositionLevelCount(1),
br_base_log: DecompositionBaseLog(23),
br_base_log: DecompositionBaseLog(22),
packing_ks_level: DecompositionLevelCount(2),
packing_ks_base_log: DecompositionBaseLog(6),
packing_ks_polynomial_size: PolynomialSize(256),
@@ -62,4 +66,5 @@ pub const V1_5_COMP_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIA
packing_ks_key_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
1.339775301998614e-07,
)),
decompression_grouping_factor: LweBskGroupingFactor(4),
});