mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
feat(integer): add KVStore compression and serialization
This commit is contained in:
committed by
tmontaigu
parent
4a73b7bb4b
commit
9f54777ee1
@@ -5,6 +5,7 @@ use crate::integer::ciphertext::{
|
||||
CompressedModulusSwitchedSignedRadixCiphertext, DataKind, SquashedNoiseBooleanBlock,
|
||||
SquashedNoiseRadixCiphertext, SquashedNoiseSignedRadixCiphertext,
|
||||
};
|
||||
use crate::integer::server_key::CompressedKVStore;
|
||||
use crate::integer::BooleanBlock;
|
||||
#[cfg(feature = "zk-pok")]
|
||||
use crate::integer::ProvenCompactCiphertextList;
|
||||
@@ -148,3 +149,8 @@ pub enum SquashedNoiseSignedRadixCiphertextVersions {
|
||||
pub enum SquashedNoiseBooleanBlockVersions {
|
||||
V0(SquashedNoiseBooleanBlock),
|
||||
}
|
||||
|
||||
#[derive(VersionsDispatch)]
|
||||
pub enum CompressedKVStoreVersions<K> {
|
||||
V0(CompressedKVStore<K>),
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ pub(crate) mod crt;
|
||||
mod crt_parallel;
|
||||
pub(crate) mod radix;
|
||||
pub(crate) mod radix_parallel;
|
||||
pub use radix_parallel::kv_store::KVStore;
|
||||
pub use radix_parallel::kv_store::{CompressedKVStore, KVStore};
|
||||
|
||||
use super::backward_compatibility::server_key::{CompressedServerKeyVersions, ServerKeyVersions};
|
||||
use crate::conformance::ParameterSetConformant;
|
||||
|
||||
@@ -1,18 +1,37 @@
|
||||
use crate::integer::backward_compatibility::ciphertext::CompressedKVStoreVersions;
|
||||
use crate::integer::block_decomposition::{Decomposable, DecomposableInto};
|
||||
use crate::integer::ciphertext::{
|
||||
CompressedCiphertextList, CompressedCiphertextListBuilder, Compressible, Expandable,
|
||||
};
|
||||
use crate::integer::compression_keys::{CompressionKey, DecompressionKey};
|
||||
use crate::integer::prelude::ServerKeyDefaultCMux;
|
||||
use crate::integer::{BooleanBlock, IntegerRadixCiphertext, ServerKey};
|
||||
use crate::prelude::CastInto;
|
||||
use rayon::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use std::hash::Hash;
|
||||
use std::num::NonZeroUsize;
|
||||
use tfhe_versionable::Versionize;
|
||||
|
||||
/// The KVStore is a specialized encrypted HashMap
|
||||
///
|
||||
/// * Keys are clear numbers
|
||||
/// * Values are RadixCiphertext or SignedRadixCiphertext
|
||||
///
|
||||
/// It supports getting/modifying existing pairs of (key,value)
|
||||
/// using an encrypted key.
|
||||
///
|
||||
///
|
||||
/// To serialize a KVStore it must first be compressed with [KVStore::compress]
|
||||
pub struct KVStore<Key, Ct> {
|
||||
data: HashMap<Key, Ct>,
|
||||
block_count: Option<NonZeroUsize>,
|
||||
}
|
||||
|
||||
impl<Key, Ct> KVStore<Key, Ct> {
|
||||
/// Creates an empty KVStore
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
data: HashMap::new(),
|
||||
@@ -20,6 +39,10 @@ impl<Key, Ct> KVStore<Key, Ct> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the value stored for the key if any
|
||||
///
|
||||
/// Key is in clear, see [ServerKey::kv_store_get] if you wish to
|
||||
/// query using an encrypted key
|
||||
pub fn get(&self, key: &Key) -> Option<&Ct>
|
||||
where
|
||||
Key: Eq + Hash,
|
||||
@@ -41,7 +64,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
|
||||
/// values stored
|
||||
pub fn insert(&mut self, key: Key, value: Ct) -> Option<Ct>
|
||||
where
|
||||
Key: PartialEq + Ord + Eq + Hash,
|
||||
Key: PartialEq + Eq + Hash,
|
||||
Ct: IntegerRadixCiphertext,
|
||||
{
|
||||
let n_blocks = value.blocks().len();
|
||||
@@ -61,10 +84,12 @@ impl<Key, Ct> KVStore<Key, Ct> {
|
||||
self.data.insert(key, value)
|
||||
}
|
||||
|
||||
/// Returns the number of key-value pairs currently stored
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
/// Returns whether the store is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
@@ -352,3 +377,246 @@ impl ServerKey {
|
||||
(new_value, check_block)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Key, Ct> KVStore<Key, Ct>
|
||||
where
|
||||
Key: Copy,
|
||||
Ct: Compressible + Clone,
|
||||
{
|
||||
/// Compress the KVStore to be able to serialize it
|
||||
pub fn compress(&self, compression_key: &CompressionKey) -> CompressedKVStore<Key> {
|
||||
let mut builder = CompressedCiphertextListBuilder::new();
|
||||
let mut keys = Vec::with_capacity(self.data.len());
|
||||
for (key, value) in self.data.iter() {
|
||||
keys.push(*key);
|
||||
builder.push(value.clone());
|
||||
}
|
||||
|
||||
let values = builder.build(compression_key);
|
||||
|
||||
CompressedKVStore { keys, values }
|
||||
}
|
||||
}
|
||||
|
||||
/// Compressed KVStore
|
||||
///
|
||||
/// This type is the serializable and deserializable form of a KVStore
|
||||
#[derive(Serialize, Deserialize, Versionize)]
|
||||
#[versionize(CompressedKVStoreVersions)]
|
||||
pub struct CompressedKVStore<Key> {
|
||||
keys: Vec<Key>,
|
||||
values: CompressedCiphertextList,
|
||||
}
|
||||
|
||||
impl<Key> CompressedKVStore<Key>
|
||||
where
|
||||
Key: Copy + Display + Eq + Hash,
|
||||
{
|
||||
/// Decompressed the KVStore
|
||||
///
|
||||
/// Returns an error if:
|
||||
/// * A key does not have a corresponding value
|
||||
/// * A value (which is a radix ciphertext) does not have the same number of blocks as the
|
||||
/// others.
|
||||
///
|
||||
/// Both these errors indicate corrupted or malformed data
|
||||
pub fn decompress<Ct>(
|
||||
&self,
|
||||
decompression_key: &DecompressionKey,
|
||||
) -> crate::Result<KVStore<Key, Ct>>
|
||||
where
|
||||
Ct: Expandable + IntegerRadixCiphertext,
|
||||
{
|
||||
let mut block_count = None;
|
||||
let mut store = KVStore::new();
|
||||
for (i, key) in self.keys.iter().enumerate() {
|
||||
let value: Ct = self
|
||||
.values
|
||||
.get(i, decompression_key)?
|
||||
.ok_or_else(|| crate::error!("Missing value for key '{key}'"))?;
|
||||
|
||||
let n = *block_count.get_or_insert_with(|| value.blocks().len());
|
||||
|
||||
if n != value.blocks().len() {
|
||||
return Err(crate::error!(
|
||||
"The value for key {key} does not have the same number \
|
||||
of blocks as other values. {} instead of {n}",
|
||||
value.blocks().len()
|
||||
));
|
||||
}
|
||||
|
||||
let _ = store.insert(*key, value);
|
||||
}
|
||||
|
||||
Ok(store)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_named_for_kv_store {
|
||||
($Key:ty) => {
|
||||
impl crate::named::Named for CompressedKVStore<$Key> {
|
||||
const NAME: &'static str =
|
||||
concat!("integer::CompressedKVStore<", stringify!($Key), ">");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_named_for_kv_store!(u8);
|
||||
impl_named_for_kv_store!(u16);
|
||||
impl_named_for_kv_store!(u32);
|
||||
impl_named_for_kv_store!(u64);
|
||||
impl_named_for_kv_store!(u128);
|
||||
impl_named_for_kv_store!(i8);
|
||||
impl_named_for_kv_store!(i16);
|
||||
impl_named_for_kv_store!(i32);
|
||||
impl_named_for_kv_store!(i64);
|
||||
impl_named_for_kv_store!(i128);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rand::Rng;
|
||||
|
||||
use super::*;
|
||||
use crate::integer::{
|
||||
gen_keys, ClientKey, IntegerKeyKind, RadixCiphertext, SignedRadixCiphertext,
|
||||
};
|
||||
use crate::shortint::parameters::test_params::{
|
||||
TEST_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
TEST_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
};
|
||||
use crate::shortint::ShortintParameterSet;
|
||||
|
||||
fn assert_store_unsigned_matches(
|
||||
clear_store: &HashMap<u32, u64>,
|
||||
kv_store: &KVStore<u32, RadixCiphertext>,
|
||||
cks: &ClientKey,
|
||||
) {
|
||||
assert_eq!(
|
||||
clear_store.len(),
|
||||
kv_store.len(),
|
||||
"Clear and Encrypted stores do no have the same number of pairs"
|
||||
);
|
||||
|
||||
for (key, value) in clear_store {
|
||||
let ct = kv_store
|
||||
.get(key)
|
||||
.expect("Missing entry in decompressed KVStore");
|
||||
|
||||
let decrypted: u64 = cks.decrypt_radix(ct);
|
||||
assert_eq!(
|
||||
*value, decrypted,
|
||||
"Invalid value stored for key '{key}', expected '{value}' got '{decrypted}'"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compression_serialization_unsigned() {
|
||||
let params = TEST_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into();
|
||||
|
||||
let (cks, _) = gen_keys::<ShortintParameterSet>(params, IntegerKeyKind::Radix);
|
||||
|
||||
let private_compression_key = cks
|
||||
.new_compression_private_key(TEST_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128);
|
||||
|
||||
let (compression_key, decompression_key) =
|
||||
cks.new_compression_decompression_keys(&private_compression_key);
|
||||
|
||||
let num_blocks = 32;
|
||||
let num_keys = 100;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let mut clear_store = HashMap::new();
|
||||
let mut kv_store = KVStore::new();
|
||||
for _ in 0..num_keys {
|
||||
let key = rng.gen::<u32>();
|
||||
let value = rng.gen::<u64>();
|
||||
|
||||
let ct = cks.encrypt_radix(value, num_blocks);
|
||||
|
||||
let _ = clear_store.insert(key, value);
|
||||
kv_store.insert(key, ct);
|
||||
}
|
||||
|
||||
assert_store_unsigned_matches(&clear_store, &kv_store, &cks);
|
||||
|
||||
let compressed = kv_store.compress(&compression_key);
|
||||
let kv_store = compressed.decompress(&decompression_key).unwrap();
|
||||
assert_store_unsigned_matches(&clear_store, &kv_store, &cks);
|
||||
|
||||
let mut data = vec![];
|
||||
crate::safe_serialization::safe_serialize(&compressed, &mut data, 1 << 20).unwrap();
|
||||
let compressed: CompressedKVStore<u32> =
|
||||
crate::safe_serialization::safe_deserialize(data.as_slice(), 1 << 20).unwrap();
|
||||
let kv_store = compressed.decompress(&decompression_key).unwrap();
|
||||
assert_store_unsigned_matches(&clear_store, &kv_store, &cks);
|
||||
}
|
||||
|
||||
fn assert_store_signed_matches(
|
||||
clear_store: &HashMap<u32, i64>,
|
||||
kv_store: &KVStore<u32, SignedRadixCiphertext>,
|
||||
cks: &ClientKey,
|
||||
) {
|
||||
assert_eq!(
|
||||
clear_store.len(),
|
||||
kv_store.len(),
|
||||
"Clear and Encrypted stores do no have the same number of pairs"
|
||||
);
|
||||
|
||||
for (key, value) in clear_store {
|
||||
let ct = kv_store
|
||||
.get(key)
|
||||
.expect("Missing entry in decompressed KVStore");
|
||||
|
||||
let decrypted: i64 = cks.decrypt_signed_radix(ct);
|
||||
assert_eq!(
|
||||
*value, decrypted,
|
||||
"Invalid value stored for key '{key}', expected '{value}' got '{decrypted}'"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compression_serialization_signed() {
|
||||
let params = TEST_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128.into();
|
||||
|
||||
let (cks, _) = gen_keys::<ShortintParameterSet>(params, IntegerKeyKind::Radix);
|
||||
|
||||
let private_compression_key = cks
|
||||
.new_compression_private_key(TEST_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128);
|
||||
|
||||
let (compression_key, decompression_key) =
|
||||
cks.new_compression_decompression_keys(&private_compression_key);
|
||||
|
||||
let num_blocks = 32;
|
||||
let num_keys = 100;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let mut clear_store = HashMap::new();
|
||||
let mut kv_store = KVStore::new();
|
||||
for _ in 0..num_keys {
|
||||
let key = rng.gen::<u32>();
|
||||
let value = rng.gen::<i64>();
|
||||
|
||||
let ct = cks.encrypt_signed_radix(value, num_blocks);
|
||||
|
||||
let _ = clear_store.insert(key, value);
|
||||
kv_store.insert(key, ct);
|
||||
}
|
||||
|
||||
assert_store_signed_matches(&clear_store, &kv_store, &cks);
|
||||
|
||||
let compressed = kv_store.compress(&compression_key);
|
||||
let kv_store = compressed.decompress(&decompression_key).unwrap();
|
||||
assert_store_signed_matches(&clear_store, &kv_store, &cks);
|
||||
|
||||
let mut data = vec![];
|
||||
crate::safe_serialization::safe_serialize(&compressed, &mut data, 1 << 20).unwrap();
|
||||
let compressed: CompressedKVStore<u32> =
|
||||
crate::safe_serialization::safe_deserialize(data.as_slice(), 1 << 20).unwrap();
|
||||
let kv_store = compressed.decompress(&decompression_key).unwrap();
|
||||
assert_store_signed_matches(&clear_store, &kv_store, &cks);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,11 +9,101 @@ use crate::shortint::parameters::{TestParameters, *};
|
||||
use std::collections::BTreeMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
create_parameterized_test!(integer_default_kv_store_add);
|
||||
create_parameterized_test!(integer_default_kv_store_sub);
|
||||
create_parameterized_test!(integer_default_kv_store_mul);
|
||||
create_parameterized_test!(integer_default_kv_store_get_update);
|
||||
create_parameterized_test!(integer_default_kv_store_map);
|
||||
create_parameterized_test!(
|
||||
integer_default_kv_store_add
|
||||
{
|
||||
coverage => {
|
||||
COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
|
||||
COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS
|
||||
},
|
||||
no_coverage => {
|
||||
TEST_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128,
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128,
|
||||
// 2M128 is too slow for 4_4, it is estimated to be 2x slower
|
||||
TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
|
||||
}
|
||||
}
|
||||
);
|
||||
create_parameterized_test!(
|
||||
integer_default_kv_store_sub
|
||||
{
|
||||
coverage => {
|
||||
COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
|
||||
COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS
|
||||
},
|
||||
no_coverage => {
|
||||
TEST_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128,
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128,
|
||||
// 2M128 is too slow for 4_4, it is estimated to be 2x slower
|
||||
TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
|
||||
}
|
||||
}
|
||||
);
|
||||
create_parameterized_test!(
|
||||
integer_default_kv_store_mul
|
||||
{
|
||||
coverage => {
|
||||
COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
|
||||
COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS
|
||||
},
|
||||
no_coverage => {
|
||||
TEST_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128,
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128,
|
||||
// 2M128 is too slow for 4_4, it is estimated to be 2x slower
|
||||
TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
|
||||
}
|
||||
}
|
||||
);
|
||||
create_parameterized_test!(
|
||||
integer_default_kv_store_get_update
|
||||
{
|
||||
coverage => {
|
||||
COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
|
||||
COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS
|
||||
},
|
||||
no_coverage => {
|
||||
TEST_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128,
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128,
|
||||
// 2M128 is too slow for 4_4, it is estimated to be 2x slower
|
||||
TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
|
||||
}
|
||||
}
|
||||
);
|
||||
create_parameterized_test!(
|
||||
integer_default_kv_store_map
|
||||
{
|
||||
coverage => {
|
||||
COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
|
||||
COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS
|
||||
},
|
||||
no_coverage => {
|
||||
TEST_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128,
|
||||
PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128,
|
||||
TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128,
|
||||
// 2M128 is too slow for 4_4, it is estimated to be 2x slower
|
||||
TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
|
||||
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
fn integer_default_kv_store_add(params: impl Into<TestParameters>) {
|
||||
let executor = CpuFunctionExecutor::new(&ServerKey::kv_store_add_to_slot);
|
||||
|
||||
Reference in New Issue
Block a user