mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
chore(gpu): get decompress size on gpu without calling on_gpu
This commit is contained in:
@@ -622,20 +622,32 @@ impl CiphertextList for CompressedCiphertextList {
|
||||
global_state::with_internal_keys(|key| {
|
||||
if let InternalServerKey::Cuda(cuda_key) = key {
|
||||
let streams = &cuda_key.streams;
|
||||
cuda_key
|
||||
.key
|
||||
.decompression_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
crate::Error::new("Compression key not set in server key".to_owned())
|
||||
})
|
||||
.map(|decompression_key| {
|
||||
self.inner.on_gpu(streams).get_decompression_size_on_gpu(
|
||||
index,
|
||||
decompression_key,
|
||||
streams,
|
||||
)
|
||||
})
|
||||
match &self.inner {
|
||||
InnerCompressedCiphertextList::Cpu(ct_list) => cuda_key
|
||||
.key
|
||||
.decompression_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
crate::Error::new("Compression key not set in server key".to_owned())
|
||||
})
|
||||
.map(|decompression_key| {
|
||||
ct_list.get_decompression_size_on_gpu(index, decompression_key, streams)
|
||||
}),
|
||||
InnerCompressedCiphertextList::Cuda(cuda_ct_list) => cuda_key
|
||||
.key
|
||||
.decompression_key
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
crate::Error::new("Compression key not set in server key".to_owned())
|
||||
})
|
||||
.map(|decompression_key| {
|
||||
cuda_ct_list.get_decompression_size_on_gpu(
|
||||
index,
|
||||
decompression_key,
|
||||
streams,
|
||||
)
|
||||
}),
|
||||
}
|
||||
} else {
|
||||
Ok(Some(0))
|
||||
}
|
||||
@@ -1221,7 +1233,7 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
let compressed_list = compressed_list_init.build().unwrap();
|
||||
let mut compressed_list = compressed_list_init.build().unwrap();
|
||||
let decompress_ct1_size_on_gpu = compressed_list
|
||||
.get_decompression_size_on_gpu(0)
|
||||
.unwrap()
|
||||
@@ -1232,6 +1244,17 @@ mod tests {
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
check_valid_cuda_malloc_assert_oom(decompress_ct2_size_on_gpu, GpuIndex::new(0));
|
||||
compressed_list.move_to_current_device();
|
||||
let decompress_ct1_size_on_gpu_1 = compressed_list
|
||||
.get_decompression_size_on_gpu(0)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let decompress_ct2_size_on_gpu_1 = compressed_list
|
||||
.get_decompression_size_on_gpu(1)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(decompress_ct1_size_on_gpu, decompress_ct1_size_on_gpu_1);
|
||||
assert_eq!(decompress_ct2_size_on_gpu, decompress_ct2_size_on_gpu_1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
use super::{DataKind, Expandable, RadixCiphertext, SignedRadixCiphertext};
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::integer::backward_compatibility::ciphertext::CompressedCiphertextListVersions;
|
||||
use crate::integer::compression_keys::{CompressionKey, DecompressionKey};
|
||||
#[cfg(feature = "gpu")]
|
||||
use crate::integer::gpu::list_compression::server_keys::CudaDecompressionKey;
|
||||
use crate::integer::BooleanBlock;
|
||||
use crate::shortint::ciphertext::CompressedCiphertextList as ShortintCompressedCiphertextList;
|
||||
use crate::shortint::Ciphertext;
|
||||
@@ -176,6 +180,41 @@ impl CompressedCiphertextList {
|
||||
.map(|(blocks, kind)| T::from_expanded_blocks(blocks, kind))
|
||||
.transpose()
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
pub fn get_decompression_size_on_gpu(
|
||||
&self,
|
||||
index: usize,
|
||||
decomp_key: &CudaDecompressionKey,
|
||||
streams: &CudaStreams,
|
||||
) -> Option<u64> {
|
||||
self.get_blocks_of_size_on_gpu(index, decomp_key, streams)
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
fn get_blocks_of_size_on_gpu(
|
||||
&self,
|
||||
index: usize,
|
||||
decomp_key: &CudaDecompressionKey,
|
||||
streams: &CudaStreams,
|
||||
) -> Option<u64> {
|
||||
let preceding_infos = self.info.get(..index)?;
|
||||
let current_info = self.info.get(index).copied()?;
|
||||
let message_modulus = self.packed_list.message_modulus()?;
|
||||
|
||||
let start_block_index: usize = preceding_infos
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|kind| kind.num_blocks(message_modulus))
|
||||
.sum();
|
||||
|
||||
let end_block_index = start_block_index + current_info.num_blocks(message_modulus) - 1;
|
||||
|
||||
Some(decomp_key.get_cpu_list_unpack_size_on_gpu(
|
||||
&self.packed_list,
|
||||
start_block_index,
|
||||
end_block_index,
|
||||
streams,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -125,7 +125,6 @@ impl CudaCompressedCiphertextList {
|
||||
))
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn get_blocks_of_size_on_gpu(
|
||||
&self,
|
||||
index: usize,
|
||||
@@ -144,7 +143,7 @@ impl CudaCompressedCiphertextList {
|
||||
|
||||
let end_block_index = start_block_index + current_info.num_blocks(message_modulus) - 1;
|
||||
|
||||
Some(decomp_key.get_unpack_size_on_gpu(
|
||||
Some(decomp_key.get_gpu_list_unpack_size_on_gpu(
|
||||
&self.packed_list,
|
||||
start_block_index,
|
||||
end_block_index,
|
||||
|
||||
@@ -20,6 +20,7 @@ use crate::integer::gpu::{
|
||||
};
|
||||
use crate::prelude::CastInto;
|
||||
use crate::shortint::ciphertext::{
|
||||
CompressedCiphertextList,
|
||||
CompressedSquashedNoiseCiphertextList as ShortintCompressedSquashedNoiseCiphertextList, Degree,
|
||||
NoiseLevel,
|
||||
};
|
||||
@@ -477,14 +478,14 @@ impl CudaDecompressionKey {
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn get_unpack_size_on_gpu(
|
||||
pub fn get_gpu_list_unpack_size_on_gpu(
|
||||
&self,
|
||||
packed_list: &CudaPackedGlweCiphertextList<u64>,
|
||||
start_block_index: usize,
|
||||
end_block_index: usize,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
if packed_list.bodies_count() == 0 && start_block_index == end_block_index {
|
||||
if start_block_index == end_block_index {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -500,6 +501,62 @@ impl CudaDecompressionKey {
|
||||
let encryption_polynomial_size = self.polynomial_size;
|
||||
let compression_glwe_dimension = meta.glwe_dimension;
|
||||
let compression_polynomial_size = meta.polynomial_size;
|
||||
|
||||
let indexes_array_len = LweCiphertextCount(indexes_array.len());
|
||||
|
||||
let message_modulus = self.message_modulus;
|
||||
let carry_modulus = self.carry_modulus;
|
||||
|
||||
match &self.blind_rotate_key {
|
||||
CudaBootstrappingKey::Classic(bsk) => {
|
||||
assert!(
|
||||
bsk.ms_noise_reduction_configuration.is_none(),
|
||||
"Decompression key should not do modulus switch noise reduction"
|
||||
);
|
||||
let lwe_dimension = bsk.output_lwe_dimension();
|
||||
|
||||
get_decompression_size_on_gpu(
|
||||
streams,
|
||||
message_modulus,
|
||||
carry_modulus,
|
||||
encryption_glwe_dimension,
|
||||
encryption_polynomial_size,
|
||||
compression_glwe_dimension,
|
||||
compression_polynomial_size,
|
||||
lwe_dimension,
|
||||
bsk.decomp_base_log(),
|
||||
bsk.decomp_level_count(),
|
||||
indexes_array_len.0 as u32,
|
||||
)
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(_) => {
|
||||
panic! {"Compression is currently not compatible with Multi-Bit PBS"}
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn get_cpu_list_unpack_size_on_gpu(
|
||||
&self,
|
||||
packed_list: &CompressedCiphertextList,
|
||||
start_block_index: usize,
|
||||
end_block_index: usize,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
if start_block_index == end_block_index {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let indexes_array = (start_block_index..=end_block_index)
|
||||
.map(|x| x as u32)
|
||||
.collect_vec();
|
||||
|
||||
let encryption_glwe_dimension = self.glwe_dimension;
|
||||
let encryption_polynomial_size = self.polynomial_size;
|
||||
|
||||
let compression_polynomial_size =
|
||||
packed_list.modulus_switched_glwe_ciphertext_list[0].polynomial_size();
|
||||
let compression_glwe_dimension =
|
||||
packed_list.modulus_switched_glwe_ciphertext_list[0].glwe_dimension();
|
||||
|
||||
let indexes_array_len = LweCiphertextCount(indexes_array.len());
|
||||
|
||||
let message_modulus = self.message_modulus;
|
||||
|
||||
Reference in New Issue
Block a user