refactor: use BTreeMap as internals of KVStore

This is to make the order of the key and value lists
deterministic when compressing
This commit is contained in:
Thomas Montaigu
2025-10-08 11:31:26 +02:00
committed by tmontaigu
parent eb03158e6e
commit 6869214e15
3 changed files with 46 additions and 36 deletions

View File

@@ -1,7 +1,6 @@
use benchmark::utilities::{hlapi_throughput_num_ops, write_to_json, BenchmarkType, OperatorType};
use criterion::{black_box, Criterion, Throughput};
use rand::prelude::*;
use std::hash::Hash;
use std::marker::PhantomData;
use std::ops::*;
use tfhe::core_crypto::prelude::Numeric;
@@ -286,7 +285,7 @@ where
fn bench_kv_store<Key, FheKey, Value>(c: &mut Criterion, cks: &ClientKey, num_elements: usize)
where
rand::distributions::Standard: rand::distributions::Distribution<Key>,
Key: Numeric + DecomposableInto<u64> + Eq + Hash + CastInto<usize> + TypeDisplay,
Key: Numeric + DecomposableInto<u64> + Ord + CastInto<usize> + TypeDisplay,
Value: FheEncrypt<u128, ClientKey> + FheIntegerType + Clone + Send + Sync + TypeDisplay,
Value::Id: FheUintId,
FheKey: FheEncrypt<Key, ClientKey> + FheIntegerType + Send + Sync,

View File

@@ -13,7 +13,6 @@ use crate::integer::server_key::{
use crate::prelude::CastInto;
use crate::{FheBool, IntegerId, ReRandomizationMetadata, Tag};
use std::fmt::Display;
use std::hash::Hash;
#[derive(Clone)]
enum InnerKVStore<Key, T>
@@ -78,7 +77,7 @@ where
/// Returns the old value if there was any
pub fn insert_with_clear_key(&mut self, key: Key, value: T) -> Option<T>
where
Key: Eq + Hash,
Key: Ord,
{
#[allow(unreachable_patterns)]
global_state::with_internal_keys(|server_key| match (server_key, &mut self.inner) {
@@ -116,7 +115,7 @@ where
/// if its not present
pub fn update_with_clear_key(&mut self, key: &Key, value: T) -> Option<T>
where
Key: Eq + Hash,
Key: Ord,
{
#[allow(unreachable_patterns)]
global_state::with_internal_keys(|server_key| match (server_key, &mut self.inner) {
@@ -155,7 +154,7 @@ where
/// be set when calling this function is order to set the Tag of the resulting ciphertext
pub fn remove_with_clear_key(&mut self, key: &Key) -> Option<T>
where
Key: Eq + Hash,
Key: Ord,
{
#[allow(unreachable_patterns)]
global_state::with_internal_keys(|server_key| match (server_key, &mut self.inner) {
@@ -191,7 +190,7 @@ where
/// be set when calling this function is order to set the Tag of the resulting ciphertext
pub fn get_with_clear_key(&self, key: &Key) -> Option<T>
where
Key: Eq + Hash,
Key: Ord,
{
#[allow(unreachable_patterns)]
global_state::with_internal_keys(|server_key| match (server_key, &self.inner) {
@@ -227,7 +226,7 @@ where
impl<Key, T> KVStore<Key, T>
where
Key: Decomposable + CastInto<usize> + Hash + Eq,
Key: Decomposable + CastInto<usize> + Ord,
T: FheIntegerType,
{
/// Gets the value corresponding to the encrypted key.
@@ -382,7 +381,7 @@ where
/// Compressed the KVStore, making it serializable
pub fn compress(&self) -> crate::Result<CompressedKVStore<Key, T>>
where
Key: Copy + Display + Eq + Hash,
Key: Copy + Display + Ord,
<T::Id as IntegerId>::InnerCpu: Compressible + Clone,
{
#[allow(unreachable_patterns)]
@@ -475,7 +474,7 @@ where
pub fn decompress(&self) -> crate::Result<KVStore<Key, Value>>
where
<Value::Id as IntegerId>::InnerCpu: Expandable,
Key: Copy + Display + Eq + Hash,
Key: Copy + Display + Ord,
{
global_state::try_with_internal_keys(|key| match key {
Some(InternalServerKey::Cpu(cpu_key)) => {
@@ -515,21 +514,19 @@ where
#[cfg(test)]
mod test {
use std::collections::HashMap;
use std::hash::Hash;
use crate::core_crypto::prelude::Numeric;
use crate::high_level_api::kv_store::CompressedKVStore;
use crate::prelude::*;
use crate::{ClientKey, FheInt32, FheIntegerType, FheUint32, FheUint64, FheUint8, KVStore};
use rand::prelude::*;
use std::collections::BTreeMap;
fn create_kv_store<K, V, FheType>(
num_keys: usize,
ck: &ClientKey,
) -> (KVStore<K, FheType>, HashMap<K, V>)
) -> (KVStore<K, FheType>, BTreeMap<K, V>)
where
K: Numeric + CastInto<usize> + Hash + Eq,
K: Numeric + CastInto<usize> + Ord,
V: Numeric,
rand::distributions::Standard:
rand::distributions::Distribution<K> + rand::distributions::Distribution<V>,
@@ -539,7 +536,7 @@ mod test {
let mut rng = rand::thread_rng();
let mut kv_store = KVStore::new();
let mut clear_store = HashMap::new();
let mut clear_store = BTreeMap::new();
while kv_store.len() != num_keys {
let k = rng.gen::<K>();
let v = rng.gen::<V>();

View File

@@ -9,9 +9,8 @@ use crate::integer::{BooleanBlock, IntegerRadixCiphertext, ServerKey};
use crate::prelude::CastInto;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::collections::BTreeMap;
use std::fmt::Display;
use std::hash::Hash;
use std::marker::PhantomData;
use std::num::NonZeroUsize;
use tfhe_versionable::Versionize;
@@ -28,7 +27,7 @@ use tfhe_versionable::Versionize;
/// To serialize a KVStore it must first be compressed with [KVStore::compress]
#[derive(Clone)]
pub struct KVStore<Key, Ct> {
data: HashMap<Key, Ct>,
data: BTreeMap<Key, Ct>,
block_count: Option<NonZeroUsize>,
}
@@ -36,7 +35,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
/// Creates an empty KVStore
pub fn new() -> Self {
Self {
data: HashMap::new(),
data: BTreeMap::new(),
block_count: None,
}
}
@@ -47,7 +46,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
/// query using an encrypted key
pub fn get(&self, key: &Key) -> Option<&Ct>
where
Key: Eq + Hash,
Key: Ord,
{
self.data.get(key)
}
@@ -58,7 +57,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
/// query using an encrypted key
pub fn get_mut(&mut self, key: &Key) -> Option<&mut Ct>
where
Key: Eq + Hash,
Key: Ord,
{
self.data.get_mut(key)
}
@@ -77,7 +76,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
/// values stored
pub fn insert(&mut self, key: Key, value: Ct) -> Option<Ct>
where
Key: PartialEq + Eq + Hash,
Key: Ord,
Ct: IntegerRadixCiphertext,
{
let n_blocks = value.blocks().len();
@@ -100,7 +99,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
/// Removes a key-value pair.
pub fn remove(&mut self, key: &Key) -> Option<Ct>
where
Key: Eq + Hash,
Key: Ord,
{
self.data.remove(key)
}
@@ -108,7 +107,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
/// Returns the value associated to the key given in clear
pub fn clear_get(&self, key: &Key) -> Option<&Ct>
where
Key: Eq + Hash,
Key: Ord,
{
self.data.get(key)
}
@@ -125,7 +124,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
pub fn iter(&self) -> impl Iterator<Item = (&Key, &Ct)>
where
Key: Eq + Hash + Sync,
Key: Ord,
Ct: Send,
{
self.data.iter()
@@ -133,7 +132,7 @@ impl<Key, Ct> KVStore<Key, Ct> {
fn par_iter_keys(&self) -> impl ParallelIterator<Item = &Key>
where
Key: Send + Sync + Hash + Eq,
Key: Send + Sync + Ord,
Ct: Send + Sync,
{
self.data.par_iter().map(|(k, _)| k)
@@ -153,6 +152,21 @@ where
}
}
// # Impl Note
//
// In a few places we need to do parallel iteration over the BTreeMap entries, zipped with some
// BooleanBlock However, BTreeMap's par_iter does not impl rayon::IndexedParallelIterator
// which means it has no zip. So we resort to collecting in a Vec.
// (Also, internally BTreeMap's par_iter already seems to be using a Vec<(&Key, &Value)>)
//
// We chose collecting instead or using par_bride over the zipped sequential iterators
// as its advertised as less efficient, and we can afford the cost of the clone (both memory wise
// and compute wise) however, chances are that both impl would be fine as the real cost of compute
// is in the FHE ops.
//
// Also, one important point is that par_iter_bridge may not keep iteration order
// `The resulting iterator is not guaranteed to keep the order of the original iterator` (from rayon
// docs) which is a problem for us as we need determinisn
impl ServerKey {
/// Implementation of the get function that additionally returns the Vec of selectors
/// so it can be reused to avoid re-computing it.
@@ -163,7 +177,7 @@ impl ServerKey {
) -> (Ct, BooleanBlock, Vec<BooleanBlock>)
where
Ct: IntegerRadixCiphertext,
Key: Decomposable + CastInto<usize> + Hash + Eq,
Key: Decomposable + CastInto<usize> + Ord,
{
let selectors =
self.compute_equality_selectors(encrypted_key, map.par_iter_keys().copied());
@@ -209,7 +223,7 @@ impl ServerKey {
) -> (Ct, BooleanBlock)
where
Ct: IntegerRadixCiphertext,
Key: Decomposable + CastInto<usize> + Hash + Eq,
Key: Decomposable + CastInto<usize> + Ord,
{
let (result, check_block, _selectors) = self.kv_store_get_impl(map, encrypted_key);
(result, check_block)
@@ -232,7 +246,7 @@ impl ServerKey {
) -> BooleanBlock
where
Ct: IntegerRadixCiphertext,
Key: Decomposable + CastInto<usize> + Hash + Eq,
Key: Decomposable + CastInto<usize> + Ord,
{
let selectors =
self.compute_equality_selectors(encrypted_key, map.par_iter_keys().copied());
@@ -272,7 +286,7 @@ impl ServerKey {
) -> (Ct, Ct, BooleanBlock)
where
Ct: IntegerRadixCiphertext,
Key: Decomposable + CastInto<usize> + Hash + Eq,
Key: Decomposable + CastInto<usize> + Ord,
F: Fn(Ct) -> Ct,
{
let (old_value, check_block, selectors) = self.kv_store_get_impl(map, encrypted_key);
@@ -348,7 +362,7 @@ where
decompression_key: &DecompressionKey,
) -> crate::Result<KVStore<Key, Value>>
where
Key: Copy + Display + Eq + Hash,
Key: Copy + Display + Ord,
{
if Value::IS_SIGNED != self.is_signed {
let requested = if Value::IS_SIGNED { "Signed" } else { "" };
@@ -432,7 +446,7 @@ mod tests {
use crate::shortint::ShortintParameterSet;
fn assert_store_unsigned_matches(
clear_store: &HashMap<u32, u64>,
clear_store: &BTreeMap<u32, u64>,
kv_store: &KVStore<u32, RadixCiphertext>,
cks: &ClientKey,
) {
@@ -472,7 +486,7 @@ mod tests {
let mut rng = rand::thread_rng();
let mut clear_store = HashMap::new();
let mut clear_store = BTreeMap::new();
let mut kv_store = KVStore::new();
for _ in 0..num_keys {
let key = rng.gen::<u32>();
@@ -499,7 +513,7 @@ mod tests {
}
fn assert_store_signed_matches(
clear_store: &HashMap<u32, i64>,
clear_store: &BTreeMap<u32, i64>,
kv_store: &KVStore<u32, SignedRadixCiphertext>,
cks: &ClientKey,
) {
@@ -539,7 +553,7 @@ mod tests {
let mut rng = rand::thread_rng();
let mut clear_store = HashMap::new();
let mut clear_store = BTreeMap::new();
let mut kv_store = KVStore::new();
for _ in 0..num_keys {
let key = rng.gen::<u32>();