fix(compression): update compression parameters, fix compression on GPU and improve test

- the new compression parameters went through a noise check to verify constraints
- CPU and GPU compression tests are improved and the same
- implement Debug, Eq, PartialEq to CompressedCiphertextList
- fix gpu compression when a radix ciphertext is split through more than one compact GLWE
This commit is contained in:
Pedro Alves
2024-10-04 12:14:40 -03:00
committed by Agnès Leroy
parent c2aae980ae
commit e376049e0f
13 changed files with 362 additions and 211 deletions

View File

@@ -563,6 +563,13 @@ test_integer_gpu: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --doc --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),integer,gpu -p $(TFHE_SPEC) -- integer::gpu::server_key::
.PHONY: test_integer_compression
test_integer_compression: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),integer -p $(TFHE_SPEC) -- integer::ciphertext::compressed_ciphertext_list::tests::
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --doc --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),integer -p $(TFHE_SPEC) -- integer::ciphertext::compress
.PHONY: test_integer_compression_gpu
test_integer_compression_gpu: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --profile $(CARGO_PROFILE) \

View File

@@ -61,10 +61,6 @@ __host__ void host_pack(cudaStream_t stream, uint32_t gpu_index,
auto out_len = (number_bits_to_pack + nbits - 1) / nbits;
// Last GLWE
auto last_body_count = num_lwes % compression_params.polynomial_size;
in_len =
compression_params.glwe_dimension * compression_params.polynomial_size +
last_body_count;
number_bits_to_pack = in_len * log_modulus;
auto last_out_len = (number_bits_to_pack + nbits - 1) / nbits;
@@ -75,10 +71,6 @@ __host__ void host_pack(cudaStream_t stream, uint32_t gpu_index,
dim3 grid(num_blocks);
dim3 threads(num_threads);
cuda_memset_async(array_out, 0,
num_glwes * (compression_params.glwe_dimension + 1) *
compression_params.polynomial_size * sizeof(Torus),
stream, gpu_index);
pack<Torus><<<grid, threads, 0, stream>>>(array_out, array_in, log_modulus,
num_coeffs, in_len, out_len);
check_cuda_error(cudaGetLastError());
@@ -294,7 +286,7 @@ host_integer_decompress(cudaStream_t *streams, uint32_t *gpu_indexes,
compression_params.glwe_dimension,
compression_params.polynomial_size);
d_indexes_array_chunk += num_lwes;
extracted_lwe += lwe_accumulator_size;
extracted_lwe += num_lwes * lwe_accumulator_size;
current_idx = last_idx;
}

View File

@@ -59,16 +59,19 @@ The following example shows how to compress and decompress a list containing 4 m
```rust
use tfhe::prelude::*;
use tfhe::shortint::parameters::{COMP_PARAM_MESSAGE_2_CARRY_2, PARAM_MESSAGE_2_CARRY_2};
use tfhe::shortint::parameters::{
COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
};
use tfhe::{
set_server_key, CompressedCiphertextList, CompressedCiphertextListBuilder, FheBool,
FheInt64, FheUint16, FheUint2, FheUint32,
};
fn main() {
let config = tfhe::ConfigBuilder::with_custom_parameters(PARAM_MESSAGE_2_CARRY_2)
.enable_compression(COMP_PARAM_MESSAGE_2_CARRY_2)
.build();
let config =
tfhe::ConfigBuilder::with_custom_parameters(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64)
.enable_compression(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64)
.build();
let ck = tfhe::ClientKey::generate(config);
let sk = tfhe::ServerKey::new(&ck);

View File

@@ -197,16 +197,19 @@ The following example shows how to compress and decompress a list containing 4 m
```rust
use tfhe::prelude::*;
use tfhe::shortint::parameters::{COMP_PARAM_MESSAGE_2_CARRY_2, PARAM_MESSAGE_2_CARRY_2};
use tfhe::shortint::parameters::{
COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64, PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
};
use tfhe::{
set_server_key, CompressedCiphertextList, CompressedCiphertextListBuilder, FheBool,
FheInt64, FheUint16, FheUint2, FheUint32,
};
fn main() {
let config = tfhe::ConfigBuilder::with_custom_parameters(PARAM_MESSAGE_2_CARRY_2)
.enable_compression(COMP_PARAM_MESSAGE_2_CARRY_2)
.build();
let config =
tfhe::ConfigBuilder::with_custom_parameters(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64)
.enable_compression(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64)
.build();
let ck = tfhe::ClientKey::generate(config);
let compressed_server_key = tfhe::CompressedServerKey::new(&ck);

View File

@@ -77,7 +77,7 @@ use crate::core_crypto::prelude::*;
/// );
/// }
/// ```
#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)]
#[versionize(CompressedModulusSwitchedGlweCiphertextVersions)]
pub struct CompressedModulusSwitchedGlweCiphertext<Scalar: UnsignedInteger> {
pub(crate) packed_integers: PackedIntegers<Scalar>,

View File

@@ -4,7 +4,7 @@ use crate::conformance::ParameterSetConformant;
use crate::core_crypto::backward_compatibility::entities::packed_integers::PackedIntegersVersions;
use crate::core_crypto::prelude::*;
#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)]
#[versionize(PackedIntegersVersions)]
pub struct PackedIntegers<Scalar: UnsignedInteger> {
pub(crate) packed_coeffs: Vec<Scalar>,

View File

@@ -1,6 +1,7 @@
use crate::prelude::*;
use crate::shortint::parameters::compact_public_key_only::p_fail_2_minus_64::ks_pbs::PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::key_switching::p_fail_2_minus_64::ks_pbs::PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::*;
use crate::shortint::ClassicPBSParameters;
use crate::{
@@ -20,7 +21,7 @@ fn test_tag_propagation_cpu() {
PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
PARAM_KEYSWITCH_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
)),
Some(COMP_PARAM_MESSAGE_2_CARRY_2),
Some(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64),
)
}
@@ -139,9 +140,9 @@ fn test_tag_propagation_zk_pok() {
fn test_tag_propagation_gpu() {
test_tag_propagation(
Device::CudaGpu,
PARAM_MESSAGE_2_CARRY_2,
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
None,
Some(COMP_PARAM_MESSAGE_2_CARRY_2),
Some(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64),
)
}

View File

@@ -95,7 +95,7 @@ impl CompressedCiphertextListBuilder {
}
}
#[derive(Clone, Serialize, Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Versionize)]
#[versionize(CompressedCiphertextListVersions)]
pub struct CompressedCiphertextList {
pub(crate) packed_list: ShortintCompressedCiphertextList,
@@ -153,13 +153,22 @@ impl CompressedCiphertextList {
#[cfg(test)]
mod tests {
use super::*;
use crate::integer::ClientKey;
use crate::integer::{gen_keys, IntegerKeyKind};
use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use itertools::Itertools;
use rand::Rng;
const NB_TESTS: usize = 10;
const NB_OPERATOR_TESTS: usize = 10;
#[test]
fn test_heterogeneous_ciphertext_compression_ci_run_filter() {
let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
fn test_ciphertext_compression() {
const NUM_BLOCKS: usize = 32;
let (cks, sks) = gen_keys(
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
IntegerKeyKind::Radix,
);
let private_compression_key =
cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
@@ -167,32 +176,170 @@ mod tests {
let (compression_key, decompression_key) =
cks.new_compression_decompression_keys(&private_compression_key);
let ct1 = cks.encrypt_radix(3_u32, 16);
const MAX_NB_MESSAGES: usize = 2 * COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64
.lwe_per_glwe
.0
/ NUM_BLOCKS;
let ct2 = cks.encrypt_signed_radix(-2, 16);
let mut rng = rand::thread_rng();
let ct3 = cks.encrypt_bool(true);
let message_modulus: u128 = cks.parameters().message_modulus().0 as u128;
let compressed = CompressedCiphertextListBuilder::new()
.push(ct1)
.push(ct2)
.push(ct3)
.build(&compression_key);
for _ in 0..NB_TESTS {
// Unsigned
let modulus = message_modulus.pow(NUM_BLOCKS as u32);
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let messages = (0..nb_messages)
.map(|_| rng.gen::<u128>() % modulus)
.collect::<Vec<_>>();
let decompressed1 = compressed.get(0, &decompression_key).unwrap().unwrap();
let cts = messages
.iter()
.map(|message| cks.encrypt_radix(*message, NUM_BLOCKS))
.collect_vec();
let decrypted: u32 = cks.decrypt_radix(&decompressed1);
let mut builder = CompressedCiphertextListBuilder::new();
assert_eq!(decrypted, 3_u32);
for ct in cts {
let and_ct = sks.bitand_parallelized(&ct, &ct);
builder.push(and_ct);
}
let decompressed2 = compressed.get(1, &decompression_key).unwrap().unwrap();
let compressed = builder.build(&compression_key);
let decrypted2: i32 = cks.decrypt_signed_radix(&decompressed2);
for (i, message) in messages.iter().enumerate() {
let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: u128 = cks.decrypt_radix(&decompressed);
assert_eq!(decrypted, *message);
}
}
assert_eq!(decrypted2, -2);
// Signed
let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128;
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i128>() % modulus)
.collect::<Vec<_>>();
let decompressed3 = compressed.get(2, &decompression_key).unwrap().unwrap();
let cts = messages
.iter()
.map(|message| cks.encrypt_signed_radix(*message, NUM_BLOCKS))
.collect_vec();
assert!(cks.decrypt_bool(&decompressed3));
let mut builder = CompressedCiphertextListBuilder::new();
for ct in cts {
let and_ct = sks.bitand_parallelized(&ct, &ct);
builder.push(and_ct);
}
let compressed = builder.build(&compression_key);
for (i, message) in messages.iter().enumerate() {
let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: i128 = cks.decrypt_signed_radix(&decompressed);
assert_eq!(decrypted, *message);
}
}
// Boolean
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i64>() % 2 != 0)
.collect::<Vec<_>>();
let cts = messages
.iter()
.map(|message| cks.encrypt_bool(*message))
.collect_vec();
let mut builder = CompressedCiphertextListBuilder::new();
for ct in cts {
let and_ct = sks.boolean_bitand(&ct, &ct);
builder.push(and_ct);
}
let compressed = builder.build(&compression_key);
for (i, message) in messages.iter().enumerate() {
let decompressed = compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted = cks.decrypt_bool(&decompressed);
assert_eq!(decrypted, *message);
}
}
// Hybrid
enum MessageType {
Unsigned(u128),
Signed(i128),
Boolean(bool),
}
for _ in 0..NB_OPERATOR_TESTS {
let mut builder = CompressedCiphertextListBuilder::new();
let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let mut messages = vec![];
for _ in 0..nb_messages {
let case_selector = rng.gen_range(0..3);
match case_selector {
0 => {
// Unsigned
let modulus = message_modulus.pow(NUM_BLOCKS as u32);
let message = rng.gen::<u128>() % modulus;
let ct = cks.encrypt_radix(message, NUM_BLOCKS);
let and_ct = sks.bitand_parallelized(&ct, &ct);
builder.push(and_ct);
messages.push(MessageType::Unsigned(message));
}
1 => {
// Signed
let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128;
let message = rng.gen::<i128>() % modulus;
let ct = cks.encrypt_signed_radix(message, NUM_BLOCKS);
let and_ct = sks.bitand_parallelized(&ct, &ct);
builder.push(and_ct);
messages.push(MessageType::Signed(message));
}
_ => {
// Boolean
let message = rng.gen::<i64>() % 2 != 0;
let ct = cks.encrypt_bool(message);
let and_ct = sks.boolean_bitand(&ct, &ct);
builder.push(and_ct);
messages.push(MessageType::Boolean(message));
}
}
}
let compressed = builder.build(&compression_key);
for (i, val) in messages.iter().enumerate() {
match val {
MessageType::Unsigned(message) => {
let decompressed =
compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: u128 = cks.decrypt_radix(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::Signed(message) => {
let decompressed =
compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted: i128 = cks.decrypt_signed_radix(&decompressed);
assert_eq!(decrypted, *message);
}
MessageType::Boolean(message) => {
let decompressed =
compressed.get(i, &decompression_key).unwrap().unwrap();
let decrypted = cks.decrypt_bool(&decompressed);
assert_eq!(decrypted, *message);
}
}
}
}
}
}
}

View File

@@ -132,79 +132,84 @@ impl CudaCompressedCiphertextList {
/// use tfhe::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
/// let num_blocks = 32;
/// let streams = CudaStreams::new_multi_gpu();
///
/// let private_compression_key =
/// cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
/// let (radix_cks, _) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
/// num_blocks,
/// &streams,
/// );
/// let cks = radix_cks.as_ref();
///
/// let streams = CudaStreams::new_multi_gpu();
/// let private_compression_key =
/// cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
///
/// let num_blocks = 32;
/// let (radix_cks, _) = gen_keys_radix_gpu(
/// PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
/// num_blocks,
/// &streams,
/// );
/// let (compressed_compression_key, compressed_decompression_key) =
/// radix_cks.new_compressed_compression_decompression_keys(&private_compression_key);
/// let (cuda_compression_key, cuda_decompression_key) =
/// radix_cks.new_cuda_compression_decompression_keys(&private_compression_key, &streams);
///
/// let cuda_compression_key = compressed_compression_key.decompress_to_cuda(&streams);
/// let private_compression_key =
/// cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
///
/// let compression_key = compressed_compression_key.decompress();
/// let decompression_key = compressed_decompression_key.decompress();
/// let (compressed_compression_key, compressed_decompression_key) =
/// radix_cks.new_compressed_compression_decompression_keys(&private_compression_key);
///
/// let ct1 = radix_cks.encrypt(3_u32);
/// let ct2 = radix_cks.encrypt_signed(-2);
/// let ct3 = radix_cks.encrypt_bool(true);
/// let cuda_compression_key = compressed_compression_key.decompress_to_cuda(&streams);
///
/// /// Copy to GPU
/// let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams);
/// let d_ct2 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct2, &streams);
/// let d_ct3 = CudaBooleanBlock::from_boolean_block(&ct3, &streams);
/// let compression_key = compressed_compression_key.decompress();
/// let decompression_key = compressed_decompression_key.decompress();
///
/// let cuda_compressed = CudaCompressedCiphertextListBuilder::new()
/// .push(d_ct1, &streams)
/// .push(d_ct2, &streams)
/// .push(d_ct3, &streams)
/// .build(&cuda_compression_key, &streams);
/// let ct1 = radix_cks.encrypt(3_u32);
/// let ct2 = radix_cks.encrypt_signed(-2);
/// let ct3 = radix_cks.encrypt_bool(true);
///
/// let reference_compressed = CompressedCiphertextListBuilder::new()
/// .push(ct1)
/// .push(ct2)
/// .push(ct3)
/// .build(&compression_key);
/// /// Copy to GPU
/// let d_ct1 = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct1, &streams);
/// let d_ct2 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct2, &streams);
/// let d_ct3 = CudaBooleanBlock::from_boolean_block(&ct3, &streams);
///
/// let converted_compressed = cuda_compressed.to_compressed_ciphertext_list(&streams);
/// let cuda_compressed = CudaCompressedCiphertextListBuilder::new()
/// .push(d_ct1, &streams)
/// .push(d_ct2, &streams)
/// .push(d_ct3, &streams)
/// .build(&cuda_compression_key, &streams);
///
/// let decompressed1: RadixCiphertext = converted_compressed
/// .get(0, &decompression_key)
/// .unwrap()
/// .unwrap();
/// let reference_decompressed1 = reference_compressed
/// .get(0, &decompression_key)
/// .unwrap()
/// .unwrap();
/// assert_eq!(decompressed1, reference_decompressed1);
/// let reference_compressed = CompressedCiphertextListBuilder::new()
/// .push(ct1)
/// .push(ct2)
/// .push(ct3)
/// .build(&compression_key);
///
/// let decompressed2: SignedRadixCiphertext = converted_compressed
/// .get(1, &decompression_key)
/// .unwrap()
/// .unwrap();
/// let reference_decompressed2 = reference_compressed
/// .get(1, &decompression_key)
/// .unwrap()
/// .unwrap();
/// assert_eq!(decompressed2, reference_decompressed2);
/// let converted_compressed = cuda_compressed.to_compressed_ciphertext_list(&streams);
///
/// let decompressed3: BooleanBlock = converted_compressed
/// .get(2, &decompression_key)
/// .unwrap()
/// .unwrap();
/// let reference_decompressed3 = reference_compressed
/// .get(2, &decompression_key)
/// .unwrap()
/// .unwrap();
/// assert_eq!(decompressed3, reference_decompressed3);
/// let decompressed1: RadixCiphertext = converted_compressed
/// .get(0, &decompression_key)
/// .unwrap()
/// .unwrap();
/// let reference_decompressed1 = reference_compressed
/// .get(0, &decompression_key)
/// .unwrap()
/// .unwrap();
/// assert_eq!(decompressed1, reference_decompressed1);
///
/// let decompressed2: SignedRadixCiphertext = converted_compressed
/// .get(1, &decompression_key)
/// .unwrap()
/// .unwrap();
/// let reference_decompressed2 = reference_compressed
/// .get(1, &decompression_key)
/// .unwrap()
/// .unwrap();
/// assert_eq!(decompressed2, reference_decompressed2);
///
/// let decompressed3: BooleanBlock = converted_compressed
/// .get(2, &decompression_key)
/// .unwrap()
/// .unwrap();
/// let reference_decompressed3 = reference_compressed
/// .get(2, &decompression_key)
/// .unwrap()
/// .unwrap();
/// assert_eq!(decompressed3, reference_decompressed3);
/// ```
pub fn to_compressed_ciphertext_list(&self, streams: &CudaStreams) -> CompressedCiphertextList {
let glwe_list = self
@@ -261,8 +266,8 @@ impl CudaCompressedCiphertextList {
}
impl CompressedCiphertextList {
/// ```rust
/// use tfhe::core_crypto::gpu::CudaStreams;
///```rust
/// use tfhe::core_crypto::gpu::CudaStreams;
/// use tfhe::integer::ciphertext::CompressedCiphertextListBuilder;
/// use tfhe::integer::ClientKey;
/// use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
@@ -271,62 +276,64 @@ impl CompressedCiphertextList {
/// use tfhe::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
///
/// let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
/// let num_blocks = 32;
/// let streams = CudaStreams::new_multi_gpu();
///
/// let private_compression_key =
/// cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
/// let (radix_cks, _) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
/// num_blocks,
/// &streams,
/// );
/// let cks = radix_cks.as_ref();
///
/// let streams = CudaStreams::new_multi_gpu();
/// let private_compression_key =
/// cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
///
/// let num_blocks = 32;
/// let (radix_cks, _) = gen_keys_radix_gpu(
/// PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
/// num_blocks,
/// &streams,
/// );
/// let (compressed_compression_key, compressed_decompression_key) =
/// radix_cks.new_compressed_compression_decompression_keys(&private_compression_key);
/// let (compressed_compression_key, compressed_decompression_key) =
/// radix_cks.new_compressed_compression_decompression_keys(&private_compression_key);
///
/// let cuda_decompression_key =
/// compressed_decompression_key.decompress_to_cuda(
/// radix_cks.parameters().glwe_dimension(),
/// radix_cks.parameters().polynomial_size(),
/// radix_cks.parameters().message_modulus(),
/// radix_cks.parameters().carry_modulus(),
/// radix_cks.parameters().ciphertext_modulus(),
/// &streams);
/// let cuda_decompression_key = compressed_decompression_key.decompress_to_cuda(
/// radix_cks.parameters().glwe_dimension(),
/// radix_cks.parameters().polynomial_size(),
/// radix_cks.parameters().message_modulus(),
/// radix_cks.parameters().carry_modulus(),
/// radix_cks.parameters().ciphertext_modulus(),
/// &streams
/// );
///
/// let compression_key = compressed_compression_key.decompress();
/// let compression_key = compressed_compression_key.decompress();
///
/// let ct1 = radix_cks.encrypt(3_u32);
/// let ct2 = radix_cks.encrypt_signed(-2);
/// let ct3 = radix_cks.encrypt_bool(true);
/// let ct1 = radix_cks.encrypt(3_u32);
/// let ct2 = radix_cks.encrypt_signed(-2);
/// let ct3 = radix_cks.encrypt_bool(true);
///
/// let compressed = CompressedCiphertextListBuilder::new()
/// .push(ct1)
/// .push(ct2)
/// .push(ct3)
/// .build(&compression_key);
/// let compressed = CompressedCiphertextListBuilder::new()
/// .push(ct1)
/// .push(ct2)
/// .push(ct3)
/// .build(&compression_key);
///
/// let cuda_compressed = compressed.to_cuda_compressed_ciphertext_list(&streams);
/// let cuda_compressed = compressed.to_cuda_compressed_ciphertext_list(&streams);
/// let recovered_cuda_compressed = cuda_compressed.to_compressed_ciphertext_list(&streams);
///
/// let d_decompressed1: CudaUnsignedRadixCiphertext =
/// cuda_compressed.get(0, &cuda_decompression_key, &streams).unwrap().unwrap();
/// let decompressed1 = d_decompressed1.to_radix_ciphertext(&streams);
/// let decrypted: u32 = radix_cks.decrypt(&decompressed1);
/// assert_eq!(decrypted, 3_u32);
/// assert_eq!(recovered_cuda_compressed, compressed);
///
/// let d_decompressed2: CudaSignedRadixCiphertext =
/// cuda_compressed.get(1, &cuda_decompression_key, &streams).unwrap().unwrap();
/// let decompressed2 = d_decompressed2.to_signed_radix_ciphertext(&streams);
/// let decrypted: i32 = radix_cks.decrypt_signed(&decompressed2);
/// assert_eq!(decrypted, -2);
/// let d_decompressed1: CudaUnsignedRadixCiphertext =
/// cuda_compressed.get(0, &cuda_decompression_key, &streams).unwrap().unwrap();
/// let decompressed1 = d_decompressed1.to_radix_ciphertext(&streams);
/// let decrypted: u32 = radix_cks.decrypt(&decompressed1);
/// assert_eq!(decrypted, 3_u32);
///
/// let d_decompressed3: CudaBooleanBlock =
/// cuda_compressed.get(2, &cuda_decompression_key, &streams).unwrap().unwrap();
/// let decompressed3 = d_decompressed3.to_boolean_block(&streams);
/// let decrypted = radix_cks.decrypt_bool(&decompressed3);
/// assert!(decrypted);
/// let d_decompressed2: CudaSignedRadixCiphertext =
/// cuda_compressed.get(1, &cuda_decompression_key, &streams).unwrap().unwrap();
/// let decompressed2 = d_decompressed2.to_signed_radix_ciphertext(&streams);
/// let decrypted: i32 = radix_cks.decrypt_signed(&decompressed2);
/// assert_eq!(decrypted, -2);
///
/// let d_decompressed3: CudaBooleanBlock =
/// cuda_compressed.get(2, &cuda_decompression_key, &streams).unwrap().unwrap();
/// let decompressed3 = d_decompressed3.to_boolean_block(&streams);
/// let decrypted = radix_cks.decrypt_bool(&decompressed3);
/// assert!(decrypted);
/// ```
pub fn to_cuda_compressed_ciphertext_list(
&self,
@@ -513,7 +520,6 @@ impl<'de> serde::Deserialize<'de> for CudaCompressedCiphertextList {
mod tests {
use super::*;
use crate::integer::gpu::gen_keys_radix_gpu;
use crate::integer::ClientKey;
use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64;
use rand::Rng;
@@ -523,31 +529,36 @@ mod tests {
#[test]
fn test_gpu_ciphertext_compression() {
let cks = ClientKey::new(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
const NUM_BLOCKS: usize = 32;
let streams = CudaStreams::new_multi_gpu();
let (radix_cks, sks) = gen_keys_radix_gpu(
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
NUM_BLOCKS,
&streams,
);
let cks = radix_cks.as_ref();
let private_compression_key =
cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64);
let streams = CudaStreams::new_multi_gpu();
let num_blocks = 32;
let (radix_cks, _) = gen_keys_radix_gpu(
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
num_blocks,
&streams,
);
let (cuda_compression_key, cuda_decompression_key) =
radix_cks.new_cuda_compression_decompression_keys(&private_compression_key, &streams);
const MAX_NB_MESSAGES: usize = 2 * COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64
.lwe_per_glwe
.0
/ NUM_BLOCKS;
let mut rng = rand::thread_rng();
let message_modulus: u128 = cks.parameters().message_modulus().0 as u128;
for _ in 0..NB_TESTS {
// Unsigned
let modulus = message_modulus.pow(num_blocks as u32);
let modulus = message_modulus.pow(NUM_BLOCKS as u32);
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % 6);
let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let messages = (0..nb_messages)
.map(|_| rng.gen::<u128>() % modulus)
.collect::<Vec<_>>();
@@ -563,7 +574,8 @@ mod tests {
let mut builder = CudaCompressedCiphertextListBuilder::new();
for d_ct in d_cts {
builder.push(d_ct, &streams);
let d_and_ct = sks.bitand(&d_ct, &d_ct, &streams);
builder.push(d_and_ct, &streams);
}
let cuda_compressed = builder.build(&cuda_compression_key, &streams);
@@ -580,9 +592,9 @@ mod tests {
}
// Signed
let modulus = message_modulus.pow((num_blocks - 1) as u32) as i128;
let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128;
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % 6);
let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i128>() % modulus)
.collect::<Vec<_>>();
@@ -598,7 +610,8 @@ mod tests {
let mut builder = CudaCompressedCiphertextListBuilder::new();
for d_ct in d_cts {
builder.push(d_ct, &streams);
let d_and_ct = sks.bitand(&d_ct, &d_ct, &streams);
builder.push(d_and_ct, &streams);
}
let cuda_compressed = builder.build(&cuda_compression_key, &streams);
@@ -616,7 +629,7 @@ mod tests {
// Boolean
for _ in 0..NB_OPERATOR_TESTS {
let nb_messages = 1 + (rng.gen::<u64>() % 6);
let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let messages = (0..nb_messages)
.map(|_| rng.gen::<i64>() % 2 != 0)
.collect::<Vec<_>>();
@@ -631,8 +644,12 @@ mod tests {
let mut builder = CudaCompressedCiphertextListBuilder::new();
for d_ct in d_cts {
builder.push(d_ct, &streams);
for d_boolean_ct in d_cts {
let d_ct = d_boolean_ct.0;
let d_and_ct = sks.bitand(&d_ct, &d_ct, &streams);
let d_and_boolean_ct =
CudaBooleanBlock::from_cuda_radix_ciphertext(d_and_ct.ciphertext);
builder.push(d_and_boolean_ct, &streams);
}
let cuda_compressed = builder.build(&cuda_compression_key, &streams);
@@ -657,38 +674,44 @@ mod tests {
for _ in 0..NB_OPERATOR_TESTS {
let mut builder = CudaCompressedCiphertextListBuilder::new();
let nb_messages = 1 + (rng.gen::<u64>() % 6);
let nb_messages = rng.gen_range(1..=MAX_NB_MESSAGES as u64);
let mut messages = vec![];
for _ in 0..nb_messages {
let case_selector = rng.gen_range(0..3);
match case_selector {
0 => {
// Unsigned
let modulus = message_modulus.pow(num_blocks as u32);
let modulus = message_modulus.pow(NUM_BLOCKS as u32);
let message = rng.gen::<u128>() % modulus;
let ct = radix_cks.encrypt(message);
let d_ct =
CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct, &streams);
builder.push(d_ct, &streams);
let d_and_ct = sks.bitand(&d_ct, &d_ct, &streams);
builder.push(d_and_ct, &streams);
messages.push(MessageType::Unsigned(message));
}
1 => {
// Signed
let modulus = message_modulus.pow((num_blocks - 1) as u32) as i128;
let modulus = message_modulus.pow((NUM_BLOCKS - 1) as u32) as i128;
let message = rng.gen::<i128>() % modulus;
let ct = radix_cks.encrypt_signed(message);
let d_ct = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(
&ct, &streams,
);
builder.push(d_ct, &streams);
let d_and_ct = sks.bitand(&d_ct, &d_ct, &streams);
builder.push(d_and_ct, &streams);
messages.push(MessageType::Signed(message));
}
_ => {
// Boolean
let message = rng.gen::<i64>() % 2 != 0;
let ct = radix_cks.encrypt_bool(message);
let d_ct = CudaBooleanBlock::from_boolean_block(&ct, &streams);
builder.push(d_ct, &streams);
let d_boolean_ct = CudaBooleanBlock::from_boolean_block(&ct, &streams);
let d_ct = d_boolean_ct.0;
let d_and_ct = sks.bitand(&d_ct, &d_ct, &streams);
let d_and_boolean_ct =
CudaBooleanBlock::from_cuda_radix_ciphertext(d_and_ct.ciphertext);
builder.push(d_and_boolean_ct, &streams);
messages.push(MessageType::Boolean(message));
}
}

View File

@@ -7,7 +7,7 @@ use crate::shortint::backward_compatibility::ciphertext::CompressedCiphertextLis
use crate::shortint::parameters::CompressedCiphertextConformanceParams;
use crate::shortint::{CarryModulus, MessageModulus};
#[derive(Clone, serde::Serialize, serde::Deserialize, Versionize)]
#[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize, Versionize)]
#[versionize(CompressedCiphertextListVersions)]
pub struct CompressedCiphertextList {
pub modulus_switched_glwe_ciphertext_list: Vec<CompressedModulusSwitchedGlweCiphertext<u64>>,

View File

@@ -290,15 +290,8 @@ impl NamedParam for ShortintKeySwitchingParameters {
impl NamedParam for CompressionParameters {
fn name(&self) -> String {
named_params_impl!(expose
COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64
);
named_params_impl!(
{
*self;
Self
} == (COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64)
);
named_params_impl!(
{

View File

@@ -1,6 +1,6 @@
use tfhe_versionable::Versionize;
use crate::core_crypto::prelude::{CiphertextModulusLog, LweCiphertextCount, StandardDev};
use crate::core_crypto::prelude::{CiphertextModulusLog, LweCiphertextCount};
use crate::shortint::backward_compatibility::parameters::list_compression::CompressionParametersVersions;
use crate::shortint::parameters::{
DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, GlweDimension,
@@ -25,27 +25,12 @@ pub struct CompressionParameters {
pub const COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64: CompressionParameters =
CompressionParameters {
br_level: DecompositionLevelCount(1),
br_base_log: DecompositionBaseLog(25),
packing_ks_level: DecompositionLevelCount(2),
packing_ks_base_log: DecompositionBaseLog(8),
br_base_log: DecompositionBaseLog(23),
packing_ks_level: DecompositionLevelCount(4),
packing_ks_base_log: DecompositionBaseLog(4),
packing_ks_polynomial_size: PolynomialSize(256),
packing_ks_glwe_dimension: GlweDimension(5),
packing_ks_glwe_dimension: GlweDimension(4),
lwe_per_glwe: LweCiphertextCount(256),
storage_log_modulus: CiphertextModulusLog(11),
packing_ks_key_noise_distribution: DynamicDistribution::new_t_uniform(36),
};
pub const COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64: CompressionParameters =
CompressionParameters {
br_level: DecompositionLevelCount(1),
br_base_log: DecompositionBaseLog(25),
packing_ks_level: DecompositionLevelCount(2),
packing_ks_base_log: DecompositionBaseLog(8),
packing_ks_polynomial_size: PolynomialSize(256),
packing_ks_glwe_dimension: GlweDimension(5),
lwe_per_glwe: LweCiphertextCount(256),
storage_log_modulus: CiphertextModulusLog(11),
packing_ks_key_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(
StandardDev(1.6173527465097522e-09),
),
storage_log_modulus: CiphertextModulusLog(12),
packing_ks_key_noise_distribution: DynamicDistribution::new_t_uniform(42),
};

View File

@@ -42,7 +42,9 @@ pub use crate::shortint::parameters::classic::gaussian::p_fail_2_minus_64::ks_pb
pub use crate::shortint::parameters::classic::gaussian::p_fail_2_minus_64::pbs_ks::*;
pub use crate::shortint::parameters::classic::tuniform::p_fail_2_minus_64::ks_pbs::*;
pub use crate::shortint::parameters::classic::tuniform::p_fail_2_minus_64::pbs_ks::*;
pub use crate::shortint::parameters::list_compression::CompressionParameters;
pub use crate::shortint::parameters::list_compression::{
CompressionParameters, COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64,
};
pub use compact_public_key_only::{
CastingFunctionsOwned, CastingFunctionsView, CompactCiphertextListExpansionKind,
CompactPublicKeyEncryptionParameters, ShortintCompactCiphertextListCastingMode,
@@ -887,8 +889,3 @@ pub const PARAM_SMALL_MESSAGE_1_CARRY_1: ClassicPBSParameters = PARAM_MESSAGE_1_
pub const PARAM_SMALL_MESSAGE_2_CARRY_2: ClassicPBSParameters = PARAM_MESSAGE_2_CARRY_2_PBS_KS;
pub const PARAM_SMALL_MESSAGE_3_CARRY_3: ClassicPBSParameters = PARAM_MESSAGE_3_CARRY_3_PBS_KS;
pub const PARAM_SMALL_MESSAGE_4_CARRY_4: ClassicPBSParameters = PARAM_MESSAGE_4_CARRY_4_PBS_KS;
pub const COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS: CompressionParameters =
list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64;
pub const COMP_PARAM_MESSAGE_2_CARRY_2: CompressionParameters = COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS;