mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
feat(gpu): implement conversion from CudaCompressedCiphertextList to CompressedCiphertextList
This commit is contained in:
@@ -37,8 +37,8 @@ __global__ void pack(Torus *array_out, Torus *array_in, uint32_t log_modulus,
|
||||
|
||||
template <typename Torus>
|
||||
__host__ void host_pack(cudaStream_t stream, uint32_t gpu_index,
|
||||
Torus *array_out, Torus *array_in, uint32_t num_inputs,
|
||||
uint32_t body_count, int_compression<Torus> *mem_ptr) {
|
||||
Torus *array_out, Torus *array_in, uint32_t body_count,
|
||||
int_compression<Torus> *mem_ptr) {
|
||||
cudaSetDevice(gpu_index);
|
||||
auto params = mem_ptr->compression_params;
|
||||
|
||||
@@ -105,7 +105,7 @@ __host__ void host_integer_compress(cudaStream_t *streams,
|
||||
check_cuda_error(cudaGetLastError());
|
||||
|
||||
host_pack<Torus>(streams[0], gpu_indexes[0], glwe_array_out,
|
||||
tmp_glwe_array_out, num_glwes, body_count, mem_ptr);
|
||||
tmp_glwe_array_out, body_count, mem_ptr);
|
||||
}
|
||||
|
||||
template <typename Torus>
|
||||
|
||||
@@ -80,11 +80,11 @@ use crate::core_crypto::prelude::*;
|
||||
#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
|
||||
#[versionize(CompressedModulusSwitchedGlweCiphertextVersions)]
|
||||
pub struct CompressedModulusSwitchedGlweCiphertext<Scalar: UnsignedInteger> {
|
||||
packed_integers: PackedIntegers<Scalar>,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
bodies_count: LweCiphertextCount,
|
||||
uncompressed_ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
pub(crate) packed_integers: PackedIntegers<Scalar>,
|
||||
pub(crate) glwe_dimension: GlweDimension,
|
||||
pub(crate) polynomial_size: PolynomialSize,
|
||||
pub(crate) bodies_count: LweCiphertextCount,
|
||||
pub(crate) uncompressed_ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
}
|
||||
|
||||
impl<Scalar: UnsignedTorus> CompressedModulusSwitchedGlweCiphertext<Scalar> {
|
||||
|
||||
@@ -99,7 +99,7 @@ impl CompressedCiphertextListBuilder {
|
||||
#[versionize(CompressedCiphertextListVersions)]
|
||||
pub struct CompressedCiphertextList {
|
||||
pub(crate) packed_list: ShortintCompressedCiphertextList,
|
||||
info: Vec<DataKind>,
|
||||
pub(crate) info: Vec<DataKind>,
|
||||
}
|
||||
|
||||
impl CompressedCiphertextList {
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
use crate::core_crypto::entities::packed_integers::PackedIntegers;
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::integer::ciphertext::DataKind;
|
||||
use crate::core_crypto::prelude::compressed_modulus_switched_glwe_ciphertext::CompressedModulusSwitchedGlweCiphertext;
|
||||
use crate::core_crypto::prelude::{CiphertextCount, ContiguousEntityContainer, LweCiphertextCount};
|
||||
use crate::integer::ciphertext::{CompressedCiphertextList, DataKind};
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::{
|
||||
CudaRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
|
||||
@@ -7,6 +10,8 @@ use crate::integer::gpu::ciphertext::{
|
||||
use crate::integer::gpu::list_compression::server_keys::{
|
||||
CudaCompressionKey, CudaDecompressionKey, CudaPackedGlweCiphertext,
|
||||
};
|
||||
use crate::shortint::ciphertext::CompressedCiphertextList as ShortintCompressedCiphertextList;
|
||||
use itertools::Itertools;
|
||||
|
||||
pub struct CudaCompressedCiphertextList {
|
||||
pub(crate) packed_list: CudaPackedGlweCiphertext,
|
||||
@@ -45,6 +50,60 @@ impl CudaCompressedCiphertextList {
|
||||
streams,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn to_compressed_ciphertext_list(&self, streams: &CudaStreams) -> CompressedCiphertextList {
|
||||
let glwe_list = self
|
||||
.packed_list
|
||||
.glwe_ciphertext_list
|
||||
.to_glwe_ciphertext_list(streams);
|
||||
let ciphertext_modulus = self.packed_list.glwe_ciphertext_list.ciphertext_modulus();
|
||||
|
||||
let first_element = self.packed_list.block_info.first().unwrap();
|
||||
let message_modulus = first_element.message_modulus;
|
||||
let carry_modulus = first_element.carry_modulus;
|
||||
let pbs_order = first_element.pbs_order;
|
||||
let lwe_per_glwe = self.packed_list.lwe_per_glwe;
|
||||
let log_modulus = self.packed_list.storage_log_modulus;
|
||||
|
||||
let initial_len = self.packed_list.initial_len;
|
||||
let number_bits_to_pack = initial_len * log_modulus.0;
|
||||
let len = number_bits_to_pack.div_ceil(u64::BITS as usize);
|
||||
|
||||
let modulus_switched_glwe_ciphertext_list = glwe_list
|
||||
.iter()
|
||||
.map(|x| {
|
||||
let glwe_dimension = x.glwe_size().to_glwe_dimension();
|
||||
let polynomial_size = x.polynomial_size();
|
||||
CompressedModulusSwitchedGlweCiphertext {
|
||||
packed_integers: PackedIntegers {
|
||||
packed_coeffs: x.into_container()[0..len].to_vec(),
|
||||
log_modulus: self.packed_list.storage_log_modulus,
|
||||
initial_len,
|
||||
},
|
||||
glwe_dimension,
|
||||
polynomial_size,
|
||||
bodies_count: LweCiphertextCount(self.packed_list.bodies_count),
|
||||
uncompressed_ciphertext_modulus: ciphertext_modulus,
|
||||
}
|
||||
})
|
||||
.collect_vec();
|
||||
|
||||
let count = CiphertextCount(self.packed_list.bodies_count);
|
||||
let packed_list = ShortintCompressedCiphertextList {
|
||||
modulus_switched_glwe_ciphertext_list,
|
||||
ciphertext_modulus,
|
||||
message_modulus,
|
||||
carry_modulus,
|
||||
pbs_order,
|
||||
lwe_per_glwe,
|
||||
count,
|
||||
};
|
||||
|
||||
CompressedCiphertextList {
|
||||
packed_list: packed_list,
|
||||
info: self.info.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait CudaCompressible {
|
||||
@@ -136,8 +195,9 @@ impl CudaCompressedCiphertextListBuilder {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::integer::ciphertext::CompressedCiphertextListBuilder;
|
||||
use crate::integer::gpu::gen_keys_radix_gpu;
|
||||
use crate::integer::ClientKey;
|
||||
use crate::integer::{BooleanBlock, ClientKey, RadixCiphertext, SignedRadixCiphertext};
|
||||
use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
|
||||
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
|
||||
|
||||
@@ -198,4 +258,82 @@ mod tests {
|
||||
assert!(decrypted);
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn test_gpu_compressed_ciphertext_conversion_to_cpu() {
|
||||
let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64);
|
||||
|
||||
let private_compression_key =
|
||||
cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64);
|
||||
|
||||
let streams = CudaStreams::new_multi_gpu();
|
||||
|
||||
let num_blocks = 32;
|
||||
let (radix_cks, _) = gen_keys_radix_gpu(
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
|
||||
num_blocks,
|
||||
&streams,
|
||||
);
|
||||
let (compressed_compression_key, compressed_decompression_key) =
|
||||
radix_cks.new_compressed_compression_decompression_keys(&private_compression_key);
|
||||
|
||||
let cuda_compression_key = compressed_compression_key.decompress_to_cuda(&streams);
|
||||
|
||||
let compression_key = compressed_compression_key.decompress();
|
||||
let decompression_key = compressed_decompression_key.decompress();
|
||||
|
||||
for _ in 0..NB_TESTS {
|
||||
let ct1 = radix_cks.encrypt(3_u32);
|
||||
let ct2 = radix_cks.encrypt_signed(-2);
|
||||
let ct3 = radix_cks.encrypt_bool(true);
|
||||
|
||||
// Copy to GPU
|
||||
let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams);
|
||||
let d_ct2 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct2, &streams);
|
||||
let d_ct3 = CudaBooleanBlock::from_boolean_block(&ct3, &streams);
|
||||
|
||||
let cuda_compressed = CudaCompressedCiphertextListBuilder::new()
|
||||
.push(d_ct1, &streams)
|
||||
.push(d_ct2, &streams)
|
||||
.push(d_ct3, &streams)
|
||||
.build(&cuda_compression_key, &streams);
|
||||
|
||||
let reference_compressed = CompressedCiphertextListBuilder::new()
|
||||
.push(ct1)
|
||||
.push(ct2)
|
||||
.push(ct3)
|
||||
.build(&compression_key);
|
||||
|
||||
let converted_compressed = cuda_compressed.to_compressed_ciphertext_list(&streams);
|
||||
|
||||
let decompressed1: RadixCiphertext = converted_compressed
|
||||
.get(0, &decompression_key)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let reference_decompressed1 = reference_compressed
|
||||
.get(0, &decompression_key)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(decompressed1, reference_decompressed1);
|
||||
|
||||
let decompressed2: SignedRadixCiphertext = converted_compressed
|
||||
.get(1, &decompression_key)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let reference_decompressed2 = reference_compressed
|
||||
.get(1, &decompression_key)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(decompressed2, reference_decompressed2);
|
||||
|
||||
let decompressed3: BooleanBlock = converted_compressed
|
||||
.get(2, &decompression_key)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let reference_decompressed3 = reference_compressed
|
||||
.get(2, &decompression_key)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(decompressed3, reference_decompressed3);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,6 +32,8 @@ pub struct CudaPackedGlweCiphertext {
|
||||
pub block_info: Vec<CudaBlockInfo>,
|
||||
pub bodies_count: usize,
|
||||
pub storage_log_modulus: CiphertextModulusLog,
|
||||
pub lwe_per_glwe: LweCiphertextCount,
|
||||
pub initial_len: usize,
|
||||
}
|
||||
|
||||
impl CudaCompressionKey {
|
||||
@@ -126,10 +128,12 @@ impl CudaCompressionKey {
|
||||
.map(|x| x.d_blocks.lwe_ciphertext_count().0)
|
||||
.sum();
|
||||
|
||||
let num_glwes = num_lwes.div_ceil(self.lwe_per_glwe.0);
|
||||
|
||||
let mut output_glwe = CudaGlweCiphertextList::new(
|
||||
compress_glwe_size.to_glwe_dimension(),
|
||||
compress_polynomial_size,
|
||||
GlweCiphertextCount(ciphertexts.len()),
|
||||
GlweCiphertextCount(num_glwes),
|
||||
ciphertext_modulus,
|
||||
streams,
|
||||
);
|
||||
@@ -159,11 +163,16 @@ impl CudaCompressionKey {
|
||||
info
|
||||
};
|
||||
|
||||
let initial_len =
|
||||
compress_glwe_size.to_glwe_dimension().0 * compress_polynomial_size.0 + num_lwes;
|
||||
|
||||
CudaPackedGlweCiphertext {
|
||||
glwe_ciphertext_list: output_glwe,
|
||||
block_info: info,
|
||||
bodies_count: num_lwes,
|
||||
storage_log_modulus: self.storage_log_modulus,
|
||||
lwe_per_glwe: LweCiphertextCount(compress_polynomial_size.0),
|
||||
initial_len,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user