From e523fd2cb6da3aca1136bd1899da765fe71239e5 Mon Sep 17 00:00:00 2001 From: Thomas Montaigu Date: Mon, 29 Sep 2025 11:34:47 +0200 Subject: [PATCH] feat: add KVStore to the high level api * Added Value type name to crate::integer::KVStore impl of Named trait as well as a bool to check we deserialize the correct value type (Radix vs SignedRadix) * Add KVStore to high_level_api * Add KVStore hlapi benches * Remove specialized `[add,mul,sub]_to_slot` as `map` is now the intended API. - mul_to_slot was way slower than using `map` - add/mul_to_slot were a bit faster (~5% latency-wise), but returned less information (no old_value, no new_value, no boolean to check) if the key matched - Some known improvement can be made to map, which should result in it being better than add/sub_to_slot * Add FheIntegerType trait to make the KVStore generic over FheUint/FheInt, and should make GPU integration "easy" --- Makefile | 6 +- tfhe-benchmark/Cargo.toml | 2 +- .../benches/high_level_api/bench.rs | 196 ++++- tfhe-benchmark/src/utilities.rs | 35 + .../backward_compatibility/kv_store.rs | 11 + .../backward_compatibility/mod.rs | 1 + tfhe/src/high_level_api/errors.rs | 46 +- tfhe/src/high_level_api/integers/mod.rs | 40 +- .../high_level_api/integers/signed/base.rs | 35 +- .../high_level_api/integers/signed/static_.rs | 9 + .../high_level_api/integers/unsigned/base.rs | 35 +- .../integers/unsigned/static_.rs | 12 + tfhe/src/high_level_api/kv_store.rs | 759 ++++++++++++++++++ tfhe/src/high_level_api/mod.rs | 7 +- .../backward_compatibility/ciphertext/mod.rs | 4 +- .../server_key/radix_parallel/kv_store.rs | 231 ++---- .../tests_unsigned/test_kv_store.rs | 272 +------ 17 files changed, 1271 insertions(+), 430 deletions(-) create mode 100644 tfhe/src/high_level_api/backward_compatibility/kv_store.rs create mode 100644 tfhe/src/high_level_api/kv_store.rs diff --git a/Makefile b/Makefile index 998f1df71..96564fd85 100644 --- a/Makefile +++ b/Makefile @@ -1507,13 +1507,13 @@ bench_web_js_api_parallel_firefox_ci: setup_venv bench_hlapi: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \ --bench hlapi \ - --features=integer,internal-keycache,nightly-avx512 -p tfhe-benchmark -- + --features=integer,internal-keycache,nightly-avx512,pbs-stats -p tfhe-benchmark -- .PHONY: bench_hlapi_gpu # Run benchmarks for integer operations on GPU bench_hlapi_gpu: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \ --bench hlapi \ - --features=integer,gpu,internal-keycache,nightly-avx512 -p tfhe-benchmark -- + --features=integer,gpu,internal-keycache,nightly-avx512,pbs-stats -p tfhe-benchmark -- .PHONY: bench_hlapi_hpu # Run benchmarks for HLAPI operations on HPU bench_hlapi_hpu: install_rs_check_toolchain @@ -1522,7 +1522,7 @@ bench_hlapi_hpu: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" \ cargo $(CARGO_RS_CHECK_TOOLCHAIN) bench \ --bench hlapi \ - --features=integer,internal-keycache,hpu,hpu-v80 -p tfhe-benchmark -- + --features=integer,internal-keycache,hpu,hpu-v80,pbs-stats -p tfhe-benchmark -- .PHONY: bench_hlapi_erc20 # Run benchmarks for ERC20 operations bench_hlapi_erc20: install_rs_check_toolchain diff --git a/tfhe-benchmark/Cargo.toml b/tfhe-benchmark/Cargo.toml index 1684f6268..c1e81d412 100644 --- a/tfhe-benchmark/Cargo.toml +++ b/tfhe-benchmark/Cargo.toml @@ -70,7 +70,7 @@ required-features = ["shortint", "internal-keycache"] name = "hlapi" path = "benches/high_level_api/bench.rs" harness = false -required-features = ["integer", "internal-keycache"] +required-features = ["integer", "internal-keycache", "pbs-stats"] [[bench]] name = "hlapi-erc20" diff --git a/tfhe-benchmark/benches/high_level_api/bench.rs b/tfhe-benchmark/benches/high_level_api/bench.rs index 384d6fd56..c02f45b67 100644 --- a/tfhe-benchmark/benches/high_level_api/bench.rs +++ b/tfhe-benchmark/benches/high_level_api/bench.rs @@ -1,14 +1,22 @@ -use benchmark::utilities::{write_to_json, OperatorType}; -use criterion::{black_box, Criterion}; +use benchmark::utilities::{hlapi_throughput_num_ops, write_to_json, BenchmarkType, OperatorType}; +use criterion::{black_box, Criterion, Throughput}; use rand::prelude::*; +use std::hash::Hash; +use std::marker::PhantomData; use std::ops::*; +use tfhe::core_crypto::prelude::Numeric; +use tfhe::integer::block_decomposition::DecomposableInto; use tfhe::keycache::NamedParam; +use tfhe::named::Named; use tfhe::prelude::*; use tfhe::{ - ClientKey, CompressedServerKey, FheUint10, FheUint12, FheUint128, FheUint14, FheUint16, - FheUint2, FheUint32, FheUint4, FheUint6, FheUint64, FheUint8, + ClientKey, CompressedServerKey, FheIntegerType, FheUint10, FheUint12, FheUint128, FheUint14, + FheUint16, FheUint2, FheUint32, FheUint4, FheUint6, FheUint64, FheUint8, FheUintId, IntegerId, + KVStore, }; +use rayon::prelude::*; + fn bench_fhe_type( c: &mut Criterion, client_key: &ClientKey, @@ -225,6 +233,170 @@ bench_type!(FheUint32); bench_type!(FheUint64); bench_type!(FheUint128); +trait TypeDisplay { + fn fmt(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let name = std::any::type_name::(); + let pos = name.rfind(":").map_or(0, |p| p + 1); + write!(f, "{}", &name[pos..]) + } +} + +impl TypeDisplay for u8 {} +impl TypeDisplay for u16 {} +impl TypeDisplay for u32 {} +impl TypeDisplay for u64 {} +impl TypeDisplay for u128 {} + +impl TypeDisplay for tfhe::FheUint { + fn fmt(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write_fhe_type_name::(f) + } +} + +impl TypeDisplay for tfhe::FheInt { + fn fmt(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write_fhe_type_name::(f) + } +} + +struct TypeDisplayer(PhantomData); + +impl Default for TypeDisplayer { + fn default() -> Self { + Self(PhantomData) + } +} + +impl std::fmt::Display for TypeDisplayer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + T::fmt(f) + } +} + +fn write_fhe_type_name<'a, FheType>(f: &mut std::fmt::Formatter<'a>) -> std::fmt::Result +where + FheType: FheIntegerType + Named, +{ + let full_name = FheType::NAME; + let i = full_name.rfind(":").map_or(0, |p| p + 1); + + write!(f, "{}{}", &full_name[i..], FheType::Id::num_bits()) +} + +fn bench_kv_store(c: &mut Criterion, cks: &ClientKey, num_elements: usize) +where + rand::distributions::Standard: rand::distributions::Distribution, + Key: Numeric + DecomposableInto + Eq + Hash + CastInto + TypeDisplay, + Value: FheEncrypt + FheIntegerType + Clone + Send + Sync + TypeDisplay, + Value::Id: FheUintId, + FheKey: FheEncrypt + FheIntegerType + Send + Sync, + FheKey::Id: FheUintId, +{ + let mut kv_store = KVStore::new(); + let mut rng = rand::thread_rng(); + + let format_id_bench = |op_name: &str| -> String { + format!( + "KVStore::<{}, {}>::{op_name}/{num_elements}", + TypeDisplayer::::default(), + TypeDisplayer::::default(), + ) + }; + + match BenchmarkType::from_env().unwrap() { + BenchmarkType::Latency => { + while kv_store.len() != num_elements { + let key = rng.gen::(); + let value = rng.gen::(); + + let encrypted_value = Value::encrypt(value, cks); + kv_store.insert_with_clear_key(key, encrypted_value); + } + + let key = rng.gen::(); + let encrypted_key = FheKey::encrypt(key, cks); + + let value = rng.gen::(); + let value_to_add = Value::encrypt(value, cks); + + c.bench_function(&format_id_bench("Get"), |b| { + b.iter(|| { + let _ = kv_store.get(&encrypted_key); + }) + }); + + c.bench_function(&format_id_bench("Update"), |b| { + b.iter(|| { + let _ = kv_store.update(&encrypted_key, &value_to_add); + }) + }); + + c.bench_function(&format_id_bench("Map"), |b| { + b.iter(|| { + kv_store.map(&encrypted_key, |v| v); + }) + }); + } + BenchmarkType::Throughput => { + while kv_store.len() != num_elements { + let key = rng.gen::(); + let value = rng.gen::(); + + let encrypted_value = Value::encrypt(value, cks); + kv_store.insert_with_clear_key(key, encrypted_value); + } + + let key = rng.gen::(); + let encrypted_key = FheKey::encrypt(key, cks); + + let value = rng.gen::(); + let value_to_add = Value::encrypt(value, cks); + + let factor = hlapi_throughput_num_ops( + || { + kv_store.map(&encrypted_key, |v| v); + }, + cks, + ); + + let mut kv_stores = vec![]; + for _ in 0..factor.saturating_sub(1) { + kv_stores.push(kv_store.clone()); + } + kv_stores.push(kv_store); + + let mut group = c.benchmark_group("KVStore Throughput"); + group.throughput(Throughput::Elements(kv_stores.len() as u64)); + + group.bench_function(format_id_bench("Map"), |b| { + b.iter(|| { + kv_stores.par_iter_mut().for_each(|kv_store| { + kv_store.map(&encrypted_key, |v| v); + }) + }) + }); + + group.bench_function(format_id_bench("Update"), |b| { + b.iter(|| { + kv_stores.par_iter_mut().for_each(|kv_store| { + kv_store.update(&encrypted_key, &value_to_add); + }) + }) + }); + + group.bench_function(format_id_bench("Get"), |b| { + b.iter(|| { + kv_stores.par_iter_mut().for_each(|kv_store| { + kv_store.get(&encrypted_key); + }) + }) + }); + + group.finish(); + } + } +} + fn main() { #[cfg(feature = "hpu")] let cks = { @@ -256,7 +428,9 @@ fn main() { let cks = ClientKey::generate(config); let compressed_sks = CompressedServerKey::new(&cks); - set_server_key(compressed_sks.decompress()); + let sks = compressed_sks.decompress(); + rayon::broadcast(|_| set_server_key(sks.clone())); + set_server_key(sks); cks }; @@ -274,5 +448,17 @@ fn main() { bench_fhe_uint64(&mut c, &cks); bench_fhe_uint128(&mut c, &cks); + for pow in 1..=10 { + bench_kv_store::(&mut c, &cks, 1 << pow); + } + + for pow in 1..=10 { + bench_kv_store::(&mut c, &cks, 1 << pow); + } + + for pow in 1..=10 { + bench_kv_store::(&mut c, &cks, 1 << pow); + } + c.final_summary(); } diff --git a/tfhe-benchmark/src/utilities.rs b/tfhe-benchmark/src/utilities.rs index a8f9c8590..614a3329a 100644 --- a/tfhe-benchmark/src/utilities.rs +++ b/tfhe-benchmark/src/utilities.rs @@ -5,6 +5,8 @@ use std::{env, fs}; #[cfg(feature = "gpu")] use tfhe::core_crypto::gpu::{get_number_of_gpus, get_number_of_sms}; use tfhe::core_crypto::prelude::*; +#[cfg(feature = "integer")] +use tfhe::prelude::*; #[cfg(feature = "boolean")] pub mod boolean_utils { @@ -466,6 +468,39 @@ pub fn throughput_num_threads(num_block: usize, op_pbs_count: u64) -> u64 { } } +// Given an `Op` this returns how many more ops should be done in parallel +// to saturate the CPU and have a better throughput measurement +#[cfg(feature = "integer")] +pub fn hlapi_throughput_num_ops(op: Op, cks: &tfhe::ClientKey) -> usize +where + Op: FnOnce(), +{ + tfhe::reset_pbs_count(); + let t = std::time::Instant::now(); + op(); + let time_for_op = t.elapsed(); + let pbs_count_for_op = tfhe::get_pbs_count(); + + let a = tfhe::FheBool::encrypt(true, cks); + let b = tfhe::FheBool::encrypt(true, cks); + let t = std::time::Instant::now(); + let _ = a & b; + let time_for_single_pbs = t.elapsed(); + + // Round-up with nano seconds + let pbs_time_in_ms = + time_for_single_pbs.as_millis() + u128::from(time_for_single_pbs.as_nanos() != 0); + + // Theoretical time if the op was just 1 layer of PBS all in parallel + let time_if_full_occupancy = + pbs_count_for_op.div_ceil(rayon::current_num_threads() as u64) as u128 * pbs_time_in_ms; + + // Then find how many ops we should do to have full occupancy + let factor = time_for_op.as_millis().div_ceil(time_if_full_occupancy); + + factor as usize +} + #[cfg(feature = "gpu")] mod cuda_utils { use tfhe::core_crypto::entities::{ diff --git a/tfhe/src/high_level_api/backward_compatibility/kv_store.rs b/tfhe/src/high_level_api/backward_compatibility/kv_store.rs new file mode 100644 index 000000000..297d38f41 --- /dev/null +++ b/tfhe/src/high_level_api/backward_compatibility/kv_store.rs @@ -0,0 +1,11 @@ +use crate::high_level_api::kv_store::CompressedKVStore; +use crate::FheIntegerType; +use tfhe_versionable::VersionsDispatch; + +#[derive(VersionsDispatch)] +pub enum CompressedKVStoreVersions +where + Value: FheIntegerType, +{ + V0(CompressedKVStore), +} diff --git a/tfhe/src/high_level_api/backward_compatibility/mod.rs b/tfhe/src/high_level_api/backward_compatibility/mod.rs index 5241b96de..9d68eef85 100644 --- a/tfhe/src/high_level_api/backward_compatibility/mod.rs +++ b/tfhe/src/high_level_api/backward_compatibility/mod.rs @@ -9,6 +9,7 @@ pub mod config; pub mod cpk_re_randomization; pub mod integers; pub mod keys; +pub mod kv_store; #[cfg(feature = "strings")] pub mod strings; pub mod tag; diff --git a/tfhe/src/high_level_api/errors.rs b/tfhe/src/high_level_api/errors.rs index a92909272..6b84f8c8b 100644 --- a/tfhe/src/high_level_api/errors.rs +++ b/tfhe/src/high_level_api/errors.rs @@ -81,10 +81,52 @@ impl Display for UninitializedReRandKey { } } -impl std::error::Error for UninitializedReRandKey {} - impl From for Error { fn from(value: UninitializedReRandKey) -> Self { Self::new(format!("{value}")) } } + +impl std::error::Error for UninitializedReRandKey {} + +#[derive(Debug)] +pub struct UninitializedCompressionKey; + +impl Display for UninitializedCompressionKey { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Compression key is not set in server key, \ + did you forget to call `enable_compression` when building your Config?", + ) + } +} + +impl std::error::Error for UninitializedCompressionKey {} + +impl From for Error { + fn from(value: UninitializedCompressionKey) -> Self { + Self::new(format!("{value}")) + } +} + +#[derive(Debug)] +pub struct UninitializedDecompressionKey; + +impl Display for UninitializedDecompressionKey { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Decompression key is not set in server key, \ + did you forget to call `enable_compression` when building your Config?", + ) + } +} + +impl std::error::Error for UninitializedDecompressionKey {} + +impl From for Error { + fn from(value: UninitializedDecompressionKey) -> Self { + Self::new(format!("{value}")) + } +} diff --git a/tfhe/src/high_level_api/integers/mod.rs b/tfhe/src/high_level_api/integers/mod.rs index 44656d54d..12caf7238 100644 --- a/tfhe/src/high_level_api/integers/mod.rs +++ b/tfhe/src/high_level_api/integers/mod.rs @@ -30,19 +30,23 @@ expand_pub_use_fhe_type!( }; ); +use crate::prelude::Tagged; +use crate::ReRandomizationMetadata; pub(in crate::high_level_api) use signed::{ - CompressedSignedRadixCiphertext, FheIntId, InnerSquashedNoiseSignedRadixCiphertextVersionOwned, + CompressedSignedRadixCiphertext, InnerSquashedNoiseSignedRadixCiphertextVersionOwned, SignedRadixCiphertextVersionOwned, }; pub(in crate::high_level_api) use unsigned::{ - CompressedRadixCiphertext, FheUintId, InnerSquashedNoiseRadixCiphertextVersionOwned, + CompressedRadixCiphertext, InnerSquashedNoiseRadixCiphertextVersionOwned, RadixCiphertextVersionOwned as UnsignedRadixCiphertextVersionOwned, }; // These are pub-exported so that their doc can appear in generated rust docs +use crate::high_level_api::details::MaybeCloned; use crate::high_level_api::traits::FheId; use crate::shortint::MessageModulus; -pub use signed::{CompressedFheInt, FheInt, SquashedNoiseFheInt}; -pub use unsigned::{CompressedFheUint, FheUint, SquashedNoiseFheUint}; +use crate::Tag; +pub use signed::{CompressedFheInt, FheInt, FheIntId, SquashedNoiseFheInt}; +pub use unsigned::{CompressedFheUint, FheUint, FheUintId, SquashedNoiseFheUint}; pub mod oprf; pub(super) mod signed; @@ -52,9 +56,37 @@ pub(super) mod unsigned; // The 'static restrains implementor from holding non-static refs // which is ok as it is meant to be impld by zero sized types. pub trait IntegerId: FheId + 'static { + type InnerCpu: crate::integer::IntegerRadixCiphertext; + + type InnerGpu; + + type InnerHpu; + fn num_bits() -> usize; fn num_blocks(message_modulus: MessageModulus) -> usize { Self::num_bits() / message_modulus.0.ilog2() as usize } } + +mod private { + pub trait Sealed {} + + impl Sealed for crate::high_level_api::FheUint where Id: super::FheUintId {} + + impl Sealed for crate::high_level_api::FheInt where Id: super::FheIntId {} +} + +pub trait FheIntegerType: Tagged + private::Sealed { + type Id: IntegerId; + + fn on_cpu(&self) -> MaybeCloned<'_, ::InnerCpu>; + + fn into_cpu(self) -> ::InnerCpu; + + fn from_cpu( + inner: ::InnerCpu, + tag: Tag, + re_randomization_metadata: ReRandomizationMetadata, + ) -> Self; +} diff --git a/tfhe/src/high_level_api/integers/signed/base.rs b/tfhe/src/high_level_api/integers/signed/base.rs index ea48644a2..4f755d5b0 100644 --- a/tfhe/src/high_level_api/integers/signed/base.rs +++ b/tfhe/src/high_level_api/integers/signed/base.rs @@ -4,9 +4,10 @@ use super::inner::SignedRadixCiphertext; use crate::backward_compatibility::integers::FheIntVersions; use crate::conformance::ParameterSetConformant; use crate::core_crypto::prelude::SignedNumeric; +use crate::high_level_api::details::MaybeCloned; use crate::high_level_api::errors::UninitializedReRandKey; use crate::high_level_api::global_state; -use crate::high_level_api::integers::{FheUint, FheUintId, IntegerId}; +use crate::high_level_api::integers::{FheIntegerType, FheUint, FheUintId, IntegerId}; use crate::high_level_api::keys::{CompactPublicKey, InternalServerKey}; use crate::high_level_api::re_randomization::ReRandomizationMetadata; use crate::high_level_api::traits::{ReRandomize, Tagged}; @@ -20,7 +21,14 @@ use crate::shortint::PBSParameters; use crate::{Device, FheBool, ServerKey, Tag}; use std::marker::PhantomData; -pub trait FheIntId: IntegerId {} +#[cfg(not(feature = "gpu"))] +type ExpectedInnerGpu = (); +#[cfg(feature = "gpu")] +type ExpectedInnerGpu = crate::integer::gpu::ciphertext::CudaSignedRadixCiphertext; +pub trait FheIntId: + IntegerId +{ +} /// A Generic FHE signed integer /// @@ -107,6 +115,29 @@ where } } +impl FheIntegerType for FheInt +where + Id: FheIntId, +{ + type Id = Id; + + fn on_cpu(&self) -> MaybeCloned<'_, ::InnerCpu> { + self.ciphertext.on_cpu() + } + + fn into_cpu(self) -> ::InnerCpu { + self.ciphertext.into_cpu() + } + + fn from_cpu( + inner: ::InnerCpu, + tag: Tag, + re_randomization_metadata: ReRandomizationMetadata, + ) -> Self { + Self::new(inner, tag, re_randomization_metadata) + } +} + impl FheInt where Id: FheIntId, diff --git a/tfhe/src/high_level_api/integers/signed/static_.rs b/tfhe/src/high_level_api/integers/signed/static_.rs index 4a1cba4ea..7f1695721 100644 --- a/tfhe/src/high_level_api/integers/signed/static_.rs +++ b/tfhe/src/high_level_api/integers/signed/static_.rs @@ -20,6 +20,15 @@ macro_rules! static_int_type { pub struct []; impl IntegerId for [] { + type InnerCpu = crate::integer::SignedRadixCiphertext; + + #[cfg(not(feature = "gpu"))] + type InnerGpu = (); + #[cfg(feature = "gpu")] + type InnerGpu = crate::integer::gpu::ciphertext::CudaSignedRadixCiphertext; + + type InnerHpu = (); + fn num_bits() -> usize { $num_bits } diff --git a/tfhe/src/high_level_api/integers/unsigned/base.rs b/tfhe/src/high_level_api/integers/unsigned/base.rs index 61612b182..2ecf60dcc 100644 --- a/tfhe/src/high_level_api/integers/unsigned/base.rs +++ b/tfhe/src/high_level_api/integers/unsigned/base.rs @@ -4,9 +4,10 @@ use super::inner::RadixCiphertext; use crate::backward_compatibility::integers::FheUintVersions; use crate::conformance::ParameterSetConformant; use crate::core_crypto::prelude::{CastFrom, UnsignedInteger, UnsignedNumeric}; +use crate::high_level_api::details::MaybeCloned; use crate::high_level_api::errors::UninitializedReRandKey; use crate::high_level_api::integers::signed::{FheInt, FheIntId}; -use crate::high_level_api::integers::IntegerId; +use crate::high_level_api::integers::{FheIntegerType, IntegerId}; use crate::high_level_api::keys::{CompactPublicKey, InternalServerKey}; use crate::high_level_api::re_randomization::ReRandomizationMetadata; use crate::high_level_api::traits::{FheWait, ReRandomize, Tagged}; @@ -68,7 +69,14 @@ impl std::fmt::Display for GenericIntegerBlockError { } } -pub trait FheUintId: IntegerId {} +#[cfg(not(feature = "gpu"))] +type ExpectedInnerGpu = (); +#[cfg(feature = "gpu")] +type ExpectedInnerGpu = crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; +pub trait FheUintId: + IntegerId +{ +} /// A Generic FHE unsigned integer /// @@ -144,6 +152,29 @@ impl Named for FheUint { const NAME: &'static str = "high_level_api::FheUint"; } +impl FheIntegerType for FheUint +where + Id: FheUintId, +{ + type Id = Id; + + fn on_cpu(&self) -> MaybeCloned<'_, ::InnerCpu> { + self.ciphertext.on_cpu() + } + + fn into_cpu(self) -> ::InnerCpu { + self.ciphertext.into_cpu() + } + + fn from_cpu( + inner: ::InnerCpu, + tag: Tag, + re_randomization_metadata: ReRandomizationMetadata, + ) -> Self { + Self::new(inner, tag, re_randomization_metadata) + } +} + impl Tagged for FheUint where Id: FheUintId, diff --git a/tfhe/src/high_level_api/integers/unsigned/static_.rs b/tfhe/src/high_level_api/integers/unsigned/static_.rs index 899170f27..b14cbca94 100644 --- a/tfhe/src/high_level_api/integers/unsigned/static_.rs +++ b/tfhe/src/high_level_api/integers/unsigned/static_.rs @@ -22,6 +22,18 @@ macro_rules! static_int_type { pub struct []; impl IntegerId for [] { + type InnerCpu = crate::integer::RadixCiphertext; + + #[cfg(not(feature = "gpu"))] + type InnerGpu = (); + #[cfg(feature = "gpu")] + type InnerGpu = crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; + + #[cfg(not(feature = "hpu"))] + type InnerHpu = (); + #[cfg(feature = "hpu")] + type InnerHpu = crate::integer::hpu::ciphertext::HpuRadixCiphertext; + fn num_bits() -> usize { $num_bits } diff --git a/tfhe/src/high_level_api/kv_store.rs b/tfhe/src/high_level_api/kv_store.rs new file mode 100644 index 000000000..9e129ba24 --- /dev/null +++ b/tfhe/src/high_level_api/kv_store.rs @@ -0,0 +1,759 @@ +use serde::{Deserialize, Serialize}; +use tfhe_versionable::Versionize; + +use crate::backward_compatibility::kv_store::CompressedKVStoreVersions; +use crate::high_level_api::global_state; +use crate::high_level_api::integers::FheIntegerType; +use crate::high_level_api::keys::InternalServerKey; +use crate::integer::block_decomposition::Decomposable; +use crate::integer::ciphertext::{Compressible, Expandable}; +use crate::integer::server_key::{ + CompressedKVStore as CompressedIntegerKVStore, KVStore as IntegerKVStore, +}; +use crate::prelude::CastInto; +use crate::{FheBool, IntegerId, ReRandomizationMetadata, Tag}; +use std::fmt::Display; +use std::hash::Hash; + +#[derive(Clone)] +enum InnerKVStore +where + T: FheIntegerType, +{ + Cpu(IntegerKVStore::InnerCpu>), +} + +/// The KVStore is a specialized encrypted HashMap +/// +/// * Keys are clear numbers +/// * Values are FheInt or FheUint +/// +/// This stores allows to insert, removed, get using clear keys. +/// It also allows to do some operations using encrypted keys. +/// +/// To serialize a KVStore it must first be compressed with [KVStore::compress] +/// +/// # Tag System +/// +/// Ciphertexts inserted into the KVStore will drop their tag. +/// Operations on the KVStore that return a ciphertext will set a tag +/// using the currently set server key. +/// Even operations that do not require FHE operations will require +/// a server key to be set in order to set the tag +#[derive(Clone)] +pub struct KVStore +where + T: FheIntegerType, +{ + inner: InnerKVStore, +} + +impl KVStore +where + T: FheIntegerType, +{ + /// Creates a new empty `KVStore`. + pub fn new() -> Self { + Self { + inner: InnerKVStore::Cpu(IntegerKVStore::new()), + } + } + + /// Returns the number of key-value pairs in the store. + pub fn len(&self) -> usize { + match &self.inner { + InnerKVStore::Cpu(kvstore) => kvstore.len(), + } + } + + /// Returns `true` if the store contains no key-value pairs + pub fn is_empty(&self) -> bool { + match &self.inner { + InnerKVStore::Cpu(kvstore) => kvstore.is_empty(), + } + } + + /// Inserts a key-value pair. + /// + /// Returns the old value if there was any + pub fn insert_with_clear_key(&mut self, key: Key, value: T) -> Option + where + Key: Eq + Hash, + { + #[allow(unreachable_patterns)] + global_state::with_internal_keys(|server_key| match (server_key, &mut self.inner) { + (InternalServerKey::Cpu(cpu_key), InnerKVStore::Cpu(inner_store)) => { + let inner = inner_store.insert(key, value.into_cpu())?; + Some(T::from_cpu( + inner, + cpu_key.tag.clone(), + ReRandomizationMetadata::default(), + )) + } + #[cfg(feature = "gpu")] + (InternalServerKey::Cuda(_cuda_key), _) => { + panic!("GPU does not support KVStore yet") + } + #[cfg(feature = "hpu")] + (InternalServerKey::Hpu(_device), _) => { + panic!("HPU does not support KVStore yet") + } + _ => panic!("The KVStore's current backend does not match the current key backend"), + }) + } + + /// Updates the value in a key-value pair. + /// + /// Returns the old value if there was any + /// Returns None if the key had no previous value + /// + /// If your key is encrypted see [Self::update] + /// + /// + /// # Note + /// + /// Contraty to [Self::insert_with_clear_key], this does not insert the key,value pair + /// if its not present + pub fn update_with_clear_key(&mut self, key: &Key, value: T) -> Option + where + Key: Eq + Hash, + { + #[allow(unreachable_patterns)] + global_state::with_internal_keys(|server_key| match (server_key, &mut self.inner) { + (InternalServerKey::Cpu(cpu_key), InnerKVStore::Cpu(inner_store)) => { + inner_store.get_mut(key).map_or_else( + || None, + |old_value_ref| { + let old_value = std::mem::replace(old_value_ref, value.into_cpu()); + Some(T::from_cpu( + old_value, + cpu_key.tag.clone(), + ReRandomizationMetadata::default(), + )) + }, + ) + } + #[cfg(feature = "gpu")] + (InternalServerKey::Cuda(_cuda_key), _) => { + panic!("GPU does not support KVStore yet") + } + #[cfg(feature = "hpu")] + (InternalServerKey::Hpu(_device), _) => { + panic!("HPU does not support KVStore yet") + } + _ => panic!("The KVStore's current backend does not match the current key backend"), + }) + } + + /// Removes a key-value pair. + /// + /// Returns Some(_) if the key was present, None otherwise + /// + /// # Note + /// + /// Even though no FHE computations are done, a server key must + /// be set when calling this function is order to set the Tag of the resulting ciphertext + pub fn remove_with_clear_key(&mut self, key: &Key) -> Option + where + Key: Eq + Hash, + { + #[allow(unreachable_patterns)] + global_state::with_internal_keys(|server_key| match (server_key, &mut self.inner) { + (InternalServerKey::Cpu(cpu_key), InnerKVStore::Cpu(inner_store)) => { + let inner = inner_store.remove(key)?; + Some(T::from_cpu( + inner, + cpu_key.tag.clone(), + ReRandomizationMetadata::default(), + )) + } + #[cfg(feature = "gpu")] + (InternalServerKey::Cuda(_cuda_key), _) => { + panic!("GPU does not support KVStore yet") + } + #[cfg(feature = "hpu")] + (InternalServerKey::Hpu(_device), _) => { + panic!("HPU does not support KVStore yet") + } + _ => panic!("The KVStore's current backend does not match the current key backend"), + }) + } + + /// Returns the value associated to a key. + /// + /// Returns Some(_) if the key was present, None otherwise + /// + /// If your key is encrypted see [Self::get] + /// + /// # Note + /// + /// Even though no FHE computations are done, a server key must + /// be set when calling this function is order to set the Tag of the resulting ciphertext + pub fn get_with_clear_key(&self, key: &Key) -> Option + where + Key: Eq + Hash, + { + #[allow(unreachable_patterns)] + global_state::with_internal_keys(|server_key| match (server_key, &self.inner) { + (InternalServerKey::Cpu(cpu_key), InnerKVStore::Cpu(inner_store)) => { + let inner = inner_store.get(key)?; + Some(T::from_cpu( + inner.clone(), + cpu_key.tag.clone(), + ReRandomizationMetadata::default(), + )) + } + #[cfg(feature = "gpu")] + (InternalServerKey::Cuda(_cuda_key), _) => { + panic!("GPU does not support KVStore yet") + } + #[cfg(feature = "hpu")] + (InternalServerKey::Hpu(_device), _) => { + panic!("HPU does not support KVStore yet") + } + _ => panic!("The KVStore's current backend does not match the current key backend"), + }) + } +} + +impl Default for KVStore +where + T: FheIntegerType, +{ + fn default() -> Self { + Self::new() + } +} + +impl KVStore +where + Key: Decomposable + CastInto + Hash + Eq, + T: FheIntegerType, +{ + /// Gets the value corresponding to the encrypted key. + /// + /// Returns the encrypted value and an encrypted boolean. + /// The boolean is an encryption of true if the key was present, + /// thus the value is meaningful. + /// + /// If your key is clear see [Self::get_with_clear_key] + pub fn get(&self, encrypted_key: &EK) -> (T, FheBool) + where + EK: FheIntegerType, + EK::Id: IntegerId< + InnerCpu = ::InnerCpu, + InnerGpu = ::InnerGpu, + >, + { + #[allow(unreachable_patterns)] + global_state::with_internal_keys(|key| match (key, &self.inner) { + (InternalServerKey::Cpu(cpu_key), InnerKVStore::Cpu(inner_store)) => { + let (inner_ct, inner_bool) = cpu_key + .pbs_key() + .kv_store_get(inner_store, &*encrypted_key.on_cpu()); + ( + T::from_cpu( + inner_ct, + cpu_key.tag.clone(), + ReRandomizationMetadata::default(), + ), + FheBool::new( + inner_bool, + cpu_key.tag.clone(), + ReRandomizationMetadata::default(), + ), + ) + } + #[cfg(feature = "gpu")] + (InternalServerKey::Cuda(_cuda_key), _) => { + panic!("GPU does not support KVStore yet") + } + #[cfg(feature = "hpu")] + (InternalServerKey::Hpu(_device), _) => { + panic!("HPU does not support KVStore yet") + } + _ => panic!("The KVStore's current backend does not match the current key backend"), + }) + } + + /// Replaces the value corresponding to the encrypted key. + /// + /// i.e. `kvstore[encrypted_value] = new_value` + /// + /// The boolean is an encryption of true if the key was present, + /// thus the value is was replaced. + /// + /// If your key is clear see [Self::update_with_clear_key] + pub fn update(&mut self, encrypted_key: &EK, new_value: &T) -> FheBool + where + EK: FheIntegerType, + EK::Id: IntegerId< + InnerCpu = ::InnerCpu, + InnerGpu = ::InnerGpu, + >, + { + #[allow(unreachable_patterns)] + global_state::with_internal_keys(|key| match (key, &mut self.inner) { + (InternalServerKey::Cpu(cpu_key), InnerKVStore::Cpu(inner_store)) => { + let inner = cpu_key.pbs_key().kv_store_update( + inner_store, + &*encrypted_key.on_cpu(), + &*new_value.on_cpu(), + ); + FheBool::new( + inner, + cpu_key.tag.clone(), + ReRandomizationMetadata::default(), + ) + } + #[cfg(feature = "gpu")] + (InternalServerKey::Cuda(_cuda_key), _) => { + panic!("GPU does not support KVStore yet") + } + #[cfg(feature = "hpu")] + (InternalServerKey::Hpu(_device), _) => { + panic!("HPU does not support KVStore yet") + } + _ => panic!("The KVStore's current backend does not match the current key backend"), + }) + } + + /// Replaces the value corresponding to the encrypted key, with the + /// result of applying the function to the current value. + /// + /// i.e. `kvstore[encrypted_value] = func(kvstore[encrypted_value])` + /// + /// Returns (old_value, new_value, check) + /// + /// The `check` boolean is an encryption of true if the key was present, + /// thus the value is was replaced. + pub fn map(&mut self, encrypted_key: &EK, func: F) -> (T, T, FheBool) + where + EK: FheIntegerType, + EK::Id: IntegerId< + InnerCpu = ::InnerCpu, + InnerGpu = ::InnerGpu, + >, + F: Fn(T) -> T, + { + #[allow(unreachable_patterns)] + global_state::with_internal_keys(|key| match (key, &mut self.inner) { + (InternalServerKey::Cpu(cpu_key), InnerKVStore::Cpu(inner_store)) => { + let (inner_old, inner_new, inner_bool) = cpu_key.pbs_key().kv_store_map( + inner_store, + &*encrypted_key.on_cpu(), + |radix| { + let wrapped = + T::from_cpu(radix, Tag::default(), ReRandomizationMetadata::default()); + let wrapped_result = func(wrapped); + wrapped_result.into_cpu() + }, + ); + ( + T::from_cpu( + inner_old, + cpu_key.tag.clone(), + ReRandomizationMetadata::default(), + ), + T::from_cpu( + inner_new, + cpu_key.tag.clone(), + ReRandomizationMetadata::default(), + ), + FheBool::new( + inner_bool, + cpu_key.tag.clone(), + ReRandomizationMetadata::default(), + ), + ) + } + #[cfg(feature = "gpu")] + (InternalServerKey::Cuda(_cuda_key), _) => { + panic!("GPU does not support KVStore yet") + } + #[cfg(feature = "hpu")] + (InternalServerKey::Hpu(_device), _) => { + panic!("HPU does not support KVStore yet") + } + _ => panic!("The KVStore's current backend does not match the current key backend"), + }) + } + + /// Compressed the KVStore, making it serializable + pub fn compress(&self) -> crate::Result> + where + Key: Copy + Display + Eq + Hash, + ::InnerCpu: Compressible + Clone, + { + #[allow(unreachable_patterns)] + global_state::with_internal_keys(|key| match (key, &self.inner) { + (InternalServerKey::Cpu(cpu_key), InnerKVStore::Cpu(inner_store)) => { + let comp_key = cpu_key + .key + .compression_key + .as_ref() + .ok_or(crate::high_level_api::errors::UninitializedCompressionKey)?; + let compressed_inner = inner_store.compress(comp_key); + Ok(CompressedKVStore { + inner: compressed_inner, + }) + } + #[cfg(feature = "gpu")] + (InternalServerKey::Cuda(_cuda_key), _) => { + panic!("GPU does not support KVStore yet") + } + #[cfg(feature = "hpu")] + (InternalServerKey::Hpu(_device), _) => { + panic!("HPU does not support KVStore yet") + } + _ => panic!("The KVStore's current backend does not match the current key backend"), + }) + } +} + +/// Compressed KVStore +/// +/// This type is the serializable and deserializable form of a KVStore +#[derive(Serialize, Deserialize, Versionize)] +#[versionize(CompressedKVStoreVersions)] +pub struct CompressedKVStore +where + Value: FheIntegerType, +{ + inner: CompressedIntegerKVStore::InnerCpu>, +} + +macro_rules! impl_named_for_kv_store { + ($Key:ty) => { + impl crate::named::Named for CompressedKVStore<$Key, crate::high_level_api::FheUint> + where + Id: crate::high_level_api::FheUintId, + { + const NAME: &'static str = concat!( + "high_level_api::CompressedKVStore<", + stringify!($Key), + ", high_level_api::FheUint>" + ); + } + + impl crate::named::Named for CompressedKVStore<$Key, crate::high_level_api::FheInt> + where + Id: crate::high_level_api::FheIntId, + { + const NAME: &'static str = concat!( + "high_level_api::CompressedKVStore<", + stringify!($Key), + ", high_level_api::FheInt>" + ); + } + }; +} + +impl_named_for_kv_store!(u8); +impl_named_for_kv_store!(u16); +impl_named_for_kv_store!(u32); +impl_named_for_kv_store!(u64); +impl_named_for_kv_store!(u128); +impl_named_for_kv_store!(i8); +impl_named_for_kv_store!(i16); +impl_named_for_kv_store!(i32); +impl_named_for_kv_store!(i64); +impl_named_for_kv_store!(i128); + +impl CompressedKVStore +where + Value: FheIntegerType, +{ + /// Decompressed the KVStore + /// + /// Returns an error if: + /// * A key does not have a corresponding value + /// * A value does not have the same number of blocks as the others. + /// * If the requested value type is not compatible with the data stored + /// + /// Both these errors indicate corrupted or malformed data + pub fn decompress(&self) -> crate::Result> + where + ::InnerCpu: Expandable, + Key: Copy + Display + Eq + Hash, + { + global_state::try_with_internal_keys(|key| match key { + Some(InternalServerKey::Cpu(cpu_key)) => { + let decomp_key = cpu_key + .key + .decompression_key + .as_ref() + .ok_or(crate::high_level_api::errors::UninitializedDecompressionKey)?; + let inner_kv_store = self.inner.decompress(decomp_key)?; + + let Some(actual_block_count) = inner_kv_store.blocks_per_radix() else { + return Ok(KVStore::new()); // The KVstore was empty + }; + + let expected_block_count = Value::Id::num_blocks(cpu_key.message_modulus()); + + if actual_block_count.get() != expected_block_count { + return Err(crate::error!("Inconsistent block count in KVStore: expected {expected_block_count} but got {actual_block_count}")); + } + + Ok(KVStore { + inner: InnerKVStore::Cpu(inner_kv_store), + }) + } + #[cfg(feature = "gpu")] + Some(InternalServerKey::Cuda(_cuda_key)) => { + panic!("Decompressing KVStore to GPU is not implemented yet") + } + #[cfg(feature = "hpu")] + Some(InternalServerKey::Hpu(_device)) => { + panic!("Decompressing KVStore to HPU is not implemented yet") + } + None => Err(crate::high_level_api::errors::UninitializedServerKey.into()), + }) + } +} + +#[cfg(test)] +mod test { + use std::collections::HashMap; + use std::hash::Hash; + + use crate::core_crypto::prelude::Numeric; + use crate::high_level_api::kv_store::CompressedKVStore; + use crate::prelude::*; + use crate::{ClientKey, FheInt32, FheIntegerType, FheUint32, FheUint64, FheUint8, KVStore}; + use rand::prelude::*; + + fn create_kv_store( + num_keys: usize, + ck: &ClientKey, + ) -> (KVStore, HashMap) + where + K: Numeric + CastInto + Hash + Eq, + V: Numeric, + rand::distributions::Standard: + rand::distributions::Distribution + rand::distributions::Distribution, + FheType: FheIntegerType + FheEncrypt, + { + assert!((K::MAX).cast_into() >= num_keys); + let mut rng = rand::thread_rng(); + + let mut kv_store = KVStore::new(); + let mut clear_store = HashMap::new(); + while kv_store.len() != num_keys { + let k = rng.gen::(); + let v = rng.gen::(); + + let e_v = FheType::encrypt(v, ck); + + let _ = kv_store.insert_with_clear_key(k, e_v); + let _ = clear_store.insert(k, v); + } + + assert_eq!(kv_store.len(), clear_store.len()); + + (kv_store, clear_store) + } + + fn kv_store_get_test_case(ck: &ClientKey) { + let num_keys = 10; + let num_tests = 10; + + let (kv_store, clear_store) = create_kv_store::(num_keys, ck); + let mut rng = rand::thread_rng(); + + for _ in 0..num_tests { + let k = rng.gen::(); + let e_k = FheUint8::encrypt(k, ck); + + let (e_v, e_is_some) = kv_store.get(&e_k); + let is_some = e_is_some.decrypt(ck); + let v: u32 = e_v.decrypt(ck); + + if let Some(expected_value) = clear_store.get(&k) { + assert_eq!(v, *expected_value); + assert!(is_some); + } else { + assert!(!is_some); + assert_eq!(v, 0); + } + } + } + + fn kv_store_update_test_case(ck: &ClientKey) { + let num_keys = 10; + let num_tests = 10; + + let (mut kv_store, mut clear_store) = create_kv_store::(num_keys, ck); + let mut rng = rand::thread_rng(); + + for _ in 0..num_tests { + let k = rng.gen::(); + let e_k = FheUint8::encrypt(k, ck); + + let new_value = rng.gen::(); + let e_new_value = FheUint32::encrypt(new_value, ck); + + let e_was_updated = kv_store.update(&e_k, &e_new_value); + let was_updated = e_was_updated.decrypt(ck); + + let is_contained = clear_store.contains_key(&k); + if is_contained { + let _ = clear_store.insert(k, new_value); + } + assert_eq!(was_updated, is_contained); + } + + for (k, expected_v) in clear_store.iter() { + let e_k = FheUint8::encrypt(*k, ck); + + let (e_v, e_is_some) = kv_store.get(&e_k); + let is_some = e_is_some.decrypt(ck); + let v: u32 = e_v.decrypt(ck); + assert!(is_some); + assert_eq!(v, *expected_v); + } + } + + fn kv_store_map_test_case(ck: &ClientKey) { + let num_keys = 10; + let num_tests = 10; + + let (mut kv_store, mut clear_store) = create_kv_store::(num_keys, ck); + let mut rng = rand::thread_rng(); + + for _ in 0..num_tests { + let k = rng.gen::(); + let e_k = FheUint8::encrypt(k, ck); + + let expected_new_value = rng.gen::(); + + let (e_old_value, e_new_value, e_was_updated) = + kv_store.map(&e_k, |_old| FheUint32::encrypt(expected_new_value, ck)); + let was_updated = e_was_updated.decrypt(ck); + let new_value: u32 = e_new_value.decrypt(ck); + let old_value: u32 = e_old_value.decrypt(ck); + + if let Some(expected_old_value) = clear_store.get(&k).copied() { + assert_eq!(old_value, expected_old_value); + let _ = clear_store.insert(k, expected_new_value); + assert_eq!(new_value, expected_new_value); + assert!(was_updated); + } else { + assert!(!was_updated); + } + } + + for (k, expected_v) in clear_store.iter() { + let e_k = FheUint8::encrypt(*k, ck); + + let (e_v, e_is_some) = kv_store.get(&e_k); + let is_some = e_is_some.decrypt(ck); + let v: u32 = e_v.decrypt(ck); + assert!(is_some); + assert_eq!(v, *expected_v); + } + } + + fn kv_store_serialization_test_case(ck: &ClientKey) { + let num_keys = 10; + + let (kv_store, clear_store) = create_kv_store::(num_keys, ck); + + let compressed = kv_store.compress().unwrap(); + + let mut data = vec![]; + crate::safe_serialization::safe_serialize(&compressed, &mut data, 1 << 30).unwrap(); + + // Key type is incorrect + let maybe_compressed = crate::safe_serialization::safe_deserialize::< + CompressedKVStore, + >(data.as_slice(), 1 << 30); + // safe_deserialize catch the error + assert!(maybe_compressed.is_err()); + + let maybe_compressed = crate::safe_serialization::safe_deserialize::< + CompressedKVStore, + >(data.as_slice(), 1 << 30); + assert!(maybe_compressed.is_err()); + + // Invalid value types + let compressed = crate::safe_serialization::safe_deserialize::< + CompressedKVStore, + >(data.as_slice(), 1 << 30) + .unwrap(); + assert!(compressed.decompress().is_err()); + + let compressed = crate::safe_serialization::safe_deserialize::< + CompressedKVStore, + >(data.as_slice(), 1 << 30) + .unwrap(); + assert!(compressed.decompress().is_err()); + + let compressed = crate::safe_serialization::safe_deserialize::< + CompressedKVStore, + >(data.as_slice(), 1 << 30) + .unwrap(); + + let kv_store = compressed.decompress().unwrap(); + + for (k, expected_v) in clear_store.iter() { + let e_k = FheUint8::encrypt(*k, ck); + + let (e_v, e_is_some) = kv_store.get(&e_k); + let is_some = e_is_some.decrypt(ck); + let v: u32 = e_v.decrypt(ck); + assert!(is_some); + assert_eq!(v, *expected_v); + } + } + + mod cpu { + use crate::shortint::parameters::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + use crate::{set_server_key, ConfigBuilder}; + + use super::*; + + pub(crate) fn setup_default_cpu() -> ClientKey { + let config = ConfigBuilder::default() + .enable_compression(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128) + .build(); + + let client_key = ClientKey::generate(config); + let csks = crate::CompressedServerKey::new(&client_key); + let server_key = csks.decompress(); + + set_server_key(server_key); + + client_key + } + + #[test] + fn test_kv_store_get() { + let ck = setup_default_cpu(); + + kv_store_get_test_case(&ck); + } + + #[test] + fn test_kv_store_update() { + let ck = setup_default_cpu(); + + kv_store_update_test_case(&ck); + } + + #[test] + fn test_kv_store_map() { + let ck = setup_default_cpu(); + + kv_store_map_test_case(&ck); + } + + #[test] + fn test_kv_store_serialization() { + let ck = setup_default_cpu(); + + kv_store_serialization_test_case(&ck); + } + } +} diff --git a/tfhe/src/high_level_api/mod.rs b/tfhe/src/high_level_api/mod.rs index 7639b155b..d39804754 100644 --- a/tfhe/src/high_level_api/mod.rs +++ b/tfhe/src/high_level_api/mod.rs @@ -58,8 +58,8 @@ pub use global_state::CustomMultiGpuIndexes; pub use global_state::{set_server_key, unset_server_key, with_server_key_as_context}; pub use integers::{ - CompressedFheInt, CompressedFheUint, FheInt, FheUint, IntegerId, SquashedNoiseFheInt, - SquashedNoiseFheUint, + CompressedFheInt, CompressedFheUint, FheInt, FheIntId, FheIntegerType, FheUint, FheUintId, + IntegerId, SquashedNoiseFheInt, SquashedNoiseFheUint, }; #[cfg(feature = "gpu")] pub use keys::CudaServerKey; @@ -141,6 +141,8 @@ pub use tag::Tag; pub use traits::FheId; pub mod xof_key_set; +pub use kv_store::KVStore; + mod booleans; mod compressed_ciphertext_list; mod config; @@ -160,6 +162,7 @@ mod gpu_utils; pub mod array; pub mod backward_compatibility; mod compact_list; +mod kv_store; mod tag; #[cfg(feature = "gpu")] diff --git a/tfhe/src/integer/backward_compatibility/ciphertext/mod.rs b/tfhe/src/integer/backward_compatibility/ciphertext/mod.rs index ac662abf3..4f74fdd94 100644 --- a/tfhe/src/integer/backward_compatibility/ciphertext/mod.rs +++ b/tfhe/src/integer/backward_compatibility/ciphertext/mod.rs @@ -151,6 +151,6 @@ pub enum SquashedNoiseBooleanBlockVersions { } #[derive(VersionsDispatch)] -pub enum CompressedKVStoreVersions { - V0(CompressedKVStore), +pub enum CompressedKVStoreVersions { + V0(CompressedKVStore), } diff --git a/tfhe/src/integer/server_key/radix_parallel/kv_store.rs b/tfhe/src/integer/server_key/radix_parallel/kv_store.rs index b13fc0b64..05c27966d 100644 --- a/tfhe/src/integer/server_key/radix_parallel/kv_store.rs +++ b/tfhe/src/integer/server_key/radix_parallel/kv_store.rs @@ -1,5 +1,5 @@ use crate::integer::backward_compatibility::ciphertext::CompressedKVStoreVersions; -use crate::integer::block_decomposition::{Decomposable, DecomposableInto}; +use crate::integer::block_decomposition::Decomposable; use crate::integer::ciphertext::{ CompressedCiphertextList, CompressedCiphertextListBuilder, Compressible, Expandable, }; @@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt::Display; use std::hash::Hash; +use std::marker::PhantomData; use std::num::NonZeroUsize; use tfhe_versionable::Versionize; @@ -25,6 +26,7 @@ use tfhe_versionable::Versionize; /// /// /// To serialize a KVStore it must first be compressed with [KVStore::compress] +#[derive(Clone)] pub struct KVStore { data: HashMap, block_count: Option, @@ -50,6 +52,17 @@ impl KVStore { self.data.get(key) } + /// Returns the value stored for the key if any + /// + /// Key is in clear, see [ServerKey::kv_store_get] if you wish to + /// query using an encrypted key + pub fn get_mut(&mut self, key: &Key) -> Option<&mut Ct> + where + Key: Eq + Hash, + { + self.data.get_mut(key) + } + /// Inserts the value for the key /// /// Returns the previous value stored for the key if there was any @@ -84,6 +97,22 @@ impl KVStore { self.data.insert(key, value) } + /// Removes a key-value pair. + pub fn remove(&mut self, key: &Key) -> Option + where + Key: Eq + Hash, + { + self.data.remove(key) + } + + /// Returns the value associated to the key given in clear + pub fn clear_get(&self, key: &Key) -> Option<&Ct> + where + Key: Eq + Hash, + { + self.data.get(key) + } + /// Returns the number of key-value pairs currently stored pub fn len(&self) -> usize { self.data.len() @@ -109,6 +138,10 @@ impl KVStore { { self.data.par_iter().map(|(k, _)| k) } + + pub(crate) fn blocks_per_radix(&self) -> Option { + self.block_count + } } impl Default for KVStore @@ -121,127 +154,6 @@ where } impl ServerKey { - /// Internal function used to perform a binary operation - /// on an entry. - /// - /// `encrypted_key`: The key of the slot - /// `func`: function that receives to arguments: - /// * A boolean block that encrypts `true` if the corresponding key is the same as the - /// `encrypted_key` - /// * a `& mut` to the ciphertext which stores the value - fn kv_store_binary_op_to_slot( - &self, - map: &mut KVStore, - encrypted_key: &Ct, - func: F, - ) where - Ct: IntegerRadixCiphertext, - Key: Decomposable + CastInto + Hash + Eq, - F: Fn(&BooleanBlock, &mut Ct) + Sync + Send, - { - let kv_vec: Vec<(&Key, &mut Ct)> = map.data.iter_mut().collect(); - - // For each clear key, get a boolean ciphertext that tells if it's - // equal to the encrypted key - let selectors = - self.compute_equality_selectors(encrypted_key, kv_vec.par_iter().map(|(k, _v)| **k)); - - kv_vec - .into_par_iter() - .zip(selectors.par_iter()) - .for_each(|((_k, current_ct), selector)| func(selector, current_ct)); - } - - /// Performs an addition on an entry of the store - /// - /// `map[encrypted_key] += value` - /// - /// This finds the value that corresponds to the given `encrypted_key ` - /// and adds `value` to it. - pub fn kv_store_add_to_slot( - &self, - map: &mut KVStore, - encrypted_key: &Ct, - value: &Ct, - ) where - Ct: IntegerRadixCiphertext, - Key: Decomposable + CastInto + Hash + Eq, - { - self.kv_store_binary_op_to_slot(map, encrypted_key, |selector, v| { - let mut ct_to_add = value.clone(); - self.zero_out_if_condition_is_false(&mut ct_to_add, &selector.0); - self.add_assign_parallelized(v, &ct_to_add); - }); - } - - /// Performs an addition by a clear on an entry of the store - /// - /// `map[encrypted_key] += value` - /// - /// This finds the value that corresponds to the given `encrypted_key ` - /// and adds `value` to it. - pub fn kv_store_scalar_add_to_slot( - &self, - map: &mut KVStore, - encrypted_key: &Ct, - value: Clear, - ) where - Ct: IntegerRadixCiphertext, - Key: Decomposable + CastInto + Hash + Eq, - Clear: DecomposableInto, - { - self.kv_store_binary_op_to_slot(map, encrypted_key, |selector, v| { - let ct_to_add = - self.scalar_cmux_parallelized(selector, value, Clear::ZERO, v.blocks().len()); - self.add_assign_parallelized(v, &ct_to_add); - }); - } - - /// Performs a subtraction on an entry of the store - /// - /// `map[encrypted_key] -= value` - /// - /// This finds the value that corresponds to the given `encrypted_key`, - /// and subtracts `value` to it. - pub fn kv_store_sub_to_slot( - &self, - map: &mut KVStore, - encrypted_key: &Ct, - value: &Ct, - ) where - Ct: IntegerRadixCiphertext, - Key: Decomposable + CastInto + Hash + Eq, - { - self.kv_store_binary_op_to_slot(map, encrypted_key, |selector, v| { - let mut ct_to_sub = value.clone(); - self.zero_out_if_condition_is_false(&mut ct_to_sub, &selector.0); - self.sub_assign_parallelized(v, &ct_to_sub); - }); - } - - /// Performs a multiplication on an entry of the store - /// - /// `map[encrypted_key] *= value` - /// - /// This finds the value that corresponds to the given `encrypted_key`, - /// and multiplies it by `value`. - pub fn kv_store_mul_to_slot( - &self, - map: &mut KVStore, - encrypted_key: &Ct, - value: &Ct, - ) where - Ct: IntegerRadixCiphertext, - Key: Decomposable + CastInto + Hash + Eq, - Self: for<'a> ServerKeyDefaultCMux, - { - self.kv_store_binary_op_to_slot(map, encrypted_key, |selector, v| { - let selector = self.boolean_bitnot(selector); - let ct_to_mul = self.if_then_else_parallelized(&selector, 1u64, value); - self.mul_assign_parallelized(v, &ct_to_mul); - }); - } - /// Implementation of the get function that additionally returns the Vec of selectors /// so it can be reused to avoid re-computing it. fn kv_store_get_impl( @@ -350,21 +262,21 @@ impl ServerKey { /// This finds the value that corresponds to the given `encrypted_key`, then /// calls `func` then updates the value stored with the one returned by the `func`. /// - /// Returns the new value and a boolean block that encrypts `true` if an entry for - /// the `encrypted_key` was found. + /// Returns the (old_value, new_value, check_block) where `check_block` encrypts `true` if an + /// entry for the `encrypted_key` was found. pub fn kv_store_map( &self, map: &mut KVStore, encrypted_key: &Ct, func: F, - ) -> (Ct, BooleanBlock) + ) -> (Ct, Ct, BooleanBlock) where Ct: IntegerRadixCiphertext, Key: Decomposable + CastInto + Hash + Eq, F: Fn(Ct) -> Ct, { - let (result, check_block, selectors) = self.kv_store_get_impl(map, encrypted_key); - let new_value = func(result); + let (old_value, check_block, selectors) = self.kv_store_get_impl(map, encrypted_key); + let new_value = func(old_value.clone()); let kv_vec: Vec<(&Key, &mut Ct)> = map.data.iter_mut().collect(); kv_vec @@ -374,17 +286,17 @@ impl ServerKey { *old_value = self.if_then_else_parallelized(s, &new_value, old_value); }); - (new_value, check_block) + (old_value, new_value, check_block) } } impl KVStore where Key: Copy, - Ct: Compressible + Clone, + Ct: IntegerRadixCiphertext + Compressible + Clone, { /// Compress the KVStore to be able to serialize it - pub fn compress(&self, compression_key: &CompressionKey) -> CompressedKVStore { + pub fn compress(&self, compression_key: &CompressionKey) -> CompressedKVStore { let mut builder = CompressedCiphertextListBuilder::new(); let mut keys = Vec::with_capacity(self.data.len()); for (key, value) in self.data.iter() { @@ -394,7 +306,7 @@ where let values = builder.build(compression_key); - CompressedKVStore { keys, values } + CompressedKVStore::new(keys, values) } } @@ -403,34 +315,54 @@ where /// This type is the serializable and deserializable form of a KVStore #[derive(Serialize, Deserialize, Versionize)] #[versionize(CompressedKVStoreVersions)] -pub struct CompressedKVStore { +pub struct CompressedKVStore { keys: Vec, values: CompressedCiphertextList, + is_signed: bool, + _v: PhantomData, } -impl CompressedKVStore +impl CompressedKVStore where - Key: Copy + Display + Eq + Hash, + Value: Expandable + IntegerRadixCiphertext, { + fn new(keys: Vec, compressed_values: CompressedCiphertextList) -> Self { + Self { + keys, + values: compressed_values, + is_signed: Value::IS_SIGNED, + _v: PhantomData, + } + } /// Decompressed the KVStore /// /// Returns an error if: + /// * The requested value type does not have the same signedness as the stored one /// * A key does not have a corresponding value /// * A value (which is a radix ciphertext) does not have the same number of blocks as the /// others. /// /// Both these errors indicate corrupted or malformed data - pub fn decompress( + pub fn decompress( &self, decompression_key: &DecompressionKey, - ) -> crate::Result> + ) -> crate::Result> where - Ct: Expandable + IntegerRadixCiphertext, + Key: Copy + Display + Eq + Hash, { + if Value::IS_SIGNED != self.is_signed { + let requested = if Value::IS_SIGNED { "Signed" } else { "" }; + let stored = if self.is_signed { "Signed" } else { "" }; + return Err(crate::error!( + "Requested value type does not have signed.\ + Requested '{requested}RadixCiphertext' but stored '{stored}RadixCiphertext'" + )); + } + let mut block_count = None; let mut store = KVStore::new(); for (i, key) in self.keys.iter().enumerate() { - let value: Ct = self + let value: Value = self .values .get(i, decompression_key)? .ok_or_else(|| crate::error!("Missing value for key '{key}'"))?; @@ -454,9 +386,22 @@ where macro_rules! impl_named_for_kv_store { ($Key:ty) => { - impl crate::named::Named for CompressedKVStore<$Key> { - const NAME: &'static str = - concat!("integer::CompressedKVStore<", stringify!($Key), ">"); + impl crate::named::Named for CompressedKVStore<$Key, crate::integer::RadixCiphertext> { + const NAME: &'static str = concat!( + "integer::CompressedKVStore<", + stringify!($Key), + ", integer::RadixCiphertext>" + ); + } + + impl crate::named::Named + for CompressedKVStore<$Key, crate::integer::SignedRadixCiphertext> + { + const NAME: &'static str = concat!( + "integer::CompressedKVStore<", + stringify!($Key), + ", integer::SignedRadixCiphertext>" + ); } }; } @@ -547,7 +492,7 @@ mod tests { let mut data = vec![]; crate::safe_serialization::safe_serialize(&compressed, &mut data, 1 << 20).unwrap(); - let compressed: CompressedKVStore = + let compressed: CompressedKVStore = crate::safe_serialization::safe_deserialize(data.as_slice(), 1 << 20).unwrap(); let kv_store = compressed.decompress(&decompression_key).unwrap(); assert_store_unsigned_matches(&clear_store, &kv_store, &cks); @@ -614,7 +559,7 @@ mod tests { let mut data = vec![]; crate::safe_serialization::safe_serialize(&compressed, &mut data, 1 << 20).unwrap(); - let compressed: CompressedKVStore = + let compressed: CompressedKVStore = crate::safe_serialization::safe_deserialize(data.as_slice(), 1 << 20).unwrap(); let kv_store = compressed.decompress(&decompression_key).unwrap(); assert_store_signed_matches(&clear_store, &kv_store, &cks); diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_kv_store.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_kv_store.rs index 51e1cf9f3..4268a7ca2 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_kv_store.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_kv_store.rs @@ -9,63 +9,6 @@ use crate::shortint::parameters::{TestParameters, *}; use std::collections::BTreeMap; use std::sync::Arc; -create_parameterized_test!( - integer_default_kv_store_add - { - coverage => { - COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS - }, - no_coverage => { - TEST_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128, - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, - TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128, - // 2M128 is too slow for 4_4, it is estimated to be 2x slower - TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, - TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64, - TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, - TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, - } - } -); -create_parameterized_test!( - integer_default_kv_store_sub - { - coverage => { - COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS - }, - no_coverage => { - TEST_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128, - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, - TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128, - // 2M128 is too slow for 4_4, it is estimated to be 2x slower - TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, - TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64, - TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, - TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, - } - } -); -create_parameterized_test!( - integer_default_kv_store_mul - { - coverage => { - COVERAGE_PARAM_MESSAGE_2_CARRY_2_KS_PBS, - COVERAGE_PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS - }, - no_coverage => { - TEST_PARAM_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M128, - PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128, - TEST_PARAM_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M128, - // 2M128 is too slow for 4_4, it is estimated to be 2x slower - TEST_PARAM_MESSAGE_4_CARRY_4_KS_PBS_GAUSSIAN_2M64, - TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64, - TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, - TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, - } - } -); create_parameterized_test!( integer_default_kv_store_get_update { @@ -105,21 +48,6 @@ create_parameterized_test!( } ); -fn integer_default_kv_store_add(params: impl Into) { - let executor = CpuFunctionExecutor::new(&ServerKey::kv_store_add_to_slot); - default_kv_store_add_test(params, executor); -} - -fn integer_default_kv_store_sub(params: impl Into) { - let executor = CpuFunctionExecutor::new(&ServerKey::kv_store_sub_to_slot); - default_kv_store_sub_test(params, executor); -} - -fn integer_default_kv_store_mul(params: impl Into) { - let executor = CpuFunctionExecutor::new(&ServerKey::kv_store_mul_to_slot); - default_kv_store_mul_test(params, executor); -} - fn integer_default_kv_store_get_update(params: impl Into) { let get_executor = CpuFunctionExecutor::new(&ServerKey::kv_store_get); let update_executor = CpuFunctionExecutor::new(&ServerKey::kv_store_update); @@ -139,196 +67,10 @@ fn integer_default_kv_store_map(params: impl Into) { pub type KeyType = u8; -fn default_kv_store_add_test(params: P, mut kv_store_add: T) -where - P: Into, - T: for<'a> FunctionExecutor< - ( - &'a mut KVStore, - &'a RadixCiphertext, - &'a RadixCiphertext, - ), - (), - >, -{ - let params = params.into(); - let (cks, mut sks) = KEY_CACHE.get_from_params(params, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let nb_blocks_key = get_num_block_for_key(params.message_modulus()); - - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32); - - kv_store_add.setup(&cks, sks); - - let num_keys = 20usize; - let (mut map, mut clear_store) = create_filled_stores(num_keys, modulus, &cks); - - // Test modifying a key that does not exist - for _ in 0..num_keys.div_ceil(2) { - let key = generate_unused_key(&clear_store); - let encrypted_key = cks.as_ref().encrypt_radix(key, nb_blocks_key); - - let value_to_add = rand::random::() % modulus; - let encrypted_value_to_add: RadixCiphertext = cks.encrypt(value_to_add); - kv_store_add.execute((&mut map, &encrypted_key, &encrypted_value_to_add)); - - panic_if_not_the_same(&map, &clear_store, &cks); - } - - // Test modifying a key that exists - for _ in 0..num_keys.div_ceil(2) { - let key_index = rand::random::() % num_keys; - let key_target = *clear_store.iter().nth(key_index).unwrap().0; - let encrypted_key = cks.as_ref().encrypt_radix(key_target, nb_blocks_key); - - let value_to_add = rand::random::() % modulus; - let encrypted_value_to_add: RadixCiphertext = cks.encrypt(value_to_add); - - kv_store_add.execute((&mut map, &encrypted_key, &encrypted_value_to_add)); - - let new_value = clear_store - .get(&key_target) - .map(|v| v.wrapping_add(value_to_add) % modulus) - .unwrap(); - clear_store.insert(key_target, new_value); - - panic_if_not_properly_updated(&map, &clear_store, key_target, &cks) - } -} - fn get_num_block_for_key(msg_mod: MessageModulus) -> usize { KeyType::BITS.div_ceil(msg_mod.0.ilog2()) as usize } -fn default_kv_store_sub_test(params: P, mut kv_store_sub: T) -where - P: Into, - T: for<'a> FunctionExecutor< - ( - &'a mut KVStore, - &'a RadixCiphertext, - &'a RadixCiphertext, - ), - (), - >, -{ - let params = params.into(); - let (cks, mut sks) = KEY_CACHE.get_from_params(params, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let nb_blocks_key = get_num_block_for_key(params.message_modulus()); - - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32); - - kv_store_sub.setup(&cks, sks); - - let num_keys = 20usize; - let (mut map, mut clear_store) = create_filled_stores(num_keys, modulus, &cks); - - // Test modifying a key that does not exist - for _ in 0..num_keys.div_ceil(2) { - let key = generate_unused_key(&clear_store); - let encrypted_key = cks.as_ref().encrypt_radix(key, nb_blocks_key); - - let value_to_add = rand::random::() % modulus; - let encrypted_value_to_add: RadixCiphertext = cks.encrypt(value_to_add); - kv_store_sub.execute((&mut map, &encrypted_key, &encrypted_value_to_add)); - - panic_if_not_the_same(&map, &clear_store, &cks); - } - - // Test modifying a key that exists - for _ in 0..num_keys.div_ceil(2) { - let key_index = rand::random::() % num_keys; - let key_target = *clear_store.iter().nth(key_index).unwrap().0; - let encrypted_key = cks.as_ref().encrypt_radix(key_target, nb_blocks_key); - - let value_to_sub = rand::random::() % modulus; - let encrypted_value_to_add: RadixCiphertext = cks.encrypt(value_to_sub); - - kv_store_sub.execute((&mut map, &encrypted_key, &encrypted_value_to_add)); - - let new_value = clear_store - .get(&key_target) - .map(|v| v.wrapping_sub(value_to_sub) % modulus) - .unwrap(); - clear_store.insert(key_target, new_value); - - panic_if_not_properly_updated(&map, &clear_store, key_target, &cks) - } -} - -fn default_kv_store_mul_test(params: P, mut kv_store_mul: T) -where - P: Into, - T: for<'a> FunctionExecutor< - ( - &'a mut KVStore, - &'a RadixCiphertext, - &'a RadixCiphertext, - ), - (), - >, -{ - let params = params.into(); - let (cks, mut sks) = KEY_CACHE.get_from_params(params, IntegerKeyKind::Radix); - let cks = RadixClientKey::from((cks, NB_CTXT)); - - sks.set_deterministic_pbs_execution(true); - let sks = Arc::new(sks); - - let nb_blocks_key = get_num_block_for_key(params.message_modulus()); - - // message_modulus^vec_length - let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32); - - kv_store_mul.setup(&cks, sks); - - let num_keys = 20usize; - let (mut map, mut clear_store) = create_filled_stores(num_keys, modulus, &cks); - - // Test modifying a key that does not exist - for _ in 0..num_keys.div_ceil(2) { - let key = generate_unused_key(&clear_store); - let encrypted_key = cks.as_ref().encrypt_radix(key, nb_blocks_key); - - let value_to_add = rand::random::() % modulus; - let encrypted_value_to_add: RadixCiphertext = cks.encrypt(value_to_add); - kv_store_mul.execute((&mut map, &encrypted_key, &encrypted_value_to_add)); - - panic_if_not_the_same(&map, &clear_store, &cks); - } - - // Test modifying a key that exists - for _ in 0..num_keys.div_ceil(2) { - let key_index = rand::random::() % num_keys; - let key_target = *clear_store.iter().nth(key_index).unwrap().0; - let encrypted_key = cks.as_ref().encrypt_radix(key_target, nb_blocks_key); - - let value_to_mul = rand::random::() % modulus; - let encrypted_value_to_add: RadixCiphertext = cks.encrypt(value_to_mul); - - kv_store_mul.execute((&mut map, &encrypted_key, &encrypted_value_to_add)); - - let new_value = clear_store - .get(&key_target) - .map(|v| v.wrapping_mul(value_to_mul) % modulus) - .unwrap(); - clear_store.insert(key_target, new_value); - - panic_if_not_properly_updated(&map, &clear_store, key_target, &cks) - } -} - fn default_kv_store_get_update_test( params: P, mut kv_store_get: T1, @@ -415,7 +157,7 @@ where &'a RadixCiphertext, &'a dyn Fn(RadixCiphertext) -> RadixCiphertext, ), - (RadixCiphertext, BooleanBlock), + (RadixCiphertext, RadixCiphertext, BooleanBlock), >, { let params = params.into(); @@ -448,7 +190,7 @@ where let key = generate_unused_key(&clear_store); let encrypted_key = cks.as_ref().encrypt_radix(key, nb_blocks_key); - let (_, is_some) = kv_store_map.execute((&mut map, &encrypted_key, &function)); + let (_, _, is_some) = kv_store_map.execute((&mut map, &encrypted_key, &function)); assert!(!cks.decrypt_bool(&is_some)); panic_if_not_the_same(&map, &clear_store, &cks); @@ -460,19 +202,21 @@ where let key_target = *clear_store.iter().nth(key_index).unwrap().0; let encrypted_key = cks.as_ref().encrypt_radix(key_target, nb_blocks_key); - let new_value = clear_store + let (old_value, new_value) = clear_store .get(&key_target) .copied() - .map(clear_function) + .map(|old_value| (old_value, clear_function(old_value))) .unwrap(); - let (result, is_some) = kv_store_map.execute(( + let (e_old_value, e_new_value, is_some) = kv_store_map.execute(( &mut map, &encrypted_key, &function as &dyn Fn(RadixCiphertext) -> RadixCiphertext, )); assert!(cks.decrypt_bool(&is_some)); - assert_eq!(cks.decrypt::(&result), new_value); + assert_eq!(cks.decrypt::(&e_new_value), new_value); + assert_eq!(cks.decrypt::(&e_old_value), old_value); + clear_store.insert(key_target, new_value); panic_if_not_properly_updated(&map, &clear_store, key_target, &cks);