chore(gpu): get decompress size on gpu without calling on_gpu

This commit is contained in:
Agnes Leroy
2025-09-25 13:59:55 +02:00
committed by Agnès Leroy
parent 23d46ba2bc
commit 15cab8b413
4 changed files with 137 additions and 19 deletions

View File

@@ -622,20 +622,32 @@ impl CiphertextList for CompressedCiphertextList {
global_state::with_internal_keys(|key| {
if let InternalServerKey::Cuda(cuda_key) = key {
let streams = &cuda_key.streams;
cuda_key
.key
.decompression_key
.as_ref()
.ok_or_else(|| {
crate::Error::new("Compression key not set in server key".to_owned())
})
.map(|decompression_key| {
self.inner.on_gpu(streams).get_decompression_size_on_gpu(
index,
decompression_key,
streams,
)
})
match &self.inner {
InnerCompressedCiphertextList::Cpu(ct_list) => cuda_key
.key
.decompression_key
.as_ref()
.ok_or_else(|| {
crate::Error::new("Compression key not set in server key".to_owned())
})
.map(|decompression_key| {
ct_list.get_decompression_size_on_gpu(index, decompression_key, streams)
}),
InnerCompressedCiphertextList::Cuda(cuda_ct_list) => cuda_key
.key
.decompression_key
.as_ref()
.ok_or_else(|| {
crate::Error::new("Compression key not set in server key".to_owned())
})
.map(|decompression_key| {
cuda_ct_list.get_decompression_size_on_gpu(
index,
decompression_key,
streams,
)
}),
}
} else {
Ok(Some(0))
}
@@ -1221,7 +1233,7 @@ mod tests {
);
}
let compressed_list = compressed_list_init.build().unwrap();
let mut compressed_list = compressed_list_init.build().unwrap();
let decompress_ct1_size_on_gpu = compressed_list
.get_decompression_size_on_gpu(0)
.unwrap()
@@ -1232,6 +1244,17 @@ mod tests {
.unwrap()
.unwrap();
check_valid_cuda_malloc_assert_oom(decompress_ct2_size_on_gpu, GpuIndex::new(0));
compressed_list.move_to_current_device();
let decompress_ct1_size_on_gpu_1 = compressed_list
.get_decompression_size_on_gpu(0)
.unwrap()
.unwrap();
let decompress_ct2_size_on_gpu_1 = compressed_list
.get_decompression_size_on_gpu(1)
.unwrap()
.unwrap();
assert_eq!(decompress_ct1_size_on_gpu, decompress_ct1_size_on_gpu_1);
assert_eq!(decompress_ct2_size_on_gpu, decompress_ct2_size_on_gpu_1);
}
}
}

View File

@@ -1,6 +1,10 @@
use super::{DataKind, Expandable, RadixCiphertext, SignedRadixCiphertext};
#[cfg(feature = "gpu")]
use crate::core_crypto::gpu::CudaStreams;
use crate::integer::backward_compatibility::ciphertext::CompressedCiphertextListVersions;
use crate::integer::compression_keys::{CompressionKey, DecompressionKey};
#[cfg(feature = "gpu")]
use crate::integer::gpu::list_compression::server_keys::CudaDecompressionKey;
use crate::integer::BooleanBlock;
use crate::shortint::ciphertext::CompressedCiphertextList as ShortintCompressedCiphertextList;
use crate::shortint::Ciphertext;
@@ -176,6 +180,41 @@ impl CompressedCiphertextList {
.map(|(blocks, kind)| T::from_expanded_blocks(blocks, kind))
.transpose()
}
#[cfg(feature = "gpu")]
pub fn get_decompression_size_on_gpu(
&self,
index: usize,
decomp_key: &CudaDecompressionKey,
streams: &CudaStreams,
) -> Option<u64> {
self.get_blocks_of_size_on_gpu(index, decomp_key, streams)
}
#[cfg(feature = "gpu")]
fn get_blocks_of_size_on_gpu(
&self,
index: usize,
decomp_key: &CudaDecompressionKey,
streams: &CudaStreams,
) -> Option<u64> {
let preceding_infos = self.info.get(..index)?;
let current_info = self.info.get(index).copied()?;
let message_modulus = self.packed_list.message_modulus()?;
let start_block_index: usize = preceding_infos
.iter()
.copied()
.map(|kind| kind.num_blocks(message_modulus))
.sum();
let end_block_index = start_block_index + current_info.num_blocks(message_modulus) - 1;
Some(decomp_key.get_cpu_list_unpack_size_on_gpu(
&self.packed_list,
start_block_index,
end_block_index,
streams,
))
}
}
#[cfg(test)]

View File

@@ -125,7 +125,6 @@ impl CudaCompressedCiphertextList {
))
}
#[allow(clippy::unnecessary_wraps)]
fn get_blocks_of_size_on_gpu(
&self,
index: usize,
@@ -144,7 +143,7 @@ impl CudaCompressedCiphertextList {
let end_block_index = start_block_index + current_info.num_blocks(message_modulus) - 1;
Some(decomp_key.get_unpack_size_on_gpu(
Some(decomp_key.get_gpu_list_unpack_size_on_gpu(
&self.packed_list,
start_block_index,
end_block_index,

View File

@@ -20,6 +20,7 @@ use crate::integer::gpu::{
};
use crate::prelude::CastInto;
use crate::shortint::ciphertext::{
CompressedCiphertextList,
CompressedSquashedNoiseCiphertextList as ShortintCompressedSquashedNoiseCiphertextList, Degree,
NoiseLevel,
};
@@ -477,14 +478,14 @@ impl CudaDecompressionKey {
}
}
}
pub fn get_unpack_size_on_gpu(
pub fn get_gpu_list_unpack_size_on_gpu(
&self,
packed_list: &CudaPackedGlweCiphertextList<u64>,
start_block_index: usize,
end_block_index: usize,
streams: &CudaStreams,
) -> u64 {
if packed_list.bodies_count() == 0 && start_block_index == end_block_index {
if start_block_index == end_block_index {
return 0;
}
@@ -500,6 +501,62 @@ impl CudaDecompressionKey {
let encryption_polynomial_size = self.polynomial_size;
let compression_glwe_dimension = meta.glwe_dimension;
let compression_polynomial_size = meta.polynomial_size;
let indexes_array_len = LweCiphertextCount(indexes_array.len());
let message_modulus = self.message_modulus;
let carry_modulus = self.carry_modulus;
match &self.blind_rotate_key {
CudaBootstrappingKey::Classic(bsk) => {
assert!(
bsk.ms_noise_reduction_configuration.is_none(),
"Decompression key should not do modulus switch noise reduction"
);
let lwe_dimension = bsk.output_lwe_dimension();
get_decompression_size_on_gpu(
streams,
message_modulus,
carry_modulus,
encryption_glwe_dimension,
encryption_polynomial_size,
compression_glwe_dimension,
compression_polynomial_size,
lwe_dimension,
bsk.decomp_base_log(),
bsk.decomp_level_count(),
indexes_array_len.0 as u32,
)
}
CudaBootstrappingKey::MultiBit(_) => {
panic! {"Compression is currently not compatible with Multi-Bit PBS"}
}
}
}
pub fn get_cpu_list_unpack_size_on_gpu(
&self,
packed_list: &CompressedCiphertextList,
start_block_index: usize,
end_block_index: usize,
streams: &CudaStreams,
) -> u64 {
if start_block_index == end_block_index {
return 0;
}
let indexes_array = (start_block_index..=end_block_index)
.map(|x| x as u32)
.collect_vec();
let encryption_glwe_dimension = self.glwe_dimension;
let encryption_polynomial_size = self.polynomial_size;
let compression_polynomial_size =
packed_list.modulus_switched_glwe_ciphertext_list[0].polynomial_size();
let compression_glwe_dimension =
packed_list.modulus_switched_glwe_ciphertext_list[0].glwe_dimension();
let indexes_array_len = LweCiphertextCount(indexes_array.len());
let message_modulus = self.message_modulus;