From 15cab8b413e9245dd3e9701cb36cf6887720216e Mon Sep 17 00:00:00 2001 From: Agnes Leroy Date: Thu, 25 Sep 2025 13:59:55 +0200 Subject: [PATCH] chore(gpu): get decompress size on gpu without calling on_gpu --- .../compressed_ciphertext_list.rs | 53 +++++++++++----- .../ciphertext/compressed_ciphertext_list.rs | 39 ++++++++++++ .../ciphertext/compressed_ciphertext_list.rs | 3 +- .../gpu/list_compression/server_keys.rs | 61 ++++++++++++++++++- 4 files changed, 137 insertions(+), 19 deletions(-) diff --git a/tfhe/src/high_level_api/compressed_ciphertext_list.rs b/tfhe/src/high_level_api/compressed_ciphertext_list.rs index b641cb4dd..0da50ba76 100644 --- a/tfhe/src/high_level_api/compressed_ciphertext_list.rs +++ b/tfhe/src/high_level_api/compressed_ciphertext_list.rs @@ -622,20 +622,32 @@ impl CiphertextList for CompressedCiphertextList { global_state::with_internal_keys(|key| { if let InternalServerKey::Cuda(cuda_key) = key { let streams = &cuda_key.streams; - cuda_key - .key - .decompression_key - .as_ref() - .ok_or_else(|| { - crate::Error::new("Compression key not set in server key".to_owned()) - }) - .map(|decompression_key| { - self.inner.on_gpu(streams).get_decompression_size_on_gpu( - index, - decompression_key, - streams, - ) - }) + match &self.inner { + InnerCompressedCiphertextList::Cpu(ct_list) => cuda_key + .key + .decompression_key + .as_ref() + .ok_or_else(|| { + crate::Error::new("Compression key not set in server key".to_owned()) + }) + .map(|decompression_key| { + ct_list.get_decompression_size_on_gpu(index, decompression_key, streams) + }), + InnerCompressedCiphertextList::Cuda(cuda_ct_list) => cuda_key + .key + .decompression_key + .as_ref() + .ok_or_else(|| { + crate::Error::new("Compression key not set in server key".to_owned()) + }) + .map(|decompression_key| { + cuda_ct_list.get_decompression_size_on_gpu( + index, + decompression_key, + streams, + ) + }), + } } else { Ok(Some(0)) } @@ -1221,7 +1233,7 @@ mod tests { ); } - let compressed_list = compressed_list_init.build().unwrap(); + let mut compressed_list = compressed_list_init.build().unwrap(); let decompress_ct1_size_on_gpu = compressed_list .get_decompression_size_on_gpu(0) .unwrap() @@ -1232,6 +1244,17 @@ mod tests { .unwrap() .unwrap(); check_valid_cuda_malloc_assert_oom(decompress_ct2_size_on_gpu, GpuIndex::new(0)); + compressed_list.move_to_current_device(); + let decompress_ct1_size_on_gpu_1 = compressed_list + .get_decompression_size_on_gpu(0) + .unwrap() + .unwrap(); + let decompress_ct2_size_on_gpu_1 = compressed_list + .get_decompression_size_on_gpu(1) + .unwrap() + .unwrap(); + assert_eq!(decompress_ct1_size_on_gpu, decompress_ct1_size_on_gpu_1); + assert_eq!(decompress_ct2_size_on_gpu, decompress_ct2_size_on_gpu_1); } } } diff --git a/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs index 1e26cceca..63203b99c 100644 --- a/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/ciphertext/compressed_ciphertext_list.rs @@ -1,6 +1,10 @@ use super::{DataKind, Expandable, RadixCiphertext, SignedRadixCiphertext}; +#[cfg(feature = "gpu")] +use crate::core_crypto::gpu::CudaStreams; use crate::integer::backward_compatibility::ciphertext::CompressedCiphertextListVersions; use crate::integer::compression_keys::{CompressionKey, DecompressionKey}; +#[cfg(feature = "gpu")] +use crate::integer::gpu::list_compression::server_keys::CudaDecompressionKey; use crate::integer::BooleanBlock; use crate::shortint::ciphertext::CompressedCiphertextList as ShortintCompressedCiphertextList; use crate::shortint::Ciphertext; @@ -176,6 +180,41 @@ impl CompressedCiphertextList { .map(|(blocks, kind)| T::from_expanded_blocks(blocks, kind)) .transpose() } + #[cfg(feature = "gpu")] + pub fn get_decompression_size_on_gpu( + &self, + index: usize, + decomp_key: &CudaDecompressionKey, + streams: &CudaStreams, + ) -> Option { + self.get_blocks_of_size_on_gpu(index, decomp_key, streams) + } + #[cfg(feature = "gpu")] + fn get_blocks_of_size_on_gpu( + &self, + index: usize, + decomp_key: &CudaDecompressionKey, + streams: &CudaStreams, + ) -> Option { + let preceding_infos = self.info.get(..index)?; + let current_info = self.info.get(index).copied()?; + let message_modulus = self.packed_list.message_modulus()?; + + let start_block_index: usize = preceding_infos + .iter() + .copied() + .map(|kind| kind.num_blocks(message_modulus)) + .sum(); + + let end_block_index = start_block_index + current_info.num_blocks(message_modulus) - 1; + + Some(decomp_key.get_cpu_list_unpack_size_on_gpu( + &self.packed_list, + start_block_index, + end_block_index, + streams, + )) + } } #[cfg(test)] diff --git a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs index b022bb420..1f787b3a2 100644 --- a/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs +++ b/tfhe/src/integer/gpu/ciphertext/compressed_ciphertext_list.rs @@ -125,7 +125,6 @@ impl CudaCompressedCiphertextList { )) } - #[allow(clippy::unnecessary_wraps)] fn get_blocks_of_size_on_gpu( &self, index: usize, @@ -144,7 +143,7 @@ impl CudaCompressedCiphertextList { let end_block_index = start_block_index + current_info.num_blocks(message_modulus) - 1; - Some(decomp_key.get_unpack_size_on_gpu( + Some(decomp_key.get_gpu_list_unpack_size_on_gpu( &self.packed_list, start_block_index, end_block_index, diff --git a/tfhe/src/integer/gpu/list_compression/server_keys.rs b/tfhe/src/integer/gpu/list_compression/server_keys.rs index 27eaf100e..e733b2476 100644 --- a/tfhe/src/integer/gpu/list_compression/server_keys.rs +++ b/tfhe/src/integer/gpu/list_compression/server_keys.rs @@ -20,6 +20,7 @@ use crate::integer::gpu::{ }; use crate::prelude::CastInto; use crate::shortint::ciphertext::{ + CompressedCiphertextList, CompressedSquashedNoiseCiphertextList as ShortintCompressedSquashedNoiseCiphertextList, Degree, NoiseLevel, }; @@ -477,14 +478,14 @@ impl CudaDecompressionKey { } } } - pub fn get_unpack_size_on_gpu( + pub fn get_gpu_list_unpack_size_on_gpu( &self, packed_list: &CudaPackedGlweCiphertextList, start_block_index: usize, end_block_index: usize, streams: &CudaStreams, ) -> u64 { - if packed_list.bodies_count() == 0 && start_block_index == end_block_index { + if start_block_index == end_block_index { return 0; } @@ -500,6 +501,62 @@ impl CudaDecompressionKey { let encryption_polynomial_size = self.polynomial_size; let compression_glwe_dimension = meta.glwe_dimension; let compression_polynomial_size = meta.polynomial_size; + + let indexes_array_len = LweCiphertextCount(indexes_array.len()); + + let message_modulus = self.message_modulus; + let carry_modulus = self.carry_modulus; + + match &self.blind_rotate_key { + CudaBootstrappingKey::Classic(bsk) => { + assert!( + bsk.ms_noise_reduction_configuration.is_none(), + "Decompression key should not do modulus switch noise reduction" + ); + let lwe_dimension = bsk.output_lwe_dimension(); + + get_decompression_size_on_gpu( + streams, + message_modulus, + carry_modulus, + encryption_glwe_dimension, + encryption_polynomial_size, + compression_glwe_dimension, + compression_polynomial_size, + lwe_dimension, + bsk.decomp_base_log(), + bsk.decomp_level_count(), + indexes_array_len.0 as u32, + ) + } + CudaBootstrappingKey::MultiBit(_) => { + panic! {"Compression is currently not compatible with Multi-Bit PBS"} + } + } + } + pub fn get_cpu_list_unpack_size_on_gpu( + &self, + packed_list: &CompressedCiphertextList, + start_block_index: usize, + end_block_index: usize, + streams: &CudaStreams, + ) -> u64 { + if start_block_index == end_block_index { + return 0; + } + + let indexes_array = (start_block_index..=end_block_index) + .map(|x| x as u32) + .collect_vec(); + + let encryption_glwe_dimension = self.glwe_dimension; + let encryption_polynomial_size = self.polynomial_size; + + let compression_polynomial_size = + packed_list.modulus_switched_glwe_ciphertext_list[0].polynomial_size(); + let compression_glwe_dimension = + packed_list.modulus_switched_glwe_ciphertext_list[0].glwe_dimension(); + let indexes_array_len = LweCiphertextCount(indexes_array.len()); let message_modulus = self.message_modulus;