mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 07:08:03 -05:00
chore(tfhe): refactor keycache in its own crate
The general parts of shortint keycache has been moved to its own crate. This enable boolean layer to get access to traits without having to import shortint::keycache module.
This commit is contained in:
14
Makefile
14
Makefile
@@ -100,7 +100,7 @@ clippy_core: install_rs_check_toolchain
|
||||
.PHONY: clippy_boolean # Run clippy lints enabling the boolean features
|
||||
clippy_boolean: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,internal-keycache \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean \
|
||||
-p tfhe -- --no-deps -D warnings
|
||||
|
||||
.PHONY: clippy_shortint # Run clippy lints enabling the shortint features
|
||||
@@ -118,19 +118,19 @@ clippy_integer: install_rs_check_toolchain
|
||||
.PHONY: clippy # Run clippy lints enabling the boolean, shortint, integer
|
||||
clippy: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer \
|
||||
-p tfhe -- --no-deps -D warnings
|
||||
|
||||
.PHONY: clippy_c_api # Run clippy lints enabling the boolean, shortint and the C API
|
||||
clippy_c_api: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,internal-keycache \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api \
|
||||
-p tfhe -- --no-deps -D warnings
|
||||
|
||||
.PHONY: clippy_js_wasm_api # Run clippy lints enabling the boolean, shortint, integer and the js wasm API
|
||||
clippy_js_wasm_api: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
|
||||
--features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api,internal-keycache \
|
||||
--features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api \
|
||||
-p tfhe -- --no-deps -D warnings
|
||||
|
||||
.PHONY: clippy_tasks # Run clippy lints on helper tasks crate.
|
||||
@@ -141,7 +141,7 @@ clippy_tasks:
|
||||
.PHONY: clippy_all_targets # Run clippy lints on all targets (benches, examples, etc.)
|
||||
clippy_all_targets:
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer \
|
||||
-p tfhe -- --no-deps -D warnings
|
||||
|
||||
.PHONY: clippy_all # Run all clippy targets
|
||||
@@ -199,7 +199,7 @@ build_tfhe_full: install_rs_build_toolchain
|
||||
.PHONY: build_c_api # Build the C API for boolean, shortint and integer
|
||||
build_c_api: install_rs_check_toolchain
|
||||
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) build --profile $(CARGO_PROFILE) \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,high-level-c-api \
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,high-level-c-api, \
|
||||
-p tfhe
|
||||
|
||||
.PHONY: build_c_api_experimental_deterministic_fft # Build the C API for boolean, shortint and integer with experimental deterministic FFT
|
||||
@@ -349,7 +349,7 @@ docs: doc
|
||||
lint_doc: install_rs_check_toolchain
|
||||
RUSTDOCFLAGS="--html-in-header katex-header.html -Dwarnings" \
|
||||
cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" doc \
|
||||
--features=$(TARGET_ARCH_FEATURE),internal-keycache,boolean,shortint,integer --no-deps
|
||||
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer --no-deps
|
||||
|
||||
.PHONY: lint_docs # Build rust doc with linting enabled alias for lint_doc
|
||||
lint_docs: lint_doc
|
||||
|
||||
@@ -9,7 +9,7 @@ use tfhe::boolean::parameters::{
|
||||
BooleanParameters, DEFAULT_PARAMETERS, PARAMETERS_ERROR_PROB_2_POW_MINUS_165,
|
||||
};
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
use tfhe::shortint::keycache::NamedParam;
|
||||
use tfhe::keycache::NamedParam;
|
||||
use tfhe::shortint::parameters::*;
|
||||
use tfhe::shortint::ClassicPBSParameters;
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ use rand::Rng;
|
||||
use std::vec::IntoIter;
|
||||
use tfhe::integer::keycache::KEY_CACHE;
|
||||
use tfhe::integer::{RadixCiphertext, ServerKey};
|
||||
use tfhe::shortint::keycache::NamedParam;
|
||||
use tfhe::keycache::NamedParam;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use tfhe::shortint::parameters::{
|
||||
|
||||
@@ -5,14 +5,13 @@ use crate::utilities::{write_to_json, OperatorType};
|
||||
use std::env;
|
||||
|
||||
use criterion::{criterion_group, Criterion};
|
||||
use tfhe::shortint::keycache::NamedParam;
|
||||
use tfhe::keycache::NamedParam;
|
||||
use tfhe::shortint::parameters::*;
|
||||
use tfhe::shortint::{Ciphertext, ClassicPBSParameters, ServerKey, ShortintParameterSet};
|
||||
|
||||
use rand::Rng;
|
||||
use tfhe::shortint::keycache::KEY_CACHE;
|
||||
use tfhe::shortint::keycache::{KEY_CACHE, KEY_CACHE_WOPBS};
|
||||
|
||||
use tfhe::shortint::keycache::KEY_CACHE_WOPBS;
|
||||
use tfhe::shortint::parameters::parameters_wopbs::WOPBS_PARAM_MESSAGE_4_NORM2_6_KS_PBS;
|
||||
|
||||
const SERVER_KEY_BENCH_PARAMS: [ClassicPBSParameters; 4] = [
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use clap::{Arg, ArgAction, Command};
|
||||
use tfhe::shortint::keycache::{NamedParam, KEY_CACHE, KEY_CACHE_WOPBS};
|
||||
use tfhe::keycache::NamedParam;
|
||||
use tfhe::shortint::keycache::{KEY_CACHE, KEY_CACHE_WOPBS};
|
||||
use tfhe::shortint::parameters::parameters_wopbs_message_carry::{
|
||||
WOPBS_PARAM_MESSAGE_1_CARRY_1_KS_PBS, WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
|
||||
WOPBS_PARAM_MESSAGE_3_CARRY_3_KS_PBS, WOPBS_PARAM_MESSAGE_4_CARRY_4_KS_PBS,
|
||||
|
||||
@@ -7,8 +7,8 @@ use std::fs::{File, OpenOptions};
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use tfhe::integer::U256;
|
||||
use tfhe::keycache::NamedParam;
|
||||
use tfhe::prelude::*;
|
||||
use tfhe::shortint::keycache::NamedParam;
|
||||
use tfhe::shortint::parameters::{
|
||||
PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS, PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS,
|
||||
};
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::path::Path;
|
||||
use tfhe::boolean::parameters::{BooleanParameters, VEC_BOOLEAN_PARAM};
|
||||
use tfhe::core_crypto::commons::dispersion::StandardDev;
|
||||
use tfhe::core_crypto::commons::parameters::{GlweDimension, LweDimension, PolynomialSize};
|
||||
use tfhe::shortint::keycache::NamedParam;
|
||||
use tfhe::keycache::NamedParam;
|
||||
use tfhe::shortint::parameters::multi_bit::ALL_MULTI_BIT_PARAMETER_VEC;
|
||||
use tfhe::shortint::parameters::{ShortintParameterSet, ALL_PARAMETER_VEC};
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ use crate::utilities::{write_to_json, OperatorType};
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use tfhe::shortint::keycache::{NamedParam, KEY_CACHE};
|
||||
use tfhe::keycache::NamedParam;
|
||||
use tfhe::shortint::keycache::KEY_CACHE;
|
||||
use tfhe::shortint::parameters::{
|
||||
PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS,
|
||||
PARAM_MESSAGE_4_CARRY_4_KS_PBS, PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_2_KS_PBS,
|
||||
|
||||
@@ -9,9 +9,9 @@ use std::fs;
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io::Write;
|
||||
use std::path::Path;
|
||||
use tfhe::keycache::NamedParam;
|
||||
use tfhe::shortint::keycache::{
|
||||
NamedParam, PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS_NAME,
|
||||
PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS_NAME,
|
||||
PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS_NAME, PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS_NAME,
|
||||
};
|
||||
use tfhe::shortint::parameters::{
|
||||
PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS, PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS,
|
||||
|
||||
@@ -24,7 +24,8 @@ pub use crate::core_crypto::commons::parameters::{
|
||||
LweDimension, PolynomialSize,
|
||||
};
|
||||
|
||||
use crate::shortint::keycache::NamedParam;
|
||||
#[cfg(any(test, doctest, feature = "internal-keycache"))]
|
||||
use crate::keycache::NamedParam;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A set of cryptographic parameters for homomorphic Boolean circuit evaluation.
|
||||
@@ -197,6 +198,7 @@ pub const TFHE_LIB_PARAMETERS: BooleanParameters = BooleanParameters {
|
||||
|
||||
pub const VEC_BOOLEAN_PARAM: [BooleanParameters; 2] = [DEFAULT_PARAMETERS, TFHE_LIB_PARAMETERS];
|
||||
|
||||
#[cfg(any(test, doctest, feature = "internal-keycache"))]
|
||||
impl NamedParam for BooleanParameters {
|
||||
fn name(&self) -> &'static str {
|
||||
if *self == DEFAULT_PARAMETERS {
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
pub use utils::{
|
||||
FileStorage, KeyCache as ImplKeyCache, NamedParam, PersistentStorage,
|
||||
SharedKey as GenericSharedKey,
|
||||
};
|
||||
|
||||
#[macro_use]
|
||||
pub mod utils {
|
||||
use fs2::FileExt;
|
||||
use once_cell::sync::OnceCell;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter};
|
||||
use std::ops::Deref;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
pub trait PersistentStorage<P, K> {
|
||||
fn load(&self, param: P) -> Option<K>;
|
||||
fn store(&self, param: P, key: &K);
|
||||
}
|
||||
|
||||
pub trait NamedParam {
|
||||
fn name(&self) -> &'static str;
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! named_params_impl(
|
||||
(expose $($const_param:ident),* $(,)? ) => {
|
||||
$(
|
||||
paste::paste! {
|
||||
pub const [<$const_param _NAME>]: &'static str = stringify!($const_param);
|
||||
}
|
||||
)*
|
||||
};
|
||||
|
||||
($param_type:ty => $($const_param:ident),* $(,)? ) => {
|
||||
named_params_impl!(expose $($const_param),*);
|
||||
|
||||
impl NamedParam for $param_type {
|
||||
fn name(&self) -> &'static str {
|
||||
named_params_impl!({*self; $param_type} == ( $($const_param),* ));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
({$thing:expr; $param_type:ty} == ( $($const_param:ident),* $(,)? )) => {
|
||||
$(
|
||||
paste::paste! {
|
||||
if $thing == <$param_type>::from($const_param) {
|
||||
return [<$const_param _NAME>];
|
||||
}
|
||||
}
|
||||
)*
|
||||
|
||||
panic!("Unnamed parameters");
|
||||
}
|
||||
);
|
||||
|
||||
pub struct FileStorage {
|
||||
prefix: String,
|
||||
}
|
||||
|
||||
impl FileStorage {
|
||||
pub fn new(prefix: String) -> Self {
|
||||
Self { prefix }
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, K> PersistentStorage<P, K> for FileStorage
|
||||
where
|
||||
P: NamedParam + DeserializeOwned + Serialize + PartialEq,
|
||||
K: DeserializeOwned + Serialize,
|
||||
{
|
||||
fn load(&self, param: P) -> Option<K> {
|
||||
let mut path_buf = PathBuf::with_capacity(256);
|
||||
path_buf.push(&self.prefix);
|
||||
path_buf.push(param.name());
|
||||
path_buf.set_extension("bin");
|
||||
|
||||
if path_buf.exists() {
|
||||
let file = File::open(&path_buf).unwrap();
|
||||
// Lock for reading
|
||||
file.lock_shared().unwrap();
|
||||
let file_reader = BufReader::new(file);
|
||||
bincode::deserialize_from::<_, (P, K)>(file_reader)
|
||||
.ok()
|
||||
.and_then(|(p, k)| if p == param { Some(k) } else { None })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn store(&self, param: P, key: &K) {
|
||||
let mut path_buf = PathBuf::with_capacity(256);
|
||||
path_buf.push(&self.prefix);
|
||||
std::fs::create_dir_all(&path_buf).unwrap();
|
||||
path_buf.push(param.name());
|
||||
path_buf.set_extension("bin");
|
||||
|
||||
let file = File::create(&path_buf).unwrap();
|
||||
// Lock for writing
|
||||
file.lock_exclusive().unwrap();
|
||||
|
||||
let file_writer = BufWriter::new(file);
|
||||
bincode::serialize_into(file_writer, &(param, key)).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SharedKey<K> {
|
||||
inner: Arc<OnceCell<K>>,
|
||||
}
|
||||
|
||||
impl<K> Clone for SharedKey<K> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K> Deref for SharedKey<K> {
|
||||
type Target = K;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.inner.get().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct KeyCache<P, K, S> {
|
||||
// Where the keys will be stored persistently
|
||||
// So they are not generated between each run
|
||||
persistent_storage: S,
|
||||
// Temporary memory storage to avoid querying the persistent storage each time
|
||||
// the outer Arc makes it so that we don't clone the OnceCell contents when initializing it
|
||||
memory_storage: RwLock<Vec<(P, SharedKey<K>)>>,
|
||||
}
|
||||
|
||||
impl<P, K, S> KeyCache<P, K, S> {
|
||||
pub fn new(storage: S) -> Self {
|
||||
Self {
|
||||
persistent_storage: storage,
|
||||
memory_storage: RwLock::new(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_in_memory_cache(&self) {
|
||||
let mut memory_storage = self.memory_storage.write().unwrap();
|
||||
memory_storage.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, K, S> KeyCache<P, K, S>
|
||||
where
|
||||
P: Copy + PartialEq + NamedParam,
|
||||
S: PersistentStorage<P, K>,
|
||||
K: From<P> + Clone,
|
||||
{
|
||||
pub fn get(&self, param: P) -> SharedKey<K> {
|
||||
self.with_key(param, |k| k.clone())
|
||||
}
|
||||
|
||||
pub fn with_key<F, R>(&self, param: P, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&SharedKey<K>) -> R,
|
||||
{
|
||||
let load_from_persistent_storage = || {
|
||||
// we check if we can load the key from persistent storage
|
||||
let persistent_storage = &self.persistent_storage;
|
||||
let maybe_key = persistent_storage.load(param);
|
||||
match maybe_key {
|
||||
Some(key) => key,
|
||||
None => {
|
||||
let key = K::from(param);
|
||||
persistent_storage.store(param, &key);
|
||||
key
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let try_load_from_memory_and_init = || {
|
||||
// we only hold a read lock for a short duration to find the key
|
||||
let maybe_shared_cell = {
|
||||
let memory_storage = self.memory_storage.read().unwrap();
|
||||
memory_storage
|
||||
.iter()
|
||||
.find(|(p, _)| *p == param)
|
||||
.map(|param_key| param_key.1.clone())
|
||||
};
|
||||
|
||||
if let Some(shared_cell) = maybe_shared_cell {
|
||||
shared_cell.inner.get_or_init(load_from_persistent_storage);
|
||||
Ok(shared_cell)
|
||||
} else {
|
||||
Err(())
|
||||
}
|
||||
};
|
||||
|
||||
match try_load_from_memory_and_init() {
|
||||
Ok(result) => f(&result),
|
||||
Err(()) => {
|
||||
{
|
||||
// we only hold a write lock for a short duration to push the lazily
|
||||
// evaluated key without actually evaluating the key
|
||||
let mut memory_storage = self.memory_storage.write().unwrap();
|
||||
if !memory_storage.iter().any(|(p, _)| *p == param) {
|
||||
memory_storage.push((
|
||||
param,
|
||||
SharedKey {
|
||||
inner: Arc::new(OnceCell::new()),
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
f(&try_load_from_memory_and_init().ok().unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,5 +55,10 @@ mod test_user_docs;
|
||||
/// cbindgen:ignore
|
||||
#[cfg(any(feature = "boolean", feature = "shortint", feature = "integer"))]
|
||||
pub(crate) mod high_level_api;
|
||||
|
||||
#[cfg(any(feature = "boolean", feature = "shortint", feature = "integer"))]
|
||||
pub use high_level_api::*;
|
||||
|
||||
/// cbindgen:ignore
|
||||
#[cfg(any(test, doctest, feature = "internal-keycache"))]
|
||||
pub mod keycache;
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use crate::keycache::*;
|
||||
use crate::named_params_impl;
|
||||
use crate::shortint::parameters::multi_bit::*;
|
||||
use crate::shortint::parameters::parameters_compact_pk::*;
|
||||
use crate::shortint::parameters::parameters_wopbs::*;
|
||||
@@ -9,227 +11,6 @@ use crate::shortint::{ClientKey, ServerKey};
|
||||
use lazy_static::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub use utils::{
|
||||
FileStorage, KeyCache as ImplKeyCache, NamedParam, PersistentStorage,
|
||||
SharedKey as GenericSharedKey,
|
||||
};
|
||||
|
||||
#[macro_use]
|
||||
pub mod utils {
|
||||
use fs2::FileExt;
|
||||
use once_cell::sync::OnceCell;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, BufWriter};
|
||||
use std::ops::Deref;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
pub trait PersistentStorage<P, K> {
|
||||
fn load(&self, param: P) -> Option<K>;
|
||||
fn store(&self, param: P, key: &K);
|
||||
}
|
||||
|
||||
pub trait NamedParam {
|
||||
fn name(&self) -> &'static str;
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! named_params_impl(
|
||||
(expose $($const_param:ident),* $(,)? ) => {
|
||||
$(
|
||||
paste::paste! {
|
||||
pub const [<$const_param _NAME>]: &'static str = stringify!($const_param);
|
||||
}
|
||||
)*
|
||||
};
|
||||
|
||||
($param_type:ty => $($const_param:ident),* $(,)? ) => {
|
||||
named_params_impl!(expose $($const_param),*);
|
||||
|
||||
impl NamedParam for $param_type {
|
||||
fn name(&self) -> &'static str {
|
||||
named_params_impl!({*self; $param_type} == ( $($const_param),* ));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
({$thing:expr; $param_type:ty} == ( $($const_param:ident),* $(,)? )) => {
|
||||
$(
|
||||
paste::paste! {
|
||||
if $thing == <$param_type>::from($const_param) {
|
||||
return [<$const_param _NAME>];
|
||||
}
|
||||
}
|
||||
)*
|
||||
|
||||
panic!("Unnamed parameters");
|
||||
}
|
||||
);
|
||||
|
||||
pub struct FileStorage {
|
||||
prefix: String,
|
||||
}
|
||||
|
||||
impl FileStorage {
|
||||
pub fn new(prefix: String) -> Self {
|
||||
Self { prefix }
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, K> PersistentStorage<P, K> for FileStorage
|
||||
where
|
||||
P: NamedParam + DeserializeOwned + Serialize + PartialEq,
|
||||
K: DeserializeOwned + Serialize,
|
||||
{
|
||||
fn load(&self, param: P) -> Option<K> {
|
||||
let mut path_buf = PathBuf::with_capacity(256);
|
||||
path_buf.push(&self.prefix);
|
||||
path_buf.push(param.name());
|
||||
path_buf.set_extension("bin");
|
||||
|
||||
if path_buf.exists() {
|
||||
let file = File::open(&path_buf).unwrap();
|
||||
// Lock for reading
|
||||
file.lock_shared().unwrap();
|
||||
let file_reader = BufReader::new(file);
|
||||
bincode::deserialize_from::<_, (P, K)>(file_reader)
|
||||
.ok()
|
||||
.and_then(|(p, k)| if p == param { Some(k) } else { None })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn store(&self, param: P, key: &K) {
|
||||
let mut path_buf = PathBuf::with_capacity(256);
|
||||
path_buf.push(&self.prefix);
|
||||
std::fs::create_dir_all(&path_buf).unwrap();
|
||||
path_buf.push(param.name());
|
||||
path_buf.set_extension("bin");
|
||||
|
||||
let file = File::create(&path_buf).unwrap();
|
||||
// Lock for writing
|
||||
file.lock_exclusive().unwrap();
|
||||
|
||||
let file_writer = BufWriter::new(file);
|
||||
bincode::serialize_into(file_writer, &(param, key)).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SharedKey<K> {
|
||||
inner: Arc<OnceCell<K>>,
|
||||
}
|
||||
|
||||
impl<K> Clone for SharedKey<K> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<K> Deref for SharedKey<K> {
|
||||
type Target = K;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.inner.get().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct KeyCache<P, K, S> {
|
||||
// Where the keys will be stored persistently
|
||||
// So they are not generated between each run
|
||||
persistent_storage: S,
|
||||
// Temporary memory storage to avoid querying the persistent storage each time
|
||||
// the outer Arc makes it so that we don't clone the OnceCell contents when initializing it
|
||||
memory_storage: RwLock<Vec<(P, SharedKey<K>)>>,
|
||||
}
|
||||
|
||||
impl<P, K, S> KeyCache<P, K, S> {
|
||||
pub fn new(storage: S) -> Self {
|
||||
Self {
|
||||
persistent_storage: storage,
|
||||
memory_storage: RwLock::new(vec![]),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_in_memory_cache(&self) {
|
||||
let mut memory_storage = self.memory_storage.write().unwrap();
|
||||
memory_storage.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl<P, K, S> KeyCache<P, K, S>
|
||||
where
|
||||
P: Copy + PartialEq + NamedParam,
|
||||
S: PersistentStorage<P, K>,
|
||||
K: From<P> + Clone,
|
||||
{
|
||||
pub fn get(&self, param: P) -> SharedKey<K> {
|
||||
self.with_key(param, |k| k.clone())
|
||||
}
|
||||
|
||||
pub fn with_key<F, R>(&self, param: P, f: F) -> R
|
||||
where
|
||||
F: FnOnce(&SharedKey<K>) -> R,
|
||||
{
|
||||
let load_from_persistent_storage = || {
|
||||
// we check if we can load the key from persistent storage
|
||||
let persistent_storage = &self.persistent_storage;
|
||||
let maybe_key = persistent_storage.load(param);
|
||||
match maybe_key {
|
||||
Some(key) => key,
|
||||
None => {
|
||||
let key = K::from(param);
|
||||
persistent_storage.store(param, &key);
|
||||
key
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let try_load_from_memory_and_init = || {
|
||||
// we only hold a read lock for a short duration to find the key
|
||||
let maybe_shared_cell = {
|
||||
let memory_storage = self.memory_storage.read().unwrap();
|
||||
memory_storage
|
||||
.iter()
|
||||
.find(|(p, _)| *p == param)
|
||||
.map(|param_key| param_key.1.clone())
|
||||
};
|
||||
|
||||
if let Some(shared_cell) = maybe_shared_cell {
|
||||
shared_cell.inner.get_or_init(load_from_persistent_storage);
|
||||
Ok(shared_cell)
|
||||
} else {
|
||||
Err(())
|
||||
}
|
||||
};
|
||||
|
||||
match try_load_from_memory_and_init() {
|
||||
Ok(result) => f(&result),
|
||||
Err(()) => {
|
||||
{
|
||||
// we only hold a write lock for a short duration to push the lazily
|
||||
// evaluated key without actually evaluating the key
|
||||
let mut memory_storage = self.memory_storage.write().unwrap();
|
||||
if !memory_storage.iter().any(|(p, _)| *p == param) {
|
||||
memory_storage.push((
|
||||
param,
|
||||
SharedKey {
|
||||
inner: Arc::new(OnceCell::new()),
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
f(&try_load_from_memory_and_init().ok().unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
named_params_impl!( ShortintParameterSet =>
|
||||
PARAM_MESSAGE_1_CARRY_0_KS_PBS,
|
||||
PARAM_MESSAGE_1_CARRY_1_KS_PBS,
|
||||
|
||||
Reference in New Issue
Block a user