diff --git a/tfhe/src/integer/backward_compatibility/ciphertext/mod.rs b/tfhe/src/integer/backward_compatibility/ciphertext/mod.rs index f4e01140b..ac662abf3 100644 --- a/tfhe/src/integer/backward_compatibility/ciphertext/mod.rs +++ b/tfhe/src/integer/backward_compatibility/ciphertext/mod.rs @@ -5,6 +5,7 @@ use crate::integer::ciphertext::{ CompressedModulusSwitchedSignedRadixCiphertext, DataKind, SquashedNoiseBooleanBlock, SquashedNoiseRadixCiphertext, SquashedNoiseSignedRadixCiphertext, }; +use crate::integer::server_key::CompressedKVStore; use crate::integer::BooleanBlock; #[cfg(feature = "zk-pok")] use crate::integer::ProvenCompactCiphertextList; @@ -148,3 +149,8 @@ pub enum SquashedNoiseSignedRadixCiphertextVersions { pub enum SquashedNoiseBooleanBlockVersions { V0(SquashedNoiseBooleanBlock), } + +#[derive(VersionsDispatch)] +pub enum CompressedKVStoreVersions { + V0(CompressedKVStore), +} diff --git a/tfhe/src/integer/server_key/mod.rs b/tfhe/src/integer/server_key/mod.rs index d1212d681..761fbf848 100644 --- a/tfhe/src/integer/server_key/mod.rs +++ b/tfhe/src/integer/server_key/mod.rs @@ -7,7 +7,7 @@ pub(crate) mod crt; mod crt_parallel; pub(crate) mod radix; pub(crate) mod radix_parallel; -pub use radix_parallel::kv_store::KVStore; +pub use radix_parallel::kv_store::{CompressedKVStore, KVStore}; use super::backward_compatibility::server_key::{CompressedServerKeyVersions, ServerKeyVersions}; use crate::conformance::ParameterSetConformant; diff --git a/tfhe/src/integer/server_key/radix_parallel/kv_store.rs b/tfhe/src/integer/server_key/radix_parallel/kv_store.rs index 590e7987f..b13fc0b64 100644 --- a/tfhe/src/integer/server_key/radix_parallel/kv_store.rs +++ b/tfhe/src/integer/server_key/radix_parallel/kv_store.rs @@ -1,18 +1,37 @@ +use crate::integer::backward_compatibility::ciphertext::CompressedKVStoreVersions; use crate::integer::block_decomposition::{Decomposable, DecomposableInto}; +use crate::integer::ciphertext::{ + CompressedCiphertextList, CompressedCiphertextListBuilder, Compressible, Expandable, +}; +use crate::integer::compression_keys::{CompressionKey, DecompressionKey}; use crate::integer::prelude::ServerKeyDefaultCMux; use crate::integer::{BooleanBlock, IntegerRadixCiphertext, ServerKey}; use crate::prelude::CastInto; use rayon::prelude::*; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use std::fmt::Display; use std::hash::Hash; use std::num::NonZeroUsize; +use tfhe_versionable::Versionize; +/// The KVStore is a specialized encrypted HashMap +/// +/// * Keys are clear numbers +/// * Values are RadixCiphertext or SignedRadixCiphertext +/// +/// It supports getting/modifying existing pairs of (key,value) +/// using an encrypted key. +/// +/// +/// To serialize a KVStore it must first be compressed with [KVStore::compress] pub struct KVStore { data: HashMap, block_count: Option, } impl KVStore { + /// Creates an empty KVStore pub fn new() -> Self { Self { data: HashMap::new(), @@ -20,6 +39,10 @@ impl KVStore { } } + /// Returns the value stored for the key if any + /// + /// Key is in clear, see [ServerKey::kv_store_get] if you wish to + /// query using an encrypted key pub fn get(&self, key: &Key) -> Option<&Ct> where Key: Eq + Hash, @@ -41,7 +64,7 @@ impl KVStore { /// values stored pub fn insert(&mut self, key: Key, value: Ct) -> Option where - Key: PartialEq + Ord + Eq + Hash, + Key: PartialEq + Eq + Hash, Ct: IntegerRadixCiphertext, { let n_blocks = value.blocks().len(); @@ -61,10 +84,12 @@ impl KVStore { self.data.insert(key, value) } + /// Returns the number of key-value pairs currently stored pub fn len(&self) -> usize { self.data.len() } + /// Returns whether the store is empty pub fn is_empty(&self) -> bool { self.data.is_empty() } @@ -352,3 +377,246 @@ impl ServerKey { (new_value, check_block) } } + +impl KVStore +where + Key: Copy, + Ct: Compressible + Clone, +{ + /// Compress the KVStore to be able to serialize it + pub fn compress(&self, compression_key: &CompressionKey) -> CompressedKVStore { + let mut builder = CompressedCiphertextListBuilder::new(); + let mut keys = Vec::with_capacity(self.data.len()); + for (key, value) in self.data.iter() { + keys.push(*key); + builder.push(value.clone()); + } + + let values = builder.build(compression_key); + + CompressedKVStore { keys, values } + } +} + +/// Compressed KVStore +/// +/// This type is the serializable and deserializable form of a KVStore +#[derive(Serialize, Deserialize, Versionize)] +#[versionize(CompressedKVStoreVersions)] +pub struct CompressedKVStore { + keys: Vec, + values: CompressedCiphertextList, +} + +impl CompressedKVStore +where + Key: Copy + Display + Eq + Hash, +{ + /// Decompressed the KVStore + /// + /// Returns an error if: + /// * A key does not have a corresponding value + /// * A value (which is a radix ciphertext) does not have the same number of blocks as the + /// others. + /// + /// Both these errors indicate corrupted or malformed data + pub fn decompress( + &self, + decompression_key: &DecompressionKey, + ) -> crate::Result> + where + Ct: Expandable + IntegerRadixCiphertext, + { + let mut block_count = None; + let mut store = KVStore::new(); + for (i, key) in self.keys.iter().enumerate() { + let value: Ct = self + .values + .get(i, decompression_key)? + .ok_or_else(|| crate::error!("Missing value for key '{key}'"))?; + + let n = *block_count.get_or_insert_with(|| value.blocks().len()); + + if n != value.blocks().len() { + return Err(crate::error!( + "The value for key {key} does not have the same number \ + of blocks as other values. {} instead of {n}", + value.blocks().len() + )); + } + + let _ = store.insert(*key, value); + } + + Ok(store) + } +} + +macro_rules! impl_named_for_kv_store { + ($Key:ty) => { + impl crate::named::Named for CompressedKVStore<$Key> { + const NAME: &'static str = + concat!("integer::CompressedKVStore<", stringify!($Key), ">"); + } + }; +} + +impl_named_for_kv_store!(u8); +impl_named_for_kv_store!(u16); +impl_named_for_kv_store!(u32); +impl_named_for_kv_store!(u64); +impl_named_for_kv_store!(u128); +impl_named_for_kv_store!(i8); +impl_named_for_kv_store!(i16); +impl_named_for_kv_store!(i32); +impl_named_for_kv_store!(i64); +impl_named_for_kv_store!(i128); + +#[cfg(test)] +mod tests { + use rand::Rng; + + use super::*; + use crate::integer::{ + gen_keys, ClientKey, IntegerKeyKind, RadixCiphertext, SignedRadixCiphertext, + }; + use crate::shortint::parameters::test_params::{ + TEST_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + TEST_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + }; + use crate::shortint::ShortintParameterSet; + + fn assert_store_unsigned_matches( + clear_store: &HashMap, + kv_store: &KVStore, + cks: &ClientKey, + ) { + assert_eq!( + clear_store.len(), + kv_store.len(), + "Clear and Encrypted stores do no have the same number of pairs" + ); + + for (key, value) in clear_store { + let ct = kv_store + .get(key) + .expect("Missing entry in decompressed KVStore"); + + let decrypted: u64 = cks.decrypt_radix(ct); + assert_eq!( + *value, decrypted, + "Invalid value stored for key '{key}', expected '{value}' got '{decrypted}'" + ); + } + } + + #[test] + fn test_compression_serialization_unsigned() { + let params = TEST_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into(); + + let (cks, _) = gen_keys::(params, IntegerKeyKind::Radix); + + let private_compression_key = cks + .new_compression_private_key(TEST_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128); + + let (compression_key, decompression_key) = + cks.new_compression_decompression_keys(&private_compression_key); + + let num_blocks = 32; + let num_keys = 100; + + let mut rng = rand::thread_rng(); + + let mut clear_store = HashMap::new(); + let mut kv_store = KVStore::new(); + for _ in 0..num_keys { + let key = rng.gen::(); + let value = rng.gen::(); + + let ct = cks.encrypt_radix(value, num_blocks); + + let _ = clear_store.insert(key, value); + kv_store.insert(key, ct); + } + + assert_store_unsigned_matches(&clear_store, &kv_store, &cks); + + let compressed = kv_store.compress(&compression_key); + let kv_store = compressed.decompress(&decompression_key).unwrap(); + assert_store_unsigned_matches(&clear_store, &kv_store, &cks); + + let mut data = vec![]; + crate::safe_serialization::safe_serialize(&compressed, &mut data, 1 << 20).unwrap(); + let compressed: CompressedKVStore = + crate::safe_serialization::safe_deserialize(data.as_slice(), 1 << 20).unwrap(); + let kv_store = compressed.decompress(&decompression_key).unwrap(); + assert_store_unsigned_matches(&clear_store, &kv_store, &cks); + } + + fn assert_store_signed_matches( + clear_store: &HashMap, + kv_store: &KVStore, + cks: &ClientKey, + ) { + assert_eq!( + clear_store.len(), + kv_store.len(), + "Clear and Encrypted stores do no have the same number of pairs" + ); + + for (key, value) in clear_store { + let ct = kv_store + .get(key) + .expect("Missing entry in decompressed KVStore"); + + let decrypted: i64 = cks.decrypt_signed_radix(ct); + assert_eq!( + *value, decrypted, + "Invalid value stored for key '{key}', expected '{value}' got '{decrypted}'" + ); + } + } + + #[test] + fn test_compression_serialization_signed() { + let params = TEST_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into(); + + let (cks, _) = gen_keys::(params, IntegerKeyKind::Radix); + + let private_compression_key = cks + .new_compression_private_key(TEST_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128); + + let (compression_key, decompression_key) = + cks.new_compression_decompression_keys(&private_compression_key); + + let num_blocks = 32; + let num_keys = 100; + + let mut rng = rand::thread_rng(); + + let mut clear_store = HashMap::new(); + let mut kv_store = KVStore::new(); + for _ in 0..num_keys { + let key = rng.gen::(); + let value = rng.gen::(); + + let ct = cks.encrypt_signed_radix(value, num_blocks); + + let _ = clear_store.insert(key, value); + kv_store.insert(key, ct); + } + + assert_store_signed_matches(&clear_store, &kv_store, &cks); + + let compressed = kv_store.compress(&compression_key); + let kv_store = compressed.decompress(&decompression_key).unwrap(); + assert_store_signed_matches(&clear_store, &kv_store, &cks); + + let mut data = vec![]; + crate::safe_serialization::safe_serialize(&compressed, &mut data, 1 << 20).unwrap(); + let compressed: CompressedKVStore = + crate::safe_serialization::safe_deserialize(data.as_slice(), 1 << 20).unwrap(); + let kv_store = compressed.decompress(&decompression_key).unwrap(); + assert_store_signed_matches(&clear_store, &kv_store, &cks); + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_kv_store.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_kv_store.rs index 0dfac00b5..51e1cf9f3 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_kv_store.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_kv_store.rs @@ -9,11 +9,101 @@ use crate::shortint::parameters::{TestParameters, *}; use std::collections::BTreeMap; use std::sync::Arc; -create_parameterized_test!(integer_default_kv_store_add); -create_parameterized_test!(integer_default_kv_store_sub); -create_parameterized_test!(integer_default_kv_store_mul); -create_parameterized_test!(integer_default_kv_store_get_update); -create_parameterized_test!(integer_default_kv_store_map); +create_parameterized_test!( + integer_default_kv_store_add + { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS + }, + no_coverage => { + TEST_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128, + // 2M128 is too slow for 4_4, it is estimated to be 2x slower + TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + } + } +); +create_parameterized_test!( + integer_default_kv_store_sub + { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS + }, + no_coverage => { + TEST_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128, + // 2M128 is too slow for 4_4, it is estimated to be 2x slower + TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + } + } +); +create_parameterized_test!( + integer_default_kv_store_mul + { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS + }, + no_coverage => { + TEST_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128, + // 2M128 is too slow for 4_4, it is estimated to be 2x slower + TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + } + } +); +create_parameterized_test!( + integer_default_kv_store_get_update + { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS + }, + no_coverage => { + TEST_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128, + // 2M128 is too slow for 4_4, it is estimated to be 2x slower + TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + } + } +); +create_parameterized_test!( + integer_default_kv_store_map + { + coverage => { + COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, + COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS + }, + no_coverage => { + TEST_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128, + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, + TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128, + // 2M128 is too slow for 4_4, it is estimated to be 2x slower + TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + } + } +); fn integer_default_kv_store_add(params: impl Into) { let executor = CpuFunctionExecutor::new(&ServerKey::kv_store_add_to_slot);