chore(tfhe): refactor keycache in its own crate

The general parts of shortint keycache has been moved to its own
crate. This enable boolean layer to get access to traits without
having to import shortint::keycache module.
This commit is contained in:
David Testé
2023-08-09 15:33:39 +02:00
committed by David Testé
parent 59181d4717
commit 304932a861
13 changed files with 249 additions and 240 deletions

View File

@@ -100,7 +100,7 @@ clippy_core: install_rs_check_toolchain
.PHONY: clippy_boolean # Run clippy lints enabling the boolean features
clippy_boolean: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,internal-keycache \
--features=$(TARGET_ARCH_FEATURE),boolean \
-p tfhe -- --no-deps -D warnings
.PHONY: clippy_shortint # Run clippy lints enabling the shortint features
@@ -118,19 +118,19 @@ clippy_integer: install_rs_check_toolchain
.PHONY: clippy # Run clippy lints enabling the boolean, shortint, integer
clippy: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache \
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer \
-p tfhe -- --no-deps -D warnings
.PHONY: clippy_c_api # Run clippy lints enabling the boolean, shortint and the C API
clippy_c_api: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,internal-keycache \
--features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api \
-p tfhe -- --no-deps -D warnings
.PHONY: clippy_js_wasm_api # Run clippy lints enabling the boolean, shortint, integer and the js wasm API
clippy_js_wasm_api: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy \
--features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api,internal-keycache \
--features=boolean-client-js-wasm-api,shortint-client-js-wasm-api,integer-client-js-wasm-api \
-p tfhe -- --no-deps -D warnings
.PHONY: clippy_tasks # Run clippy lints on helper tasks crate.
@@ -141,7 +141,7 @@ clippy_tasks:
.PHONY: clippy_all_targets # Run clippy lints on all targets (benches, examples, etc.)
clippy_all_targets:
RUSTFLAGS="$(RUSTFLAGS)" cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" clippy --all-targets \
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer,internal-keycache \
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer \
-p tfhe -- --no-deps -D warnings
.PHONY: clippy_all # Run all clippy targets
@@ -199,7 +199,7 @@ build_tfhe_full: install_rs_build_toolchain
.PHONY: build_c_api # Build the C API for boolean, shortint and integer
build_c_api: install_rs_check_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_CHECK_TOOLCHAIN) build --profile $(CARGO_PROFILE) \
--features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,high-level-c-api \
--features=$(TARGET_ARCH_FEATURE),boolean-c-api,shortint-c-api,high-level-c-api, \
-p tfhe
.PHONY: build_c_api_experimental_deterministic_fft # Build the C API for boolean, shortint and integer with experimental deterministic FFT
@@ -349,7 +349,7 @@ docs: doc
lint_doc: install_rs_check_toolchain
RUSTDOCFLAGS="--html-in-header katex-header.html -Dwarnings" \
cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" doc \
--features=$(TARGET_ARCH_FEATURE),internal-keycache,boolean,shortint,integer --no-deps
--features=$(TARGET_ARCH_FEATURE),boolean,shortint,integer --no-deps
.PHONY: lint_docs # Build rust doc with linting enabled alias for lint_doc
lint_docs: lint_doc

View File

@@ -9,7 +9,7 @@ use tfhe::boolean::parameters::{
BooleanParameters, DEFAULT_PARAMETERS, PARAMETERS_ERROR_PROB_2_POW_MINUS_165,
};
use tfhe::core_crypto::prelude::*;
use tfhe::shortint::keycache::NamedParam;
use tfhe::keycache::NamedParam;
use tfhe::shortint::parameters::*;
use tfhe::shortint::ClassicPBSParameters;

View File

@@ -13,7 +13,7 @@ use rand::Rng;
use std::vec::IntoIter;
use tfhe::integer::keycache::KEY_CACHE;
use tfhe::integer::{RadixCiphertext, ServerKey};
use tfhe::shortint::keycache::NamedParam;
use tfhe::keycache::NamedParam;
#[allow(unused_imports)]
use tfhe::shortint::parameters::{

View File

@@ -5,14 +5,13 @@ use crate::utilities::{write_to_json, OperatorType};
use std::env;
use criterion::{criterion_group, Criterion};
use tfhe::shortint::keycache::NamedParam;
use tfhe::keycache::NamedParam;
use tfhe::shortint::parameters::*;
use tfhe::shortint::{Ciphertext, ClassicPBSParameters, ServerKey, ShortintParameterSet};
use rand::Rng;
use tfhe::shortint::keycache::KEY_CACHE;
use tfhe::shortint::keycache::{KEY_CACHE, KEY_CACHE_WOPBS};
use tfhe::shortint::keycache::KEY_CACHE_WOPBS;
use tfhe::shortint::parameters::parameters_wopbs::WOPBS_PARAM_MESSAGE_4_NORM2_6_KS_PBS;
const SERVER_KEY_BENCH_PARAMS: [ClassicPBSParameters; 4] = [

View File

@@ -1,5 +1,6 @@
use clap::{Arg, ArgAction, Command};
use tfhe::shortint::keycache::{NamedParam, KEY_CACHE, KEY_CACHE_WOPBS};
use tfhe::keycache::NamedParam;
use tfhe::shortint::keycache::{KEY_CACHE, KEY_CACHE_WOPBS};
use tfhe::shortint::parameters::parameters_wopbs_message_carry::{
WOPBS_PARAM_MESSAGE_1_CARRY_1_KS_PBS, WOPBS_PARAM_MESSAGE_2_CARRY_2_KS_PBS,
WOPBS_PARAM_MESSAGE_3_CARRY_3_KS_PBS, WOPBS_PARAM_MESSAGE_4_CARRY_4_KS_PBS,

View File

@@ -7,8 +7,8 @@ use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::Path;
use tfhe::integer::U256;
use tfhe::keycache::NamedParam;
use tfhe::prelude::*;
use tfhe::shortint::keycache::NamedParam;
use tfhe::shortint::parameters::{
PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS, PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS,
};

View File

@@ -4,7 +4,7 @@ use std::path::Path;
use tfhe::boolean::parameters::{BooleanParameters, VEC_BOOLEAN_PARAM};
use tfhe::core_crypto::commons::dispersion::StandardDev;
use tfhe::core_crypto::commons::parameters::{GlweDimension, LweDimension, PolynomialSize};
use tfhe::shortint::keycache::NamedParam;
use tfhe::keycache::NamedParam;
use tfhe::shortint::parameters::multi_bit::ALL_MULTI_BIT_PARAMETER_VEC;
use tfhe::shortint::parameters::{ShortintParameterSet, ALL_PARAMETER_VEC};

View File

@@ -5,7 +5,8 @@ use crate::utilities::{write_to_json, OperatorType};
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::Path;
use tfhe::shortint::keycache::{NamedParam, KEY_CACHE};
use tfhe::keycache::NamedParam;
use tfhe::shortint::keycache::KEY_CACHE;
use tfhe::shortint::parameters::{
PARAM_MESSAGE_1_CARRY_1_KS_PBS, PARAM_MESSAGE_2_CARRY_2_KS_PBS, PARAM_MESSAGE_3_CARRY_3_KS_PBS,
PARAM_MESSAGE_4_CARRY_4_KS_PBS, PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_2_KS_PBS,

View File

@@ -9,9 +9,9 @@ use std::fs;
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::path::Path;
use tfhe::keycache::NamedParam;
use tfhe::shortint::keycache::{
NamedParam, PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS_NAME,
PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS_NAME,
PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS_NAME, PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS_NAME,
};
use tfhe::shortint::parameters::{
PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_KS_PBS, PARAM_MESSAGE_2_CARRY_2_COMPACT_PK_PBS_KS,

View File

@@ -24,7 +24,8 @@ pub use crate::core_crypto::commons::parameters::{
LweDimension, PolynomialSize,
};
use crate::shortint::keycache::NamedParam;
#[cfg(any(test, doctest, feature = "internal-keycache"))]
use crate::keycache::NamedParam;
use serde::{Deserialize, Serialize};
/// A set of cryptographic parameters for homomorphic Boolean circuit evaluation.
@@ -197,6 +198,7 @@ pub const TFHE_LIB_PARAMETERS: BooleanParameters = BooleanParameters {
pub const VEC_BOOLEAN_PARAM: [BooleanParameters; 2] = [DEFAULT_PARAMETERS, TFHE_LIB_PARAMETERS];
#[cfg(any(test, doctest, feature = "internal-keycache"))]
impl NamedParam for BooleanParameters {
fn name(&self) -> &'static str {
if *self == DEFAULT_PARAMETERS {

View File

@@ -0,0 +1,220 @@
pub use utils::{
FileStorage, KeyCache as ImplKeyCache, NamedParam, PersistentStorage,
SharedKey as GenericSharedKey,
};
#[macro_use]
pub mod utils {
use fs2::FileExt;
use once_cell::sync::OnceCell;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::ops::Deref;
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
pub trait PersistentStorage<P, K> {
fn load(&self, param: P) -> Option<K>;
fn store(&self, param: P, key: &K);
}
pub trait NamedParam {
fn name(&self) -> &'static str;
}
#[macro_export]
macro_rules! named_params_impl(
(expose $($const_param:ident),* $(,)? ) => {
$(
paste::paste! {
pub const [<$const_param _NAME>]: &'static str = stringify!($const_param);
}
)*
};
($param_type:ty => $($const_param:ident),* $(,)? ) => {
named_params_impl!(expose $($const_param),*);
impl NamedParam for $param_type {
fn name(&self) -> &'static str {
named_params_impl!({*self; $param_type} == ( $($const_param),* ));
}
}
};
({$thing:expr; $param_type:ty} == ( $($const_param:ident),* $(,)? )) => {
$(
paste::paste! {
if $thing == <$param_type>::from($const_param) {
return [<$const_param _NAME>];
}
}
)*
panic!("Unnamed parameters");
}
);
pub struct FileStorage {
prefix: String,
}
impl FileStorage {
pub fn new(prefix: String) -> Self {
Self { prefix }
}
}
impl<P, K> PersistentStorage<P, K> for FileStorage
where
P: NamedParam + DeserializeOwned + Serialize + PartialEq,
K: DeserializeOwned + Serialize,
{
fn load(&self, param: P) -> Option<K> {
let mut path_buf = PathBuf::with_capacity(256);
path_buf.push(&self.prefix);
path_buf.push(param.name());
path_buf.set_extension("bin");
if path_buf.exists() {
let file = File::open(&path_buf).unwrap();
// Lock for reading
file.lock_shared().unwrap();
let file_reader = BufReader::new(file);
bincode::deserialize_from::<_, (P, K)>(file_reader)
.ok()
.and_then(|(p, k)| if p == param { Some(k) } else { None })
} else {
None
}
}
fn store(&self, param: P, key: &K) {
let mut path_buf = PathBuf::with_capacity(256);
path_buf.push(&self.prefix);
std::fs::create_dir_all(&path_buf).unwrap();
path_buf.push(param.name());
path_buf.set_extension("bin");
let file = File::create(&path_buf).unwrap();
// Lock for writing
file.lock_exclusive().unwrap();
let file_writer = BufWriter::new(file);
bincode::serialize_into(file_writer, &(param, key)).unwrap();
}
}
pub struct SharedKey<K> {
inner: Arc<OnceCell<K>>,
}
impl<K> Clone for SharedKey<K> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<K> Deref for SharedKey<K> {
type Target = K;
fn deref(&self) -> &Self::Target {
self.inner.get().unwrap()
}
}
pub struct KeyCache<P, K, S> {
// Where the keys will be stored persistently
// So they are not generated between each run
persistent_storage: S,
// Temporary memory storage to avoid querying the persistent storage each time
// the outer Arc makes it so that we don't clone the OnceCell contents when initializing it
memory_storage: RwLock<Vec<(P, SharedKey<K>)>>,
}
impl<P, K, S> KeyCache<P, K, S> {
pub fn new(storage: S) -> Self {
Self {
persistent_storage: storage,
memory_storage: RwLock::new(vec![]),
}
}
pub fn clear_in_memory_cache(&self) {
let mut memory_storage = self.memory_storage.write().unwrap();
memory_storage.clear();
}
}
impl<P, K, S> KeyCache<P, K, S>
where
P: Copy + PartialEq + NamedParam,
S: PersistentStorage<P, K>,
K: From<P> + Clone,
{
pub fn get(&self, param: P) -> SharedKey<K> {
self.with_key(param, |k| k.clone())
}
pub fn with_key<F, R>(&self, param: P, f: F) -> R
where
F: FnOnce(&SharedKey<K>) -> R,
{
let load_from_persistent_storage = || {
// we check if we can load the key from persistent storage
let persistent_storage = &self.persistent_storage;
let maybe_key = persistent_storage.load(param);
match maybe_key {
Some(key) => key,
None => {
let key = K::from(param);
persistent_storage.store(param, &key);
key
}
}
};
let try_load_from_memory_and_init = || {
// we only hold a read lock for a short duration to find the key
let maybe_shared_cell = {
let memory_storage = self.memory_storage.read().unwrap();
memory_storage
.iter()
.find(|(p, _)| *p == param)
.map(|param_key| param_key.1.clone())
};
if let Some(shared_cell) = maybe_shared_cell {
shared_cell.inner.get_or_init(load_from_persistent_storage);
Ok(shared_cell)
} else {
Err(())
}
};
match try_load_from_memory_and_init() {
Ok(result) => f(&result),
Err(()) => {
{
// we only hold a write lock for a short duration to push the lazily
// evaluated key without actually evaluating the key
let mut memory_storage = self.memory_storage.write().unwrap();
if !memory_storage.iter().any(|(p, _)| *p == param) {
memory_storage.push((
param,
SharedKey {
inner: Arc::new(OnceCell::new()),
},
));
}
}
f(&try_load_from_memory_and_init().ok().unwrap())
}
}
}
}
}

View File

@@ -55,5 +55,10 @@ mod test_user_docs;
/// cbindgen:ignore
#[cfg(any(feature = "boolean", feature = "shortint", feature = "integer"))]
pub(crate) mod high_level_api;
#[cfg(any(feature = "boolean", feature = "shortint", feature = "integer"))]
pub use high_level_api::*;
/// cbindgen:ignore
#[cfg(any(test, doctest, feature = "internal-keycache"))]
pub mod keycache;

View File

@@ -1,3 +1,5 @@
use crate::keycache::*;
use crate::named_params_impl;
use crate::shortint::parameters::multi_bit::*;
use crate::shortint::parameters::parameters_compact_pk::*;
use crate::shortint::parameters::parameters_wopbs::*;
@@ -9,227 +11,6 @@ use crate::shortint::{ClientKey, ServerKey};
use lazy_static::*;
use serde::{Deserialize, Serialize};
pub use utils::{
FileStorage, KeyCache as ImplKeyCache, NamedParam, PersistentStorage,
SharedKey as GenericSharedKey,
};
#[macro_use]
pub mod utils {
use fs2::FileExt;
use once_cell::sync::OnceCell;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::fs::File;
use std::io::{BufReader, BufWriter};
use std::ops::Deref;
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
pub trait PersistentStorage<P, K> {
fn load(&self, param: P) -> Option<K>;
fn store(&self, param: P, key: &K);
}
pub trait NamedParam {
fn name(&self) -> &'static str;
}
#[macro_export]
macro_rules! named_params_impl(
(expose $($const_param:ident),* $(,)? ) => {
$(
paste::paste! {
pub const [<$const_param _NAME>]: &'static str = stringify!($const_param);
}
)*
};
($param_type:ty => $($const_param:ident),* $(,)? ) => {
named_params_impl!(expose $($const_param),*);
impl NamedParam for $param_type {
fn name(&self) -> &'static str {
named_params_impl!({*self; $param_type} == ( $($const_param),* ));
}
}
};
({$thing:expr; $param_type:ty} == ( $($const_param:ident),* $(,)? )) => {
$(
paste::paste! {
if $thing == <$param_type>::from($const_param) {
return [<$const_param _NAME>];
}
}
)*
panic!("Unnamed parameters");
}
);
pub struct FileStorage {
prefix: String,
}
impl FileStorage {
pub fn new(prefix: String) -> Self {
Self { prefix }
}
}
impl<P, K> PersistentStorage<P, K> for FileStorage
where
P: NamedParam + DeserializeOwned + Serialize + PartialEq,
K: DeserializeOwned + Serialize,
{
fn load(&self, param: P) -> Option<K> {
let mut path_buf = PathBuf::with_capacity(256);
path_buf.push(&self.prefix);
path_buf.push(param.name());
path_buf.set_extension("bin");
if path_buf.exists() {
let file = File::open(&path_buf).unwrap();
// Lock for reading
file.lock_shared().unwrap();
let file_reader = BufReader::new(file);
bincode::deserialize_from::<_, (P, K)>(file_reader)
.ok()
.and_then(|(p, k)| if p == param { Some(k) } else { None })
} else {
None
}
}
fn store(&self, param: P, key: &K) {
let mut path_buf = PathBuf::with_capacity(256);
path_buf.push(&self.prefix);
std::fs::create_dir_all(&path_buf).unwrap();
path_buf.push(param.name());
path_buf.set_extension("bin");
let file = File::create(&path_buf).unwrap();
// Lock for writing
file.lock_exclusive().unwrap();
let file_writer = BufWriter::new(file);
bincode::serialize_into(file_writer, &(param, key)).unwrap();
}
}
pub struct SharedKey<K> {
inner: Arc<OnceCell<K>>,
}
impl<K> Clone for SharedKey<K> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<K> Deref for SharedKey<K> {
type Target = K;
fn deref(&self) -> &Self::Target {
self.inner.get().unwrap()
}
}
pub struct KeyCache<P, K, S> {
// Where the keys will be stored persistently
// So they are not generated between each run
persistent_storage: S,
// Temporary memory storage to avoid querying the persistent storage each time
// the outer Arc makes it so that we don't clone the OnceCell contents when initializing it
memory_storage: RwLock<Vec<(P, SharedKey<K>)>>,
}
impl<P, K, S> KeyCache<P, K, S> {
pub fn new(storage: S) -> Self {
Self {
persistent_storage: storage,
memory_storage: RwLock::new(vec![]),
}
}
pub fn clear_in_memory_cache(&self) {
let mut memory_storage = self.memory_storage.write().unwrap();
memory_storage.clear();
}
}
impl<P, K, S> KeyCache<P, K, S>
where
P: Copy + PartialEq + NamedParam,
S: PersistentStorage<P, K>,
K: From<P> + Clone,
{
pub fn get(&self, param: P) -> SharedKey<K> {
self.with_key(param, |k| k.clone())
}
pub fn with_key<F, R>(&self, param: P, f: F) -> R
where
F: FnOnce(&SharedKey<K>) -> R,
{
let load_from_persistent_storage = || {
// we check if we can load the key from persistent storage
let persistent_storage = &self.persistent_storage;
let maybe_key = persistent_storage.load(param);
match maybe_key {
Some(key) => key,
None => {
let key = K::from(param);
persistent_storage.store(param, &key);
key
}
}
};
let try_load_from_memory_and_init = || {
// we only hold a read lock for a short duration to find the key
let maybe_shared_cell = {
let memory_storage = self.memory_storage.read().unwrap();
memory_storage
.iter()
.find(|(p, _)| *p == param)
.map(|param_key| param_key.1.clone())
};
if let Some(shared_cell) = maybe_shared_cell {
shared_cell.inner.get_or_init(load_from_persistent_storage);
Ok(shared_cell)
} else {
Err(())
}
};
match try_load_from_memory_and_init() {
Ok(result) => f(&result),
Err(()) => {
{
// we only hold a write lock for a short duration to push the lazily
// evaluated key without actually evaluating the key
let mut memory_storage = self.memory_storage.write().unwrap();
if !memory_storage.iter().any(|(p, _)| *p == param) {
memory_storage.push((
param,
SharedKey {
inner: Arc::new(OnceCell::new()),
},
));
}
}
f(&try_load_from_memory_and_init().ok().unwrap())
}
}
}
}
}
named_params_impl!( ShortintParameterSet =>
PARAM_MESSAGE_1_CARRY_0_KS_PBS,
PARAM_MESSAGE_1_CARRY_1_KS_PBS,