mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-08 22:28:01 -05:00
refactor: use BTreeMap as internals of KVStore
This is to make the order of the key and value lists deterministic when compressing
This commit is contained in:
committed by
tmontaigu
parent
eb03158e6e
commit
6869214e15
@@ -1,7 +1,6 @@
|
||||
use benchmark::utilities::{hlapi_throughput_num_ops, write_to_json, BenchmarkType, OperatorType};
|
||||
use criterion::{black_box, Criterion, Throughput};
|
||||
use rand::prelude::*;
|
||||
use std::hash::Hash;
|
||||
use std::marker::PhantomData;
|
||||
use std::ops::*;
|
||||
use tfhe::core_crypto::prelude::Numeric;
|
||||
@@ -286,7 +285,7 @@ where
|
||||
fn bench_kv_store<Key, FheKey, Value>(c: &mut Criterion, cks: &ClientKey, num_elements: usize)
|
||||
where
|
||||
rand::distributions::Standard: rand::distributions::Distribution<Key>,
|
||||
Key: Numeric + DecomposableInto<u64> + Eq + Hash + CastInto<usize> + TypeDisplay,
|
||||
Key: Numeric + DecomposableInto<u64> + Ord + CastInto<usize> + TypeDisplay,
|
||||
Value: FheEncrypt<u128, ClientKey> + FheIntegerType + Clone + Send + Sync + TypeDisplay,
|
||||
Value::Id: FheUintId,
|
||||
FheKey: FheEncrypt<Key, ClientKey> + FheIntegerType + Send + Sync,
|
||||
|
||||
@@ -13,7 +13,6 @@ use crate::integer::server_key::{
|
||||
use crate::prelude::CastInto;
|
||||
use crate::{FheBool, IntegerId, ReRandomizationMetadata, Tag};
|
||||
use std::fmt::Display;
|
||||
use std::hash::Hash;
|
||||
|
||||
#[derive(Clone)]
|
||||
enum InnerKVStore<Key, T>
|
||||
@@ -78,7 +77,7 @@ where
|
||||
/// Returns the old value if there was any
|
||||
pub fn insert_with_clear_key(&mut self, key: Key, value: T) -> Option<T>
|
||||
where
|
||||
Key: Eq + Hash,
|
||||
Key: Ord,
|
||||
{
|
||||
#[allow(unreachable_patterns)]
|
||||
global_state::with_internal_keys(|server_key| match (server_key, &mut self.inner) {
|
||||
@@ -116,7 +115,7 @@ where
|
||||
/// if its not present
|
||||
pub fn update_with_clear_key(&mut self, key: &Key, value: T) -> Option<T>
|
||||
where
|
||||
Key: Eq + Hash,
|
||||
Key: Ord,
|
||||
{
|
||||
#[allow(unreachable_patterns)]
|
||||
global_state::with_internal_keys(|server_key| match (server_key, &mut self.inner) {
|
||||
@@ -155,7 +154,7 @@ where
|
||||
/// be set when calling this function is order to set the Tag of the resulting ciphertext
|
||||
pub fn remove_with_clear_key(&mut self, key: &Key) -> Option<T>
|
||||
where
|
||||
Key: Eq + Hash,
|
||||
Key: Ord,
|
||||
{
|
||||
#[allow(unreachable_patterns)]
|
||||
global_state::with_internal_keys(|server_key| match (server_key, &mut self.inner) {
|
||||
@@ -191,7 +190,7 @@ where
|
||||
/// be set when calling this function is order to set the Tag of the resulting ciphertext
|
||||
pub fn get_with_clear_key(&self, key: &Key) -> Option<T>
|
||||
where
|
||||
Key: Eq + Hash,
|
||||
Key: Ord,
|
||||
{
|
||||
#[allow(unreachable_patterns)]
|
||||
global_state::with_internal_keys(|server_key| match (server_key, &self.inner) {
|
||||
@@ -227,7 +226,7 @@ where
|
||||
|
||||
impl<Key, T> KVStore<Key, T>
|
||||
where
|
||||
Key: Decomposable + CastInto<usize> + Hash + Eq,
|
||||
Key: Decomposable + CastInto<usize> + Ord,
|
||||
T: FheIntegerType,
|
||||
{
|
||||
/// Gets the value corresponding to the encrypted key.
|
||||
@@ -382,7 +381,7 @@ where
|
||||
/// Compressed the KVStore, making it serializable
|
||||
pub fn compress(&self) -> crate::Result<CompressedKVStore<Key, T>>
|
||||
where
|
||||
Key: Copy + Display + Eq + Hash,
|
||||
Key: Copy + Display + Ord,
|
||||
<T::Id as IntegerId>::InnerCpu: Compressible + Clone,
|
||||
{
|
||||
#[allow(unreachable_patterns)]
|
||||
@@ -475,7 +474,7 @@ where
|
||||
pub fn decompress(&self) -> crate::Result<KVStore<Key, Value>>
|
||||
where
|
||||
<Value::Id as IntegerId>::InnerCpu: Expandable,
|
||||
Key: Copy + Display + Eq + Hash,
|
||||
Key: Copy + Display + Ord,
|
||||
{
|
||||
global_state::try_with_internal_keys(|key| match key {
|
||||
Some(InternalServerKey::Cpu(cpu_key)) => {
|
||||
@@ -515,21 +514,19 @@ where
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
|
||||
use crate::core_crypto::prelude::Numeric;
|
||||
use crate::high_level_api::kv_store::CompressedKVStore;
|
||||
use crate::prelude::*;
|
||||
use crate::{ClientKey, FheInt32, FheIntegerType, FheUint32, FheUint64, FheUint8, KVStore};
|
||||
use rand::prelude::*;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
fn create_kv_store<K, V, FheType>(
|
||||
num_keys: usize,
|
||||
ck: &ClientKey,
|
||||
) -> (KVStore<K, FheType>, HashMap<K, V>)
|
||||
) -> (KVStore<K, FheType>, BTreeMap<K, V>)
|
||||
where
|
||||
K: Numeric + CastInto<usize> + Hash + Eq,
|
||||
K: Numeric + CastInto<usize> + Ord,
|
||||
V: Numeric,
|
||||
rand::distributions::Standard:
|
||||
rand::distributions::Distribution<K> + rand::distributions::Distribution<V>,
|
||||
@@ -539,7 +536,7 @@ mod test {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let mut kv_store = KVStore::new();
|
||||
let mut clear_store = HashMap::new();
|
||||
let mut clear_store = BTreeMap::new();
|
||||
while kv_store.len() != num_keys {
|
||||
let k = rng.gen::<K>();
|
||||
let v = rng.gen::<V>();
|
||||
|
||||
@@ -9,9 +9,8 @@ use crate::integer::{BooleanBlock, IntegerRadixCiphertext, ServerKey};
|
||||
use crate::prelude::CastInto;
|
||||
use rayon::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::BTreeMap;
|
||||
use std::fmt::Display;
|
||||
use std::hash::Hash;
|
||||
use std::marker::PhantomData;
|
||||
use std::num::NonZeroUsize;
|
||||
use tfhe_versionable::Versionize;
|
||||
@@ -28,7 +27,7 @@ use tfhe_versionable::Versionize;
|
||||
/// To serialize a KVStore it must first be compressed with [KVStore::compress]
|
||||
#[derive(Clone)]
|
||||
pub struct KVStore<Key, Ct> {
|
||||
data: HashMap<Key, Ct>,
|
||||
data: BTreeMap<Key, Ct>,
|
||||
block_count: Option<NonZeroUsize>,
|
||||
}
|
||||
|
||||
@@ -36,7 +35,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
|
||||
/// Creates an empty KVStore
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
data: HashMap::new(),
|
||||
data: BTreeMap::new(),
|
||||
block_count: None,
|
||||
}
|
||||
}
|
||||
@@ -47,7 +46,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
|
||||
/// query using an encrypted key
|
||||
pub fn get(&self, key: &Key) -> Option<&Ct>
|
||||
where
|
||||
Key: Eq + Hash,
|
||||
Key: Ord,
|
||||
{
|
||||
self.data.get(key)
|
||||
}
|
||||
@@ -58,7 +57,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
|
||||
/// query using an encrypted key
|
||||
pub fn get_mut(&mut self, key: &Key) -> Option<&mut Ct>
|
||||
where
|
||||
Key: Eq + Hash,
|
||||
Key: Ord,
|
||||
{
|
||||
self.data.get_mut(key)
|
||||
}
|
||||
@@ -77,7 +76,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
|
||||
/// values stored
|
||||
pub fn insert(&mut self, key: Key, value: Ct) -> Option<Ct>
|
||||
where
|
||||
Key: PartialEq + Eq + Hash,
|
||||
Key: Ord,
|
||||
Ct: IntegerRadixCiphertext,
|
||||
{
|
||||
let n_blocks = value.blocks().len();
|
||||
@@ -100,7 +99,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
|
||||
/// Removes a key-value pair.
|
||||
pub fn remove(&mut self, key: &Key) -> Option<Ct>
|
||||
where
|
||||
Key: Eq + Hash,
|
||||
Key: Ord,
|
||||
{
|
||||
self.data.remove(key)
|
||||
}
|
||||
@@ -108,7 +107,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
|
||||
/// Returns the value associated to the key given in clear
|
||||
pub fn clear_get(&self, key: &Key) -> Option<&Ct>
|
||||
where
|
||||
Key: Eq + Hash,
|
||||
Key: Ord,
|
||||
{
|
||||
self.data.get(key)
|
||||
}
|
||||
@@ -125,7 +124,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = (&Key, &Ct)>
|
||||
where
|
||||
Key: Eq + Hash + Sync,
|
||||
Key: Ord,
|
||||
Ct: Send,
|
||||
{
|
||||
self.data.iter()
|
||||
@@ -133,7 +132,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
|
||||
|
||||
fn par_iter_keys(&self) -> impl ParallelIterator<Item = &Key>
|
||||
where
|
||||
Key: Send + Sync + Hash + Eq,
|
||||
Key: Send + Sync + Ord,
|
||||
Ct: Send + Sync,
|
||||
{
|
||||
self.data.par_iter().map(|(k, _)| k)
|
||||
@@ -153,6 +152,21 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// # Impl Note
|
||||
//
|
||||
// In a few places we need to do parallel iteration over the BTreeMap entries, zipped with some
|
||||
// BooleanBlock However, BTreeMap's par_iter does not impl rayon::IndexedParallelIterator
|
||||
// which means it has no zip. So we resort to collecting in a Vec.
|
||||
// (Also, internally BTreeMap's par_iter already seems to be using a Vec<(&Key, &Value)>)
|
||||
//
|
||||
// We chose collecting instead or using par_bride over the zipped sequential iterators
|
||||
// as its advertised as less efficient, and we can afford the cost of the clone (both memory wise
|
||||
// and compute wise) however, chances are that both impl would be fine as the real cost of compute
|
||||
// is in the FHE ops.
|
||||
//
|
||||
// Also, one important point is that par_iter_bridge may not keep iteration order
|
||||
// `The resulting iterator is not guaranteed to keep the order of the original iterator` (from rayon
|
||||
// docs) which is a problem for us as we need determinisn
|
||||
impl ServerKey {
|
||||
/// Implementation of the get function that additionally returns the Vec of selectors
|
||||
/// so it can be reused to avoid re-computing it.
|
||||
@@ -163,7 +177,7 @@ impl ServerKey {
|
||||
) -> (Ct, BooleanBlock, Vec<BooleanBlock>)
|
||||
where
|
||||
Ct: IntegerRadixCiphertext,
|
||||
Key: Decomposable + CastInto<usize> + Hash + Eq,
|
||||
Key: Decomposable + CastInto<usize> + Ord,
|
||||
{
|
||||
let selectors =
|
||||
self.compute_equality_selectors(encrypted_key, map.par_iter_keys().copied());
|
||||
@@ -209,7 +223,7 @@ impl ServerKey {
|
||||
) -> (Ct, BooleanBlock)
|
||||
where
|
||||
Ct: IntegerRadixCiphertext,
|
||||
Key: Decomposable + CastInto<usize> + Hash + Eq,
|
||||
Key: Decomposable + CastInto<usize> + Ord,
|
||||
{
|
||||
let (result, check_block, _selectors) = self.kv_store_get_impl(map, encrypted_key);
|
||||
(result, check_block)
|
||||
@@ -232,7 +246,7 @@ impl ServerKey {
|
||||
) -> BooleanBlock
|
||||
where
|
||||
Ct: IntegerRadixCiphertext,
|
||||
Key: Decomposable + CastInto<usize> + Hash + Eq,
|
||||
Key: Decomposable + CastInto<usize> + Ord,
|
||||
{
|
||||
let selectors =
|
||||
self.compute_equality_selectors(encrypted_key, map.par_iter_keys().copied());
|
||||
@@ -272,7 +286,7 @@ impl ServerKey {
|
||||
) -> (Ct, Ct, BooleanBlock)
|
||||
where
|
||||
Ct: IntegerRadixCiphertext,
|
||||
Key: Decomposable + CastInto<usize> + Hash + Eq,
|
||||
Key: Decomposable + CastInto<usize> + Ord,
|
||||
F: Fn(Ct) -> Ct,
|
||||
{
|
||||
let (old_value, check_block, selectors) = self.kv_store_get_impl(map, encrypted_key);
|
||||
@@ -348,7 +362,7 @@ where
|
||||
decompression_key: &DecompressionKey,
|
||||
) -> crate::Result<KVStore<Key, Value>>
|
||||
where
|
||||
Key: Copy + Display + Eq + Hash,
|
||||
Key: Copy + Display + Ord,
|
||||
{
|
||||
if Value::IS_SIGNED != self.is_signed {
|
||||
let requested = if Value::IS_SIGNED { "Signed" } else { "" };
|
||||
@@ -432,7 +446,7 @@ mod tests {
|
||||
use crate::shortint::ShortintParameterSet;
|
||||
|
||||
fn assert_store_unsigned_matches(
|
||||
clear_store: &HashMap<u32, u64>,
|
||||
clear_store: &BTreeMap<u32, u64>,
|
||||
kv_store: &KVStore<u32, RadixCiphertext>,
|
||||
cks: &ClientKey,
|
||||
) {
|
||||
@@ -472,7 +486,7 @@ mod tests {
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let mut clear_store = HashMap::new();
|
||||
let mut clear_store = BTreeMap::new();
|
||||
let mut kv_store = KVStore::new();
|
||||
for _ in 0..num_keys {
|
||||
let key = rng.gen::<u32>();
|
||||
@@ -499,7 +513,7 @@ mod tests {
|
||||
}
|
||||
|
||||
fn assert_store_signed_matches(
|
||||
clear_store: &HashMap<u32, i64>,
|
||||
clear_store: &BTreeMap<u32, i64>,
|
||||
kv_store: &KVStore<u32, SignedRadixCiphertext>,
|
||||
cks: &ClientKey,
|
||||
) {
|
||||
@@ -539,7 +553,7 @@ mod tests {
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
let mut clear_store = HashMap::new();
|
||||
let mut clear_store = BTreeMap::new();
|
||||
let mut kv_store = KVStore::new();
|
||||
for _ in 0..num_keys {
|
||||
let key = rng.gen::<u32>();
|
||||
|
||||
Reference in New Issue
Block a user