From 304932a861c051e1033cd299dacd11015b7488c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Test=C3=A9?= Date: Wed, 9 Aug 2023 15:33:39 +0200 Subject: [PATCH] chore(tfhe): refactor keycache in its own crate The general parts of shortint keycache has been moved to its own crate. This enable boolean layer to get access to traits without having to import shortint::keycache module. --- Makefile | 14 +- tfhe/benches/core_crypto/pbs_bench.rs | 2 +- tfhe/benches/integer/bench.rs | 2 +- tfhe/benches/shortint/bench.rs | 5 +- .../examples/utilities/generates_test_keys.rs | 3 +- .../utilities/hlapi_compact_pk_ct_sizes.rs | 2 +- tfhe/examples/utilities/params_to_file.rs | 2 +- tfhe/examples/utilities/shortint_key_sizes.rs | 3 +- .../utilities/wasm_benchmarks_parser.rs | 4 +- tfhe/src/boolean/parameters/mod.rs | 4 +- tfhe/src/keycache/mod.rs | 220 +++++++++++++++++ tfhe/src/lib.rs | 5 + tfhe/src/shortint/keycache.rs | 223 +----------------- 13 files changed, 249 insertions(+), 240 deletions(-) diff --git a/Makefile b/Makefile index 782879118..3549764db 100644 --- a/Makefile +++ b/Makefile @@ -100,7 +100,7 @@ clippy_core: install_rs_check_toolchain .PHONY: clippy_boolean # Run clippy lints enabling the boolean features clippy_boolean: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \ - --features=$(TARGET_ARCH_FEATURE),boolean,shortint,internal-keycache \ + --features=$(TARGET_ARCH_FEATURE),boolean \ -p tfhe -- --no-deps -D warnings .PHONY: clippy_shortint # Run clippy lints enabling the shortint features @@ -118,19 +118,19 @@ clippy_integer: install_rs_check_toolchain .PHONY: clippy # Run clippy lints enabling the boolean, shortint, integer clippy: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \ - --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer \ -p tfhe -- --no-deps -D warnings .PHONY: clippy_c_api # Run clippy lints enabling the boolean, shortint and the C API clippy_c_api: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \ - --features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,internal-keycache \ + --features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api \ -p tfhe -- --no-deps -D warnings .PHONY: clippy_js_wasm_api # Run clippy lints enabling the boolean, shortint, integer and the js wasm API clippy_js_wasm_api: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \ - --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api,internal-keycache \ + --features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api \ -p tfhe -- --no-deps -D warnings .PHONY: clippy_tasks # Run clippy lints on helper tasks crate. @@ -141,7 +141,7 @@ clippy_tasks: .PHONY: clippy_all_targets # Run clippy lints on all targets (benches, examples, etc.) clippy_all_targets: RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \ - --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache \ + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer \ -p tfhe -- --no-deps -D warnings .PHONY: clippy_all # Run all clippy targets @@ -199,7 +199,7 @@ build_tfhe_full: install_rs_build_toolchain .PHONY: build_c_api # Build the C API for boolean, shortint and integer build_c_api: install_rs_check_toolchain RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) build --profile $(CARGO_PROFILE) \ - --features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,high-level-c-api \ + --features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,high-level-c-api, \ -p tfhe .PHONY: build_c_api_experimental_deterministic_fft # Build the C API for boolean, shortint and integer with experimental deterministic FFT @@ -349,7 +349,7 @@ docs: doc lint_doc: install_rs_check_toolchain RUSTDOCFLAGS="--html-in-header katex-header.html -Dwarnings" \ cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" doc \ - --features=$(TARGET_ARCH_FEATURE),internal-keycache,boolean,shortint,integer --no-deps + --features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer --no-deps .PHONY: lint_docs # Build rust doc with linting enabled alias for lint_doc lint_docs: lint_doc diff --git a/tfhe/benches/core_crypto/pbs_bench.rs b/tfhe/benches/core_crypto/pbs_bench.rs index 105e95fb9..13644bde1 100644 --- a/tfhe/benches/core_crypto/pbs_bench.rs +++ b/tfhe/benches/core_crypto/pbs_bench.rs @@ -9,7 +9,7 @@ use tfhe::boolean::parameters::{ BooleanParameters, DEFAULT_PARAMETERS, PARAMETERS_ERROR_PROB_2_POW_MINUS_165, }; use tfhe::core_crypto::prelude::*; -use tfhe::shortint::keycache::NamedParam; +use tfhe::keycache::NamedParam; use tfhe::shortint::parameters::*; use tfhe::shortint::ClassicPBSParameters; diff --git a/tfhe/benches/integer/bench.rs b/tfhe/benches/integer/bench.rs index 7363b75f6..34756a348 100644 --- a/tfhe/benches/integer/bench.rs +++ b/tfhe/benches/integer/bench.rs @@ -13,7 +13,7 @@ use rand::Rng; use std::vec::IntoIter; use tfhe::integer::keycache::KEY_CACHE; use tfhe::integer::{RadixCiphertext, ServerKey}; -use tfhe::shortint::keycache::NamedParam; +use tfhe::keycache::NamedParam; #[allow(unused_imports)] use tfhe::shortint::parameters::{ diff --git a/tfhe/benches/shortint/bench.rs b/tfhe/benches/shortint/bench.rs index 5c03c928c..2e7523583 100644 --- a/tfhe/benches/shortint/bench.rs +++ b/tfhe/benches/shortint/bench.rs @@ -5,14 +5,13 @@ use crate::utilities::{write_to_json, OperatorType}; use std::env; use criterion::{criterion_group, Criterion}; -use tfhe::shortint::keycache::NamedParam; +use tfhe::keycache::NamedParam; use tfhe::shortint::parameters::*; use tfhe::shortint::{Ciphertext, ClassicPBSParameters, ServerKey, ShortintParameterSet}; use rand::Rng; -use tfhe::shortint::keycache::KEY_CACHE; +use tfhe::shortint::keycache::{KEY_CACHE, KEY_CACHE_WOPBS}; -use tfhe::shortint::keycache::KEY_CACHE_WOPBS; use tfhe::shortint::parameters::parameters_wopbs::WOPBS_PARAM_MESSAGE_4_NORM2_6_KS_PBS; const SERVER_KEY_BENCH_PARAMS: [ClassicPBSParameters; 4] = [ diff --git a/tfhe/examples/utilities/generates_test_keys.rs b/tfhe/examples/utilities/generates_test_keys.rs index 5598557d6..7303b1b76 100644 --- a/tfhe/examples/utilities/generates_test_keys.rs +++ b/tfhe/examples/utilities/generates_test_keys.rs @@ -1,5 +1,6 @@ use clap::{Arg, ArgAction, Command}; -use tfhe::shortint::keycache::{NamedParam, KEY_CACHE, KEY_CACHE_WOPBS}; +use tfhe::keycache::NamedParam; +use tfhe::shortint::keycache::{KEY_CACHE, KEY_CACHE_WOPBS}; use tfhe::shortint::parameters::parameters_wopbs_message_carry::{ WOPBS_PARAM_MESSAGE_1_CARRY_1_KS_PBS, WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS, WOPBS_PARAM_MESSAGE_3_CARRY_3_KS_PBS, WOPBS_PARAM_MESSAGE_4_CARRY_4_KS_PBS, diff --git a/tfhe/examples/utilities/hlapi_compact_pk_ct_sizes.rs b/tfhe/examples/utilities/hlapi_compact_pk_ct_sizes.rs index 19f72aa32..dac914ddf 100644 --- a/tfhe/examples/utilities/hlapi_compact_pk_ct_sizes.rs +++ b/tfhe/examples/utilities/hlapi_compact_pk_ct_sizes.rs @@ -7,8 +7,8 @@ use std::fs::{File, OpenOptions}; use std::io::Write; use std::path::Path; use tfhe::integer::U256; +use tfhe::keycache::NamedParam; use tfhe::prelude::*; -use tfhe::shortint::keycache::NamedParam; use tfhe::shortint::parameters::{ PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS, PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS, }; diff --git a/tfhe/examples/utilities/params_to_file.rs b/tfhe/examples/utilities/params_to_file.rs index e5f82513f..cec9f2cc4 100644 --- a/tfhe/examples/utilities/params_to_file.rs +++ b/tfhe/examples/utilities/params_to_file.rs @@ -4,7 +4,7 @@ use std::path::Path; use tfhe::boolean::parameters::{BooleanParameters, VEC_BOOLEAN_PARAM}; use tfhe::core_crypto::commons::dispersion::StandardDev; use tfhe::core_crypto::commons::parameters::{GlweDimension, LweDimension, PolynomialSize}; -use tfhe::shortint::keycache::NamedParam; +use tfhe::keycache::NamedParam; use tfhe::shortint::parameters::multi_bit::ALL_MULTI_BIT_PARAMETER_VEC; use tfhe::shortint::parameters::{ShortintParameterSet, ALL_PARAMETER_VEC}; diff --git a/tfhe/examples/utilities/shortint_key_sizes.rs b/tfhe/examples/utilities/shortint_key_sizes.rs index 931abe3b9..63f816e4f 100644 --- a/tfhe/examples/utilities/shortint_key_sizes.rs +++ b/tfhe/examples/utilities/shortint_key_sizes.rs @@ -5,7 +5,8 @@ use crate::utilities::{write_to_json, OperatorType}; use std::fs::{File, OpenOptions}; use std::io::Write; use std::path::Path; -use tfhe::shortint::keycache::{NamedParam, KEY_CACHE}; +use tfhe::keycache::NamedParam; +use tfhe::shortint::keycache::KEY_CACHE; use tfhe::shortint::parameters::{ PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS, PARAM_MESSAGE_4_CARRY_4_KS_PBS, PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_2_KS_PBS, diff --git a/tfhe/examples/utilities/wasm_benchmarks_parser.rs b/tfhe/examples/utilities/wasm_benchmarks_parser.rs index 088365b2c..433ae7315 100644 --- a/tfhe/examples/utilities/wasm_benchmarks_parser.rs +++ b/tfhe/examples/utilities/wasm_benchmarks_parser.rs @@ -9,9 +9,9 @@ use std::fs; use std::fs::{File, OpenOptions}; use std::io::Write; use std::path::Path; +use tfhe::keycache::NamedParam; use tfhe::shortint::keycache::{ - NamedParam, PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS_NAME, - PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS_NAME, + PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS_NAME, PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS_NAME, }; use tfhe::shortint::parameters::{ PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS, PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS, diff --git a/tfhe/src/boolean/parameters/mod.rs b/tfhe/src/boolean/parameters/mod.rs index d5ccf678b..be41578e3 100644 --- a/tfhe/src/boolean/parameters/mod.rs +++ b/tfhe/src/boolean/parameters/mod.rs @@ -24,7 +24,8 @@ pub use crate::core_crypto::commons::parameters::{ LweDimension, PolynomialSize, }; -use crate::shortint::keycache::NamedParam; +#[cfg(any(test, doctest, feature = "internal-keycache"))] +use crate::keycache::NamedParam; use serde::{Deserialize, Serialize}; /// A set of cryptographic parameters for homomorphic Boolean circuit evaluation. @@ -197,6 +198,7 @@ pub const TFHE_LIB_PARAMETERS: BooleanParameters = BooleanParameters { pub const VEC_BOOLEAN_PARAM: [BooleanParameters; 2] = [DEFAULT_PARAMETERS, TFHE_LIB_PARAMETERS]; +#[cfg(any(test, doctest, feature = "internal-keycache"))] impl NamedParam for BooleanParameters { fn name(&self) -> &'static str { if *self == DEFAULT_PARAMETERS { diff --git a/tfhe/src/keycache/mod.rs b/tfhe/src/keycache/mod.rs index e69de29bb..b16940637 100644 --- a/tfhe/src/keycache/mod.rs +++ b/tfhe/src/keycache/mod.rs @@ -0,0 +1,220 @@ +pub use utils::{ + FileStorage, KeyCache as ImplKeyCache, NamedParam, PersistentStorage, + SharedKey as GenericSharedKey, +}; + +#[macro_use] +pub mod utils { + use fs2::FileExt; + use once_cell::sync::OnceCell; + use serde::de::DeserializeOwned; + use serde::Serialize; + use std::fs::File; + use std::io::{BufReader, BufWriter}; + use std::ops::Deref; + use std::path::PathBuf; + use std::sync::{Arc, RwLock}; + + pub trait PersistentStorage { + fn load(&self, param: P) -> Option; + fn store(&self, param: P, key: &K); + } + + pub trait NamedParam { + fn name(&self) -> &'static str; + } + + #[macro_export] + macro_rules! named_params_impl( + (expose $($const_param:ident),* $(,)? ) => { + $( + paste::paste! { + pub const [<$const_param _NAME>]: &'static str = stringify!($const_param); + } + )* + }; + + ($param_type:ty => $($const_param:ident),* $(,)? ) => { + named_params_impl!(expose $($const_param),*); + + impl NamedParam for $param_type { + fn name(&self) -> &'static str { + named_params_impl!({*self; $param_type} == ( $($const_param),* )); + } + } + }; + + ({$thing:expr; $param_type:ty} == ( $($const_param:ident),* $(,)? )) => { + $( + paste::paste! { + if $thing == <$param_type>::from($const_param) { + return [<$const_param _NAME>]; + } + } + )* + + panic!("Unnamed parameters"); + } + ); + + pub struct FileStorage { + prefix: String, + } + + impl FileStorage { + pub fn new(prefix: String) -> Self { + Self { prefix } + } + } + + impl PersistentStorage for FileStorage + where + P: NamedParam + DeserializeOwned + Serialize + PartialEq, + K: DeserializeOwned + Serialize, + { + fn load(&self, param: P) -> Option { + let mut path_buf = PathBuf::with_capacity(256); + path_buf.push(&self.prefix); + path_buf.push(param.name()); + path_buf.set_extension("bin"); + + if path_buf.exists() { + let file = File::open(&path_buf).unwrap(); + // Lock for reading + file.lock_shared().unwrap(); + let file_reader = BufReader::new(file); + bincode::deserialize_from::<_, (P, K)>(file_reader) + .ok() + .and_then(|(p, k)| if p == param { Some(k) } else { None }) + } else { + None + } + } + + fn store(&self, param: P, key: &K) { + let mut path_buf = PathBuf::with_capacity(256); + path_buf.push(&self.prefix); + std::fs::create_dir_all(&path_buf).unwrap(); + path_buf.push(param.name()); + path_buf.set_extension("bin"); + + let file = File::create(&path_buf).unwrap(); + // Lock for writing + file.lock_exclusive().unwrap(); + + let file_writer = BufWriter::new(file); + bincode::serialize_into(file_writer, &(param, key)).unwrap(); + } + } + + pub struct SharedKey { + inner: Arc>, + } + + impl Clone for SharedKey { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } + } + + impl Deref for SharedKey { + type Target = K; + + fn deref(&self) -> &Self::Target { + self.inner.get().unwrap() + } + } + + pub struct KeyCache { + // Where the keys will be stored persistently + // So they are not generated between each run + persistent_storage: S, + // Temporary memory storage to avoid querying the persistent storage each time + // the outer Arc makes it so that we don't clone the OnceCell contents when initializing it + memory_storage: RwLock)>>, + } + + impl KeyCache { + pub fn new(storage: S) -> Self { + Self { + persistent_storage: storage, + memory_storage: RwLock::new(vec![]), + } + } + + pub fn clear_in_memory_cache(&self) { + let mut memory_storage = self.memory_storage.write().unwrap(); + memory_storage.clear(); + } + } + + impl KeyCache + where + P: Copy + PartialEq + NamedParam, + S: PersistentStorage, + K: From

+ Clone, + { + pub fn get(&self, param: P) -> SharedKey { + self.with_key(param, |k| k.clone()) + } + + pub fn with_key(&self, param: P, f: F) -> R + where + F: FnOnce(&SharedKey) -> R, + { + let load_from_persistent_storage = || { + // we check if we can load the key from persistent storage + let persistent_storage = &self.persistent_storage; + let maybe_key = persistent_storage.load(param); + match maybe_key { + Some(key) => key, + None => { + let key = K::from(param); + persistent_storage.store(param, &key); + key + } + } + }; + + let try_load_from_memory_and_init = || { + // we only hold a read lock for a short duration to find the key + let maybe_shared_cell = { + let memory_storage = self.memory_storage.read().unwrap(); + memory_storage + .iter() + .find(|(p, _)| *p == param) + .map(|param_key| param_key.1.clone()) + }; + + if let Some(shared_cell) = maybe_shared_cell { + shared_cell.inner.get_or_init(load_from_persistent_storage); + Ok(shared_cell) + } else { + Err(()) + } + }; + + match try_load_from_memory_and_init() { + Ok(result) => f(&result), + Err(()) => { + { + // we only hold a write lock for a short duration to push the lazily + // evaluated key without actually evaluating the key + let mut memory_storage = self.memory_storage.write().unwrap(); + if !memory_storage.iter().any(|(p, _)| *p == param) { + memory_storage.push(( + param, + SharedKey { + inner: Arc::new(OnceCell::new()), + }, + )); + } + } + f(&try_load_from_memory_and_init().ok().unwrap()) + } + } + } + } +} diff --git a/tfhe/src/lib.rs b/tfhe/src/lib.rs index 0b64b619e..c9fc6bb18 100644 --- a/tfhe/src/lib.rs +++ b/tfhe/src/lib.rs @@ -55,5 +55,10 @@ mod test_user_docs; /// cbindgen:ignore #[cfg(any(feature = "boolean", feature = "shortint", feature = "integer"))] pub(crate) mod high_level_api; + #[cfg(any(feature = "boolean", feature = "shortint", feature = "integer"))] pub use high_level_api::*; + +/// cbindgen:ignore +#[cfg(any(test, doctest, feature = "internal-keycache"))] +pub mod keycache; diff --git a/tfhe/src/shortint/keycache.rs b/tfhe/src/shortint/keycache.rs index 45356b6d2..b72d407f5 100644 --- a/tfhe/src/shortint/keycache.rs +++ b/tfhe/src/shortint/keycache.rs @@ -1,3 +1,5 @@ +use crate::keycache::*; +use crate::named_params_impl; use crate::shortint::parameters::multi_bit::*; use crate::shortint::parameters::parameters_compact_pk::*; use crate::shortint::parameters::parameters_wopbs::*; @@ -9,227 +11,6 @@ use crate::shortint::{ClientKey, ServerKey}; use lazy_static::*; use serde::{Deserialize, Serialize}; -pub use utils::{ - FileStorage, KeyCache as ImplKeyCache, NamedParam, PersistentStorage, - SharedKey as GenericSharedKey, -}; - -#[macro_use] -pub mod utils { - use fs2::FileExt; - use once_cell::sync::OnceCell; - use serde::de::DeserializeOwned; - use serde::Serialize; - use std::fs::File; - use std::io::{BufReader, BufWriter}; - use std::ops::Deref; - use std::path::PathBuf; - use std::sync::{Arc, RwLock}; - - pub trait PersistentStorage { - fn load(&self, param: P) -> Option; - fn store(&self, param: P, key: &K); - } - - pub trait NamedParam { - fn name(&self) -> &'static str; - } - - #[macro_export] - macro_rules! named_params_impl( - (expose $($const_param:ident),* $(,)? ) => { - $( - paste::paste! { - pub const [<$const_param _NAME>]: &'static str = stringify!($const_param); - } - )* - }; - - ($param_type:ty => $($const_param:ident),* $(,)? ) => { - named_params_impl!(expose $($const_param),*); - - impl NamedParam for $param_type { - fn name(&self) -> &'static str { - named_params_impl!({*self; $param_type} == ( $($const_param),* )); - } - } - }; - - ({$thing:expr; $param_type:ty} == ( $($const_param:ident),* $(,)? )) => { - $( - paste::paste! { - if $thing == <$param_type>::from($const_param) { - return [<$const_param _NAME>]; - } - } - )* - - panic!("Unnamed parameters"); - } - ); - - pub struct FileStorage { - prefix: String, - } - - impl FileStorage { - pub fn new(prefix: String) -> Self { - Self { prefix } - } - } - - impl PersistentStorage for FileStorage - where - P: NamedParam + DeserializeOwned + Serialize + PartialEq, - K: DeserializeOwned + Serialize, - { - fn load(&self, param: P) -> Option { - let mut path_buf = PathBuf::with_capacity(256); - path_buf.push(&self.prefix); - path_buf.push(param.name()); - path_buf.set_extension("bin"); - - if path_buf.exists() { - let file = File::open(&path_buf).unwrap(); - // Lock for reading - file.lock_shared().unwrap(); - let file_reader = BufReader::new(file); - bincode::deserialize_from::<_, (P, K)>(file_reader) - .ok() - .and_then(|(p, k)| if p == param { Some(k) } else { None }) - } else { - None - } - } - - fn store(&self, param: P, key: &K) { - let mut path_buf = PathBuf::with_capacity(256); - path_buf.push(&self.prefix); - std::fs::create_dir_all(&path_buf).unwrap(); - path_buf.push(param.name()); - path_buf.set_extension("bin"); - - let file = File::create(&path_buf).unwrap(); - // Lock for writing - file.lock_exclusive().unwrap(); - - let file_writer = BufWriter::new(file); - bincode::serialize_into(file_writer, &(param, key)).unwrap(); - } - } - - pub struct SharedKey { - inner: Arc>, - } - - impl Clone for SharedKey { - fn clone(&self) -> Self { - Self { - inner: self.inner.clone(), - } - } - } - - impl Deref for SharedKey { - type Target = K; - - fn deref(&self) -> &Self::Target { - self.inner.get().unwrap() - } - } - - pub struct KeyCache { - // Where the keys will be stored persistently - // So they are not generated between each run - persistent_storage: S, - // Temporary memory storage to avoid querying the persistent storage each time - // the outer Arc makes it so that we don't clone the OnceCell contents when initializing it - memory_storage: RwLock)>>, - } - - impl KeyCache { - pub fn new(storage: S) -> Self { - Self { - persistent_storage: storage, - memory_storage: RwLock::new(vec![]), - } - } - - pub fn clear_in_memory_cache(&self) { - let mut memory_storage = self.memory_storage.write().unwrap(); - memory_storage.clear(); - } - } - - impl KeyCache - where - P: Copy + PartialEq + NamedParam, - S: PersistentStorage, - K: From

+ Clone, - { - pub fn get(&self, param: P) -> SharedKey { - self.with_key(param, |k| k.clone()) - } - - pub fn with_key(&self, param: P, f: F) -> R - where - F: FnOnce(&SharedKey) -> R, - { - let load_from_persistent_storage = || { - // we check if we can load the key from persistent storage - let persistent_storage = &self.persistent_storage; - let maybe_key = persistent_storage.load(param); - match maybe_key { - Some(key) => key, - None => { - let key = K::from(param); - persistent_storage.store(param, &key); - key - } - } - }; - - let try_load_from_memory_and_init = || { - // we only hold a read lock for a short duration to find the key - let maybe_shared_cell = { - let memory_storage = self.memory_storage.read().unwrap(); - memory_storage - .iter() - .find(|(p, _)| *p == param) - .map(|param_key| param_key.1.clone()) - }; - - if let Some(shared_cell) = maybe_shared_cell { - shared_cell.inner.get_or_init(load_from_persistent_storage); - Ok(shared_cell) - } else { - Err(()) - } - }; - - match try_load_from_memory_and_init() { - Ok(result) => f(&result), - Err(()) => { - { - // we only hold a write lock for a short duration to push the lazily - // evaluated key without actually evaluating the key - let mut memory_storage = self.memory_storage.write().unwrap(); - if !memory_storage.iter().any(|(p, _)| *p == param) { - memory_storage.push(( - param, - SharedKey { - inner: Arc::new(OnceCell::new()), - }, - )); - } - } - f(&try_load_from_memory_and_init().ok().unwrap()) - } - } - } - } -} - named_params_impl!( ShortintParameterSet => PARAM_MESSAGE_1_CARRY_0_KS_PBS, PARAM_MESSAGE_1_CARRY_1_KS_PBS,