diff --git a/.gitignore b/.gitignore index 13088a25b..0b77f566a 100644 --- a/.gitignore +++ b/.gitignore @@ -3,9 +3,9 @@ target/ .vscode/ # Path we use for internal-keycache during tests -keys/ +./keys/ # In case of symlinked keys -keys +./keys **/Cargo.lock **/*.bin diff --git a/tfhe/src/integer/u256.rs b/tfhe/src/integer/u256.rs index 303f5f7c9..7f890c79b 100644 --- a/tfhe/src/integer/u256.rs +++ b/tfhe/src/integer/u256.rs @@ -68,6 +68,30 @@ impl From<(u128, u128)> for U256 { } } +impl From for U256 { + fn from(value: u8) -> Self { + Self::from(value as u128) + } +} + +impl From for U256 { + fn from(value: u16) -> Self { + Self::from(value as u128) + } +} + +impl From for U256 { + fn from(value: u32) -> Self { + Self::from(value as u128) + } +} + +impl From for U256 { + fn from(value: u64) -> Self { + Self::from(value as u128) + } +} + impl From for U256 { fn from(value: u128) -> Self { Self([ diff --git a/tfhe/src/lib.rs b/tfhe/src/lib.rs index 1fd5e3fb9..086cc895b 100644 --- a/tfhe/src/lib.rs +++ b/tfhe/src/lib.rs @@ -4,6 +4,8 @@ #![cfg_attr(feature = "__wasm_api", allow(dead_code))] #![cfg_attr(feature = "nightly-avx512", feature(stdsimd, avx512_target_feature))] +#![cfg_attr(all(doc, not(doctest)), feature(doc_auto_cfg))] +#![cfg_attr(all(doc, not(doctest)), feature(doc_cfg))] #![deny(rustdoc::broken_intra_doc_links)] #[cfg(feature = "__c_api")] @@ -49,3 +51,6 @@ pub use js_on_wasm_api::*; feature = "integer" ))] mod test_user_docs; + +pub(crate) mod typed_api; +pub use typed_api::*; diff --git a/tfhe/src/typed_api/booleans/client_key.rs b/tfhe/src/typed_api/booleans/client_key.rs new file mode 100644 index 000000000..a25c61cee --- /dev/null +++ b/tfhe/src/typed_api/booleans/client_key.rs @@ -0,0 +1,26 @@ +use crate::boolean::client_key::ClientKey; + +use serde::{Deserialize, Serialize}; + +use super::parameters::BooleanParameterSet; +use super::types::static_::StaticBoolParameters; +use super::FheBoolParameters; + +#[cfg_attr(all(doc, not(doctest)), cfg(feature = "boolean"))] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GenericBoolClientKey

+where + P: BooleanParameterSet, +{ + pub(in crate::typed_api::booleans) key: ClientKey, + _marker: std::marker::PhantomData

, +} + +impl From for GenericBoolClientKey { + fn from(parameters: FheBoolParameters) -> Self { + Self { + key: ClientKey::new(¶meters.into()), + _marker: Default::default(), + } + } +} diff --git a/tfhe/src/typed_api/booleans/keys.rs b/tfhe/src/typed_api/booleans/keys.rs new file mode 100644 index 000000000..20001b7a3 --- /dev/null +++ b/tfhe/src/typed_api/booleans/keys.rs @@ -0,0 +1,5 @@ +define_key_structs! { + Boolean { + bool: FheBool, + } +} diff --git a/tfhe/src/typed_api/booleans/mod.rs b/tfhe/src/typed_api/booleans/mod.rs new file mode 100644 index 000000000..e838ba7d2 --- /dev/null +++ b/tfhe/src/typed_api/booleans/mod.rs @@ -0,0 +1,14 @@ +pub(crate) use keys::{BooleanClientKey, BooleanConfig, BooleanPublicKey, BooleanServerKey}; +pub use parameters::FheBoolParameters; +pub use types::{CompressedFheBool, FheBool, GenericBool}; + +mod client_key; +mod keys; +mod public_key; +mod server_key; +mod types; + +mod parameters; + +#[cfg(test)] +mod tests; diff --git a/tfhe/src/typed_api/booleans/parameters.rs b/tfhe/src/typed_api/booleans/parameters.rs new file mode 100644 index 000000000..03954b07f --- /dev/null +++ b/tfhe/src/typed_api/booleans/parameters.rs @@ -0,0 +1,76 @@ +use crate::boolean::parameters::{ + BooleanParameters, DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, + PolynomialSize, StandardDev, +}; +pub use crate::boolean::parameters::{DEFAULT_PARAMETERS, TFHE_LIB_PARAMETERS}; + +use serde::{Deserialize, Serialize}; + +pub trait BooleanParameterSet: Into { + type Id: Copy; +} + +/// Parameters for [FheBool]. +/// +/// [FheBool]: crate::typed_api::FheBool +#[cfg_attr(all(doc, not(doctest)), cfg(feature = "boolean"))] +#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct FheBoolParameters { + pub lwe_dimension: LweDimension, + pub glwe_dimension: GlweDimension, + pub polynomial_size: PolynomialSize, + pub lwe_modular_std_dev: StandardDev, + pub glwe_modular_std_dev: StandardDev, + pub pbs_base_log: DecompositionBaseLog, + pub pbs_level: DecompositionLevelCount, + pub ks_base_log: DecompositionBaseLog, + pub ks_level: DecompositionLevelCount, +} + +impl FheBoolParameters { + pub fn tfhe_lib() -> Self { + Self::from_static(&TFHE_LIB_PARAMETERS) + } + + fn from_static(params: &'static BooleanParameters) -> Self { + (*params).into() + } +} + +impl Default for FheBoolParameters { + fn default() -> Self { + Self::from_static(&DEFAULT_PARAMETERS) + } +} + +impl From for BooleanParameters { + fn from(params: FheBoolParameters) -> Self { + Self { + lwe_dimension: params.lwe_dimension, + glwe_dimension: params.glwe_dimension, + polynomial_size: params.polynomial_size, + lwe_modular_std_dev: params.lwe_modular_std_dev, + glwe_modular_std_dev: params.glwe_modular_std_dev, + pbs_base_log: params.pbs_base_log, + pbs_level: params.pbs_level, + ks_base_log: params.ks_base_log, + ks_level: params.ks_level, + } + } +} + +impl From for FheBoolParameters { + fn from(params: BooleanParameters) -> FheBoolParameters { + Self { + lwe_dimension: params.lwe_dimension, + glwe_dimension: params.glwe_dimension, + polynomial_size: params.polynomial_size, + lwe_modular_std_dev: params.lwe_modular_std_dev, + glwe_modular_std_dev: params.glwe_modular_std_dev, + pbs_base_log: params.pbs_base_log, + pbs_level: params.pbs_level, + ks_base_log: params.ks_base_log, + ks_level: params.ks_level, + } + } +} diff --git a/tfhe/src/typed_api/booleans/public_key.rs b/tfhe/src/typed_api/booleans/public_key.rs new file mode 100644 index 000000000..3f42d606c --- /dev/null +++ b/tfhe/src/typed_api/booleans/public_key.rs @@ -0,0 +1,27 @@ +use crate::typed_api::booleans::client_key::GenericBoolClientKey; +use crate::typed_api::booleans::parameters::BooleanParameterSet; + +use serde::{Deserialize, Serialize}; + +#[cfg_attr(all(doc, not(doctest)), cfg(feature = "boolean"))] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GenericBoolPublicKey

+where + P: BooleanParameterSet, +{ + pub(in crate::typed_api::booleans) key: crate::boolean::public_key::PublicKey, + _marker: std::marker::PhantomData

, +} + +impl

GenericBoolPublicKey

+where + P: BooleanParameterSet, +{ + pub fn new(client_key: &GenericBoolClientKey

) -> Self { + let key = crate::boolean::public_key::PublicKey::new(&client_key.key); + Self { + key, + _marker: Default::default(), + } + } +} diff --git a/tfhe/src/typed_api/booleans/server_key.rs b/tfhe/src/typed_api/booleans/server_key.rs new file mode 100644 index 000000000..ae9ae1ca9 --- /dev/null +++ b/tfhe/src/typed_api/booleans/server_key.rs @@ -0,0 +1,93 @@ +use super::client_key::GenericBoolClientKey; +use super::parameters::BooleanParameterSet; +use super::types::GenericBool; +use crate::boolean::server_key::{BinaryBooleanGates, ServerKey}; + +use serde::{Deserialize, Serialize}; + +#[cfg_attr(all(doc, not(doctest)), cfg(feature = "boolean"))] +#[derive(Clone, Serialize, Deserialize)] +pub struct GenericBoolServerKey

+where + P: BooleanParameterSet, +{ + pub(in crate::typed_api::booleans) key: ServerKey, + _marker: std::marker::PhantomData

, +} + +impl

GenericBoolServerKey

+where + P: BooleanParameterSet, +{ + pub(crate) fn new(key: &GenericBoolClientKey

) -> Self { + Self { + key: ServerKey::new(&key.key), + _marker: Default::default(), + } + } + + pub(in crate::typed_api::booleans) fn and( + &self, + lhs: &GenericBool

, + rhs: &GenericBool

, + ) -> GenericBool

{ + let ciphertext = self.key.and(&lhs.ciphertext, &rhs.ciphertext); + GenericBool::

::new(ciphertext, lhs.id) + } + + pub(in crate::typed_api::booleans) fn or( + &self, + lhs: &GenericBool

, + rhs: &GenericBool

, + ) -> GenericBool

{ + let ciphertext = self.key.or(&lhs.ciphertext, &rhs.ciphertext); + GenericBool::

::new(ciphertext, lhs.id) + } + + pub(in crate::typed_api::booleans) fn xor( + &self, + lhs: &GenericBool

, + rhs: &GenericBool

, + ) -> GenericBool

{ + let ciphertext = self.key.xor(&lhs.ciphertext, &rhs.ciphertext); + GenericBool::

::new(ciphertext, lhs.id) + } + + pub(in crate::typed_api::booleans) fn xnor( + &self, + lhs: &GenericBool

, + rhs: &GenericBool

, + ) -> GenericBool

{ + let ciphertext = self.key.xnor(&lhs.ciphertext, &rhs.ciphertext); + GenericBool::

::new(ciphertext, lhs.id) + } + + pub(in crate::typed_api::booleans) fn nand( + &self, + lhs: &GenericBool

, + rhs: &GenericBool

, + ) -> GenericBool

{ + let ciphertext = self.key.nand(&lhs.ciphertext, &rhs.ciphertext); + GenericBool::

::new(ciphertext, lhs.id) + } + + pub(in crate::typed_api::booleans) fn not(&self, lhs: &GenericBool

) -> GenericBool

{ + let ciphertext = self.key.not(&lhs.ciphertext); + GenericBool::

::new(ciphertext, lhs.id) + } + + #[allow(dead_code)] + pub(in crate::typed_api::booleans) fn mux( + &self, + condition: &GenericBool

, + then_result: &GenericBool

, + else_result: &GenericBool

, + ) -> GenericBool

{ + let ciphertext = self.key.mux( + &condition.ciphertext, + &then_result.ciphertext, + &else_result.ciphertext, + ); + GenericBool::

::new(ciphertext, condition.id) + } +} diff --git a/tfhe/src/typed_api/booleans/tests.rs b/tfhe/src/typed_api/booleans/tests.rs new file mode 100644 index 000000000..80b06f9b2 --- /dev/null +++ b/tfhe/src/typed_api/booleans/tests.rs @@ -0,0 +1,192 @@ +// Without this, clippy will conplain about equal expressions to `ffalse & ffalse` +// However since we overloaded these operators, we want to test them to see +// if they are correct +#![allow(clippy::eq_op)] +#![allow(clippy::bool_assert_comparison)] +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +use crate::typed_api::prelude::*; +use crate::typed_api::{ + generate_keys, set_server_key, ClientKey, CompressedFheBool, ConfigBuilder, FheBool, + FheBoolParameters, +}; + +fn setup_static_default() -> ClientKey { + let config = ConfigBuilder::all_disabled().enable_default_bool().build(); + + let (my_keys, server_keys) = generate_keys(config); + + set_server_key(server_keys); + my_keys +} + +fn setup_static_tfhe() -> ClientKey { + let config = ConfigBuilder::all_disabled() + .enable_custom_bool(FheBoolParameters::tfhe_lib()) + .build(); + + let (my_keys, server_keys) = generate_keys(config); + + set_server_key(server_keys); + my_keys +} + +#[test] +fn test_xor_truth_table_static_default() { + let keys = setup_static_default(); + + let ttrue = FheBool::encrypt(true, &keys); + let ffalse = FheBool::encrypt(false, &keys); + + xor_truth_table(&ttrue, &ffalse, &keys); +} + +#[test] +fn test_and_truth_table_static_default() { + let keys = setup_static_default(); + + let ttrue = FheBool::encrypt(true, &keys); + let ffalse = FheBool::encrypt(false, &keys); + + and_truth_table(&ttrue, &ffalse, &keys); +} + +#[test] +fn test_or_truth_table_static_default() { + let keys = setup_static_default(); + + let ttrue = FheBool::encrypt(true, &keys); + let ffalse = FheBool::encrypt(false, &keys); + + or_truth_table(&ttrue, &ffalse, &keys); +} + +#[test] +fn test_not_truth_table_static_default() { + let keys = setup_static_default(); + + let ttrue = FheBool::encrypt(true, &keys); + let ffalse = FheBool::encrypt(false, &keys); + + not_truth_table(&ttrue, &ffalse, &keys); +} + +#[test] +fn test_xor_truth_table_static_tfhe() { + let keys = setup_static_tfhe(); + + let ttrue = FheBool::encrypt(true, &keys); + let ffalse = FheBool::encrypt(false, &keys); + + xor_truth_table(&ttrue, &ffalse, &keys); +} + +#[test] +fn test_and_truth_table_static_tfhe() { + let keys = setup_static_tfhe(); + + let ttrue = FheBool::encrypt(true, &keys); + let ffalse = FheBool::encrypt(false, &keys); + + and_truth_table(&ttrue, &ffalse, &keys); +} + +#[test] +fn test_or_truth_table_static_tfhe() { + let keys = setup_static_tfhe(); + + let ttrue = FheBool::encrypt(true, &keys); + let ffalse = FheBool::encrypt(false, &keys); + + or_truth_table(&ttrue, &ffalse, &keys); +} + +#[test] +fn test_not_truth_table_static_tfhe() { + let keys = setup_static_tfhe(); + + let ttrue = FheBool::encrypt(true, &keys); + let ffalse = FheBool::encrypt(false, &keys); + + not_truth_table(&ttrue, &ffalse, &keys); +} + +fn xor_truth_table<'a, BoolType>(ttrue: &'a BoolType, ffalse: &'a BoolType, key: &ClientKey) +where + &'a BoolType: BitXor<&'a BoolType, Output = BoolType>, + BoolType: FheDecrypt, +{ + let r = ffalse ^ ffalse; + assert_eq!(r.decrypt(key), false); + + let r = ffalse ^ ttrue; + assert_eq!(r.decrypt(key), true); + + let r = ttrue ^ ffalse; + assert_eq!(r.decrypt(key), true); + + let r = ttrue ^ ttrue; + assert_eq!(r.decrypt(key), false); +} + +fn and_truth_table<'a, BoolType>(ttrue: &'a BoolType, ffalse: &'a BoolType, key: &ClientKey) +where + &'a BoolType: BitAnd<&'a BoolType, Output = BoolType>, + BoolType: FheDecrypt, +{ + let r = ffalse & ffalse; + assert_eq!(r.decrypt(key), false); + + let r = ffalse & ttrue; + assert_eq!(r.decrypt(key), false); + + let r = ttrue & ffalse; + assert_eq!(r.decrypt(key), false); + + let r = ttrue & ttrue; + assert_eq!(r.decrypt(key), true); +} + +fn or_truth_table<'a, BoolType>(ttrue: &'a BoolType, ffalse: &'a BoolType, key: &ClientKey) +where + &'a BoolType: BitOr<&'a BoolType, Output = BoolType>, + BoolType: FheDecrypt, +{ + let r = ffalse | ffalse; + assert_eq!(r.decrypt(key), false); + + let r = ffalse | ttrue; + assert_eq!(r.decrypt(key), true); + + let r = ttrue | ffalse; + assert_eq!(r.decrypt(key), true); + + let r = ttrue | ttrue; + assert_eq!(r.decrypt(key), true); +} + +fn not_truth_table<'a, BoolType>(ttrue: &'a BoolType, ffalse: &'a BoolType, key: &ClientKey) +where + &'a BoolType: Not, + BoolType: FheDecrypt, +{ + let r = !ffalse; + assert_eq!(r.decrypt(key), true); + + let r = !ttrue; + assert_eq!(r.decrypt(key), false); +} + +#[test] +fn test_compressed_bool() { + let keys = setup_static_default(); + + let cttrue = CompressedFheBool::encrypt(true, &keys); + let cffalse = CompressedFheBool::encrypt(false, &keys); + + let a = FheBool::from(cttrue); + let b = FheBool::from(cffalse); + + assert_eq!(a.decrypt(&keys), true); + assert_eq!(b.decrypt(&keys), false); +} diff --git a/tfhe/src/typed_api/booleans/types/base.rs b/tfhe/src/typed_api/booleans/types/base.rs new file mode 100644 index 000000000..7e0f8a5be --- /dev/null +++ b/tfhe/src/typed_api/booleans/types/base.rs @@ -0,0 +1,309 @@ +use std::borrow::Borrow; +use std::ops::{BitAnd, BitOr, BitXor}; + +use crate::boolean::ciphertext::{Ciphertext, CompressedCiphertext}; +use serde::{Deserialize, Serialize}; + +use crate::typed_api::booleans::client_key::GenericBoolClientKey; +use crate::typed_api::booleans::parameters::BooleanParameterSet; +use crate::typed_api::booleans::public_key::GenericBoolPublicKey; +use crate::typed_api::booleans::server_key::GenericBoolServerKey; +use crate::typed_api::global_state::WithGlobalKey; +use crate::typed_api::keys::{ClientKey, PublicKey, RefKeyFromKeyChain, RefKeyFromPublicKeyChain}; +use crate::typed_api::traits::{ + FheDecrypt, FheEncrypt, FheEq, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, +}; + +/// The FHE boolean data type. +/// +/// To be able to use this type, the cargo feature `booleans` must be enabled, +/// and your config should also enable the type with either default parameters or custom ones. +/// +/// # Example +/// ```rust +/// use tfhe::prelude::*; +/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheBool}; +/// +/// // Enable booleans in the config +/// let config = ConfigBuilder::all_disabled().enable_default_bool().build(); +/// +/// // With the booleans enabled in the config, the needed keys and details +/// // can be taken care of. +/// let (client_key, server_key) = generate_keys(config); +/// +/// let ttrue = FheBool::encrypt(true, &client_key); +/// let ffalse = FheBool::encrypt(false, &client_key); +/// +/// // Do not forget to set the server key before doing any computation +/// set_server_key(server_key); +/// +/// let fhe_result = ttrue & ffalse; +/// +/// let clear_result = fhe_result.decrypt(&client_key); +/// assert_eq!(clear_result, false); +/// ``` +#[cfg_attr(all(doc, not(doctest)), cfg(feature = "boolean"))] +#[derive(Clone, Serialize, Deserialize)] +pub struct GenericBool

+where + P: BooleanParameterSet, +{ + pub(in crate::typed_api::booleans) ciphertext: Ciphertext, + pub(in crate::typed_api::booleans) id: P::Id, +} + +pub struct CompressedBool

+where + P: BooleanParameterSet, +{ + pub(in crate::typed_api::booleans) ciphertext: CompressedCiphertext, + pub(in crate::typed_api::booleans) id: P::Id, +} + +impl

GenericBool

+where + P: BooleanParameterSet, +{ + pub(in crate::typed_api::booleans) fn new(ciphertext: Ciphertext, id: P::Id) -> Self { + Self { ciphertext, id } + } +} + +impl

GenericBool

+where + P: BooleanParameterSet, + P::Id: WithGlobalKey>, +{ + pub fn nand(&self, rhs: &Self) -> Self { + self.id.with_unwrapped_global(|key| key.nand(self, rhs)) + } + + pub fn neq(&self, other: &Self) -> Self { + self.id.with_unwrapped_global(|key| { + let eq = key.xnor(self, other); + key.not(&eq) + }) + } +} + +impl FheEq for GenericBool

+where + B: Borrow, + P: BooleanParameterSet, + P::Id: WithGlobalKey>, +{ + type Output = Self; + + fn eq(&self, other: B) -> Self { + self.id + .with_unwrapped_global(|key| key.xnor(self, other.borrow())) + } +} + +#[allow(dead_code)] +#[cfg_attr(all(doc, not(doctest)), cfg(feature = "boolean"))] +pub fn if_then_else(ct_condition: B1, ct_then: B2, ct_else: B2) -> GenericBool

+where + B1: Borrow>, + B2: Borrow>, + P: BooleanParameterSet, + P::Id: WithGlobalKey>, +{ + let ct_condition = ct_condition.borrow(); + ct_condition + .id + .with_unwrapped_global(|key| key.mux(ct_condition, ct_then.borrow(), ct_else.borrow())) +} + +impl

CompressedBool

+where + P: BooleanParameterSet, +{ + fn new(ciphertext: CompressedCiphertext, id: P::Id) -> Self { + Self { ciphertext, id } + } +} + +impl

From> for GenericBool

+where + P: BooleanParameterSet, +{ + fn from(value: CompressedBool

) -> Self { + Self::new(value.ciphertext.into(), value.id) + } +} + +impl

FheTryEncrypt for CompressedBool

+where + P: BooleanParameterSet, + P::Id: RefKeyFromKeyChain> + Default, +{ + type Error = crate::typed_api::errors::Error; + + fn try_encrypt(value: bool, key: &ClientKey) -> Result { + let id = P::Id::default(); + let key = id.ref_key(key)?; + let ciphertext = key.key.encrypt_compressed(value); + Ok(CompressedBool::

::new(ciphertext, id)) + } +} + +impl

FheEncrypt for CompressedBool

+where + P: BooleanParameterSet, + P::Id: RefKeyFromKeyChain> + Default, +{ + #[track_caller] + fn encrypt(value: bool, key: &ClientKey) -> Self { + Self::try_encrypt(value, key).unwrap() + } +} + +impl

FheEncrypt for GenericBool

+where + P: BooleanParameterSet, + P::Id: RefKeyFromKeyChain> + Default, +{ + #[track_caller] + fn encrypt(value: bool, key: &ClientKey) -> Self { + >::try_encrypt(value, key).unwrap() + } +} + +impl

FheEncrypt for GenericBool

+where + P: BooleanParameterSet, + P::Id: RefKeyFromPublicKeyChain> + Default, +{ + #[track_caller] + fn encrypt(value: bool, key: &PublicKey) -> Self { + >::try_encrypt(value, key).unwrap() + } +} + +impl

FheTryEncrypt for GenericBool

+where + P: BooleanParameterSet, + P::Id: RefKeyFromKeyChain> + Default, +{ + type Error = crate::typed_api::errors::Error; + + fn try_encrypt(value: bool, key: &ClientKey) -> Result { + let id = P::Id::default(); + let key = id.ref_key(key)?; + let ciphertext = key.key.encrypt(value); + Ok(GenericBool::

::new(ciphertext, id)) + } +} + +impl

FheTryTrivialEncrypt for GenericBool

+where + P: BooleanParameterSet, + P::Id: Default + WithGlobalKey>, +{ + type Error = crate::typed_api::errors::Error; + + fn try_encrypt_trivial(value: bool) -> Result { + let id = P::Id::default(); + id.with_global(|key| { + let ciphertext = key.key.trivial_encrypt(value); + Ok(GenericBool::new(ciphertext, id)) + })? + } +} + +impl

FheTrivialEncrypt for GenericBool

+where + P: BooleanParameterSet, + P::Id: Default + WithGlobalKey>, +{ + #[track_caller] + fn encrypt_trivial(value: bool) -> Self { + Self::try_encrypt_trivial(value).unwrap() + } +} + +impl

FheTryEncrypt for GenericBool

+where + P: BooleanParameterSet, + P::Id: RefKeyFromPublicKeyChain> + Default, +{ + type Error = crate::typed_api::errors::Error; + + fn try_encrypt(value: bool, key: &PublicKey) -> Result { + let id = P::Id::default(); + let key = id.ref_key(key)?; + let ciphertext = key.key.encrypt(value); + Ok(GenericBool::

::new(ciphertext, id)) + } +} + +impl

FheDecrypt for GenericBool

+where + P: BooleanParameterSet, + P::Id: RefKeyFromKeyChain>, +{ + #[track_caller] + fn decrypt(&self, key: &ClientKey) -> bool { + let key = self.id.unwrapped_ref_key(key); + key.key.decrypt(&self.ciphertext) + } +} + +macro_rules! fhe_bool_impl_operation( + ($trait_name:ident($trait_method:ident) => $key_method:ident) => { + impl $trait_name for GenericBool

+ where B: Borrow>, + P: BooleanParameterSet, + P::Id: WithGlobalKey>, + { + type Output = GenericBool

; + + fn $trait_method(self, rhs: B) -> Self::Output { + <&Self as $trait_name>::$trait_method(&self, rhs) + } + } + + impl $trait_name for &GenericBool

+ where B: Borrow>, + P: BooleanParameterSet, + P::Id: WithGlobalKey>, + { + type Output = GenericBool

; + + fn $trait_method(self, rhs: B) -> Self::Output { + self.id.with_unwrapped_global(|key| { + key.$key_method(self, rhs.borrow()) + }) + } + } + }; +); + +fhe_bool_impl_operation!(BitAnd(bitand) => and); +fhe_bool_impl_operation!(BitOr(bitor) => or); +fhe_bool_impl_operation!(BitXor(bitxor) => xor); + +impl

::std::ops::Not for GenericBool

+where + P: BooleanParameterSet, + P::Id: WithGlobalKey>, +{ + type Output = Self; + + fn not(self) -> Self::Output { + self.id.with_unwrapped_global(|key| key.not(&self)) + } +} + +impl

::std::ops::Not for &GenericBool

+where + P: BooleanParameterSet, + P::Id: WithGlobalKey>, +{ + type Output = GenericBool

; + + fn not(self) -> Self::Output { + self.id.with_unwrapped_global(|key| key.not(self)) + } +} diff --git a/tfhe/src/typed_api/booleans/types/mod.rs b/tfhe/src/typed_api/booleans/types/mod.rs new file mode 100644 index 000000000..6fe907fd5 --- /dev/null +++ b/tfhe/src/typed_api/booleans/types/mod.rs @@ -0,0 +1,5 @@ +pub use base::{CompressedBool, GenericBool}; +pub use static_::{CompressedFheBool, FheBool}; + +mod base; +pub mod static_; diff --git a/tfhe/src/typed_api/booleans/types/static_.rs b/tfhe/src/typed_api/booleans/types/static_.rs new file mode 100644 index 000000000..0cebda494 --- /dev/null +++ b/tfhe/src/typed_api/booleans/types/static_.rs @@ -0,0 +1,83 @@ +use crate::boolean::parameters::BooleanParameters; +use serde::{Deserialize, Serialize}; + +use crate::typed_api::booleans::client_key::GenericBoolClientKey; +use crate::typed_api::booleans::parameters::BooleanParameterSet; +pub use crate::typed_api::booleans::parameters::FheBoolParameters; +use crate::typed_api::booleans::public_key::GenericBoolPublicKey; +use crate::typed_api::booleans::server_key::GenericBoolServerKey; +use crate::typed_api::booleans::types::CompressedBool; +use crate::typed_api::errors::Type; + +use super::base::GenericBool; + +// Has Overridable Operator: +// - and => BitAnd => & +// - not => Not => ! +// - or => BitOr => | +// - xor => BitXor => ^ +// +// Does Not have overridable operator: +// - mux -> But maybe by using a macro_rules with regular function we can have some sufficiently +// nice syntax sugar +// - nand +// - nor +// - xnor should be Eq => ==, But Eq requires to return a bool not a FHE bool So we cant do it +// - ||, && cannot be overloaded, maybe a well-crafted macro-rules that implements `if-else` could +// bring this syntax sugar + +/// The struct to identify the static boolean type +#[derive(Copy, Clone, Default, Serialize, Deserialize)] +pub struct FheBoolId; + +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +pub struct StaticBoolParameters(pub(crate) FheBoolParameters); + +impl From for BooleanParameters { + fn from(p: StaticBoolParameters) -> Self { + p.0.into() + } +} + +impl From for StaticBoolParameters { + fn from(p: FheBoolParameters) -> Self { + Self(p) + } +} + +impl BooleanParameterSet for StaticBoolParameters { + type Id = FheBoolId; +} + +pub type FheBool = GenericBool; +pub type CompressedFheBool = CompressedBool; +pub(in crate::typed_api::booleans) type FheBoolClientKey = + GenericBoolClientKey; +pub(in crate::typed_api::booleans) type FheBoolServerKey = + GenericBoolServerKey; +pub(in crate::typed_api::booleans) type FheBoolPublicKey = + GenericBoolPublicKey; + +impl_with_global_key!( + for FheBoolId { + key_type: FheBoolServerKey, + keychain_member: boolean_key.bool_key, + type_variant: Type::FheBool, + } +); + +impl_ref_key_from_keychain!( + for FheBoolId { + key_type: FheBoolClientKey, + keychain_member: boolean_key.bool_key, + type_variant: Type::FheBool, + } +); + +impl_ref_key_from_public_keychain!( + for FheBoolId { + key_type: FheBoolPublicKey, + keychain_member: boolean_key.bool_key, + type_variant: Type::FheBool, + } +); diff --git a/tfhe/src/typed_api/config.rs b/tfhe/src/typed_api/config.rs new file mode 100644 index 000000000..832d0d04d --- /dev/null +++ b/tfhe/src/typed_api/config.rs @@ -0,0 +1,167 @@ +#[cfg(feature = "boolean")] +use crate::typed_api::booleans::{BooleanConfig, FheBoolParameters}; +#[cfg(feature = "integer")] +use crate::typed_api::integers::IntegerConfig; +#[cfg(feature = "shortint")] +use crate::typed_api::shortints::ShortIntConfig; + +/// The config type +#[derive(Clone, Debug)] +pub struct Config { + #[cfg(feature = "boolean")] + pub(crate) boolean_config: BooleanConfig, + #[cfg(feature = "integer")] + pub(crate) integer_config: IntegerConfig, + #[cfg(feature = "shortint")] + pub(crate) shortint_config: ShortIntConfig, +} + +/// The builder to create your config +/// +/// This struct is what you will to use to build your +/// configuration. +/// +/// # Why ? +/// +/// The configuration is needed to select which types you are going to use or not +/// and which parameters you wish to use for these types (whether it is the default parameters or +/// some custom parameters). +/// +/// To be able to configure a type, its "cargo feature kind" must be enabled (see the [table]). +/// +/// The configuration is needed for the crate to be able to initialize and generate +/// all the needed client and server keys as well as other internal details. +/// +/// As generating these keys and details for types that you are not going to use would be +/// a waste of time and space (both memory and disk if you serialize), generating a config is an +/// important step. +/// +/// [table]: index.html#data-types +#[derive(Clone)] +pub struct ConfigBuilder { + config: Config, +} + +impl ConfigBuilder { + /// Create a new builder with all the data types activated with their default parameters + pub fn all_enabled() -> Self { + Self { + config: Config { + #[cfg(feature = "boolean")] + boolean_config: BooleanConfig::all_default(), + #[cfg(feature = "integer")] + integer_config: IntegerConfig::all_default(), + #[cfg(feature = "shortint")] + shortint_config: ShortIntConfig::all_default(), + }, + } + } + + /// Create a new builder with all the data types disabled + pub fn all_disabled() -> Self { + Self { + config: Config { + #[cfg(feature = "boolean")] + boolean_config: BooleanConfig::all_none(), + #[cfg(feature = "integer")] + integer_config: IntegerConfig::all_none(), + #[cfg(feature = "shortint")] + shortint_config: ShortIntConfig::all_none(), + }, + } + } + + #[cfg(feature = "boolean")] + pub fn enable_default_bool(mut self) -> Self { + self.config.boolean_config.bool_params = Some(Default::default()); + self + } + + #[cfg(feature = "boolean")] + pub fn enable_custom_bool(mut self, params: FheBoolParameters) -> Self { + self.config.boolean_config.bool_params = Some(params); + self + } + + #[cfg(feature = "boolean")] + pub fn disable_bool(mut self) -> Self { + self.config.boolean_config.bool_params = None; + self + } + + #[cfg(feature = "shortint")] + pub fn enable_default_uint2(mut self) -> Self { + self.config.shortint_config.uint2_params = Some(Default::default()); + self + } + + #[cfg(feature = "shortint")] + pub fn enable_default_uint3(mut self) -> Self { + self.config.shortint_config.uint3_params = Some(Default::default()); + self + } + + #[cfg(feature = "shortint")] + pub fn enable_default_uint4(mut self) -> Self { + self.config.shortint_config.uint4_params = Some(Default::default()); + self + } + + #[cfg(feature = "integer")] + pub fn enable_default_uint8(mut self) -> Self { + self.config.integer_config.uint8_params = Some(Default::default()); + self + } + + #[cfg(feature = "integer")] + pub fn disable_uint8(mut self) -> Self { + self.config.integer_config.uint8_params = None; + self + } + + #[cfg(feature = "integer")] + pub fn enable_default_uint12(mut self) -> Self { + self.config.integer_config.uint12_params = Some(Default::default()); + self + } + + #[cfg(feature = "integer")] + pub fn disable_uint12(mut self) -> Self { + self.config.integer_config.uint12_params = None; + self + } + + #[cfg(feature = "integer")] + pub fn enable_default_uint16(mut self) -> Self { + self.config.integer_config.uint16_params = Some(Default::default()); + self + } + + #[cfg(feature = "integer")] + pub fn disable_uint16(mut self) -> Self { + self.config.integer_config.uint16_params = None; + self + } + + #[cfg(feature = "integer")] + pub fn enable_default_uint256(mut self) -> Self { + self.config.integer_config.uint256_params = Some(Default::default()); + self + } + + #[cfg(feature = "integer")] + pub fn disable_uint256(mut self) -> Self { + self.config.integer_config.uint256_params = None; + self + } + + pub fn build(self) -> Config { + self.config + } +} + +impl From for Config { + fn from(builder: ConfigBuilder) -> Self { + builder.build() + } +} diff --git a/tfhe/src/typed_api/design.md b/tfhe/src/typed_api/design.md new file mode 100644 index 000000000..1d9ec32ae --- /dev/null +++ b/tfhe/src/typed_api/design.md @@ -0,0 +1,79 @@ +# typed_api + +The `typed_api` module main goal is to provide +an API that is higher level than what the `boolean`, `shortint`, `integer` +modules offers. + +The way it is done is by exposing FHE types `FheBool`, `FheUint2`, etc +that are closer to the `u8`, `u16` than what Ciphertext are, +this is mainly achieved by overloading operators (`+` , `-`, `*`, etc). + +Since all operations (add, sub, etc) have to be done via a `ServerKey` +it means it has to be managed by the `typed_api`, to hide it in order +to allow operator overloading. + +## How the FHE types are created in this module + +Crypto parameters `tfhe::boolean::Parameters` and `tfhe::shortint::Parameters` +are what defines the number of bits a Ciphertext can store. + +The way the different FHE types are created can be summarized as: + +> Instead of having one struct to represent multiple parameters, +> we will create one struct per each parameter we use. + + +To understand that last sentence a bit more we'll explain a simplified example on shortint: + +We want to provide FheUint2, FheUint4 that are based on `tfhe::shortint` + +We first create a wrapper struct, that will wrap the Ciphertext type. +This struct is generic over some type called `P`, it is this genericity +that will enable us to easily create our types by using type specialization / type aliases + +```rust +struct GenericShortint

{ + inner: shortint::Ciphertext + // other details +} +``` + +We also create "parameter structs" in order the be able to +specialize our generic wrapper struct with`type FheName = GenericWrapperStruct`. + +For example, depending on the values in a `shortint::Parameters` instance, the number of bits in the `shortint::Ciphertext` +is not the same: +* `PARAM_MESSAGE_2_CARRY_2` -> 2 bits of message, 2 bits of carry +* `PARAM_MESSAGE_4_CARRY_4` -> 4 bits of message, 4 bits of carry +(both are of type `shortint::Parameters`) +And, generally, some ciphertext encrypted with some parameters values A, +will not be compatible with another ciphertext encrypted with some parameters values B. + +So in the `typed_api` we create 2 structs `struct FheUint2Parameters { ... }` and `struct FheUint4Parameters { ... }` +which are made so that FheUint2Parameters only contains `PARAM_MESSAGE_2_CARRY_2` +and FheUint4Parameters only contains `PARAM_MESSAGE_4_CARRY_4`. + +This way, we can specialize our wrapper types: +* `type FheUint2 = GenericWrapperStruct` +* `type FheUint4 = GenericWrapperStruct` +and now we have two disctint types, that have specific crypto parameters associated with them. +Also, they are type safe (can't `+` a FheUint2 to a FheUint4 without a compilatation error +unless the implemenation explicitely allows it since the type are different, which is not the case if you +use the 'raw' shortint api) + +In practice it is a bit more complex as we have to introduce traits to internally manipulate the +"parameter struct", eg a trait to convert the `FheUint2Parameters` and `FheUint4Parameters` back into `shortint::Parameters`, + +```rust +pub trait ShortIntegerParameter: Copy + Into { + // ... +} +``` + +The `ShortIntegerParameter` is meant to be implemented on "parameter struct" +that map to specific `shortint::Parameters`, like FheUint2Parameters and FheUint2Parameters +does, and so we require the `Into` convertion to be able to internally +interact with the shortint API. + +The same wrapping proccess is done for ClientKey, ServerKey, PublicKey, etc. + diff --git a/tfhe/src/typed_api/details.rs b/tfhe/src/typed_api/details.rs new file mode 100644 index 000000000..aa60ecd34 --- /dev/null +++ b/tfhe/src/typed_api/details.rs @@ -0,0 +1,124 @@ +#[cfg(any(feature = "boolean", feature = "shortint", feature = "integer"))] +macro_rules! define_key_structs { + ( + $base_struct_name:ident { + $( + $name:ident: $base_ty_name:ident + ),* + $(,)? + } + ) => { + + ::paste::paste!{ + $( + use super::types::static_::{ + [<$base_ty_name Parameters>], + [<$base_ty_name ClientKey>], + [<$base_ty_name PublicKey>], + [<$base_ty_name ServerKey>] + }; + )* + + /////////////////////// + /// Config + /////////////////////// + #[derive(Clone, Debug)] + pub(crate) struct [<$base_struct_name Config>] { + $( + pub(crate) [<$name _params>]: Option<[<$base_ty_name Parameters>]>, + )* + } + + impl [<$base_struct_name Config>] { + pub(crate) fn all_default() -> Self { + Self { + $( + [<$name _params>]: Some(Default::default()), + )* + } + } + + pub(crate) fn all_none() -> Self { + Self { + $( + [<$name _params>]: None, + )* + } + } + } + + /////////////////////// + /// Client Key + /////////////////////// + #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] + pub(crate) struct [<$base_struct_name ClientKey>] { + $( + pub(super) [<$name _key>]: Option<[<$base_ty_name ClientKey>]>, + )* + } + + impl From<[<$base_struct_name Config>]> for [<$base_struct_name ClientKey>] { + fn from(config: [<$base_struct_name Config>]) -> Self { + Self { + $( + [<$name _key>]: config.[<$name _params>].map(<[<$base_ty_name ClientKey>]>::from), + )* + } + } + } + + /////////////////////// + /// Public Key + /////////////////////// + #[derive(Clone, Debug, serde::Deserialize, serde::Serialize)] + pub(crate) struct [<$base_struct_name PublicKey>] { + $( + pub(super) [<$name _key>]: Option<[<$base_ty_name PublicKey>]>, + )* + } + + impl [<$base_struct_name PublicKey>] { + pub(crate) fn new(client_key: &[<$base_struct_name ClientKey>]) -> Self { + Self { + $( + [<$name _key>]: client_key + .[<$name _key>] + .as_ref() + .map(<[<$base_ty_name PublicKey>]>::new), + )* + } + } + } + + /////////////////////// + /// Server Key + /////////////////////// + #[derive(Clone, serde::Deserialize, serde::Serialize)] + pub(crate) struct [<$base_struct_name ServerKey>] { + $( + pub(super) [<$name _key>]: Option<[<$base_ty_name ServerKey>]>, + )* + } + + impl [<$base_struct_name ServerKey>] { + pub(crate) fn new(client_key: &[<$base_struct_name ClientKey>]) -> Self { + Self { + $( + [<$name _key>]: client_key.[<$name _key>].as_ref().map(<[<$base_ty_name ServerKey>]>::new), + )* + } + } + } + + impl Default for [<$base_struct_name ServerKey>] { + fn default() -> Self { + Self { + $( + [<$name _key>]: None, + )* + } + } + } + } + } +} diff --git a/tfhe/src/typed_api/errors.rs b/tfhe/src/typed_api/errors.rs new file mode 100644 index 000000000..600ea4889 --- /dev/null +++ b/tfhe/src/typed_api/errors.rs @@ -0,0 +1,171 @@ +use std::fmt::{Display, Formatter}; + +/// Unwrap 'Extension' trait +/// +/// The goal of this trait is to add a method similar to `unwrap` to `Result` +/// that uses the implementation of `Display` and not `Debug` as the +/// message in the panic. +pub trait UnwrapResultExt { + fn unwrap_display(self) -> T; +} + +impl UnwrapResultExt for Result +where + E: Display, +{ + #[track_caller] + fn unwrap_display(self) -> T { + match self { + Ok(t) => t, + Err(e) => panic!("{}", e), + } + } +} + +/// Enum that lists types available +/// +/// Mainly used to provide good errors. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum Type { + #[cfg(feature = "boolean")] + FheBool, + #[cfg(feature = "shortint")] + FheUint2, + #[cfg(feature = "shortint")] + FheUint3, + #[cfg(feature = "shortint")] + FheUint4, + #[cfg(feature = "integer")] + FheUint8, + #[cfg(feature = "integer")] + FheUint10, + #[cfg(feature = "integer")] + FheUint12, + #[cfg(feature = "integer")] + FheUint14, + #[cfg(feature = "integer")] + FheUint16, + #[cfg(feature = "integer")] + FheUint256, +} + +/// The server key of a given type was not initialized +#[derive(Debug)] +pub struct UninitializedServerKey(pub(crate) Type); + +impl Display for UninitializedServerKey { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "The server key for the type '{:?}' was not properly initialized\n\ + Did you forget to call `set_server_key` in this thread or forget to + enable the type in the config ? + ", + self.0 + ) + } +} + +impl std::error::Error for UninitializedServerKey {} + +/// The client key of a given type was not initialized +#[derive(Debug)] +pub struct UninitializedClientKey(pub(crate) Type); + +impl Display for UninitializedClientKey { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "The client key for the type '{:?}' was not properly initialized\n\ + Dis you forget to enable the type in the config ? + ", + self.0 + ) + } +} + +impl std::error::Error for UninitializedClientKey {} + +/// The client key of a given type was not initialized +#[derive(Debug)] +pub struct UninitializedPublicKey(pub(crate) Type); + +impl Display for UninitializedPublicKey { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "The public key for the type '{:?}' was not properly initialized\n\ + Dis you forget do enable the type in the config ? + ", + self.0 + ) + } +} + +impl std::error::Error for UninitializedPublicKey {} + +/// Error when trying to create a short integer from a value that was too big to be represented +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub struct OutOfRangeError; + +impl Display for OutOfRangeError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Value is out of range") + } +} + +impl std::error::Error for OutOfRangeError {} + +#[non_exhaustive] +#[derive(Debug, Eq, PartialEq)] +pub enum Error { + OutOfRange, + UninitializedClientKey(Type), + UninitializedPublicKey(Type), + UninitializedServerKey(Type), +} + +impl From for Error { + fn from(_: OutOfRangeError) -> Self { + Self::OutOfRange + } +} + +impl From for Error { + fn from(value: UninitializedClientKey) -> Self { + Self::UninitializedClientKey(value.0) + } +} + +impl From for Error { + fn from(value: UninitializedPublicKey) -> Self { + Self::UninitializedPublicKey(value.0) + } +} + +impl From for Error { + fn from(value: UninitializedServerKey) -> Self { + Self::UninitializedServerKey(value.0) + } +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Error::OutOfRange => { + write!(f, "{OutOfRangeError}") + } + Error::UninitializedClientKey(ty) => { + write!(f, "{}", UninitializedClientKey(*ty)) + } + Error::UninitializedPublicKey(ty) => { + write!(f, "{}", UninitializedPublicKey(*ty)) + } + Error::UninitializedServerKey(ty) => { + write!(f, "{}", UninitializedServerKey(*ty)) + } + } + } +} + +impl std::error::Error for Error {} diff --git a/tfhe/src/typed_api/global_state.rs b/tfhe/src/typed_api/global_state.rs new file mode 100644 index 000000000..f5029d53d --- /dev/null +++ b/tfhe/src/typed_api/global_state.rs @@ -0,0 +1,153 @@ +//! In this module, we store the hidden (to the end-user) internal state/keys that are needed to +//! perform operations. +use crate::typed_api::errors::{UninitializedServerKey, UnwrapResultExt}; +use std::cell::RefCell; + +use crate::typed_api::keys::ServerKey; + +/// We store the internal keys as thread local, meaning each thread has its own set of keys. +/// +/// This means that the user can do computations in multiple threads +/// (eg a web server that processes multiple requests in multiple threads). +/// The user however, has to initialize the internal keys each time it starts a thread. +thread_local! { + static INTERNAL_KEYS: RefCell = RefCell::new(ServerKey::default()); +} + +/// The function used to initialize internal keys. +/// +/// As each thread has its own set of keys, +/// this function must be called at least once on each thread to initialize its keys. +/// +/// +/// # Example +/// +/// Only working in the `main` thread +/// +/// ``` +/// use tfhe; +/// +/// # let config = tfhe::ConfigBuilder::all_disabled().build(); +/// let (client_key, server_key) = tfhe::generate_keys(config); +/// +/// tfhe::set_server_key(server_key); +/// // Now we can do operations on homomorphic types +/// ``` +/// +/// +/// Working with multiple threads +/// +/// ``` +/// use std::thread; +/// use tfhe; +/// use tfhe::ConfigBuilder; +/// +/// # let config = tfhe::ConfigBuilder::all_disabled().build(); +/// let (client_key, server_key) = tfhe::generate_keys(config); +/// let server_key_2 = server_key.clone(); +/// +/// let th1 = thread::spawn(move || { +/// tfhe::set_server_key(server_key); +/// // Now, this thread we can do operations on homomorphic types +/// }); +/// +/// let th2 = thread::spawn(move || { +/// tfhe::set_server_key(server_key_2); +/// // Now, this thread we can do operations on homomorphic types +/// }); +/// +/// th2.join(); +/// th1.join(); +/// ``` +pub fn set_server_key(keys: ServerKey) { + INTERNAL_KEYS.with(|internal_keys| internal_keys.replace_with(|_old| keys)); +} + +pub fn unset_server_key() -> ServerKey { + INTERNAL_KEYS.with(|internal_keys| internal_keys.replace_with(|_old| Default::default())) +} + +pub fn with_server_key_as_context(keys: ServerKey, f: F) -> (T, ServerKey) +where + F: FnOnce() -> T, +{ + set_server_key(keys); + let result = f(); + let keys = unset_server_key(); + (result, keys) +} + +/// Convenience function that allows to write functions that needs to access the internal keys. +#[cfg(any(feature = "integer", feature = "shortint", feature = "boolean"))] +#[inline] +pub(crate) fn with_internal_keys(func: F) -> T +where + F: FnOnce(&ServerKey) -> T, +{ + // Should use `with_borrow` when its stabilized + INTERNAL_KEYS.with(|keys| { + let key = &*keys.borrow(); + func(key) + }) +} + +/// Helper macro to help reduce boiler plate +/// needed to implement `WithGlobalKey` since for +/// our keys, the implementation is the same, only a few things change. +/// +/// It expects: +/// - The implementor type +/// - The `name` of the key type for which the trait will be implemented. +/// - The identifier (or identifier chain) that points to the member in the `ServerKey` that holds +/// the key for which the trait is implemented. +/// - Type Variant used to identify the type at runtime (see `error.rs`) +#[cfg(any(feature = "integer", feature = "shortint", feature = "boolean"))] +macro_rules! impl_with_global_key { + ( + for $implementor:ty { + key_type: $key_type:ty, + keychain_member: $($member:ident).*, + type_variant: $enum_variant:expr, + } + ) => { + impl crate::typed_api::global_state::WithGlobalKey for $implementor { + type Key = $key_type; + + fn with_global(self, func: F) -> Result + where + F: FnOnce(&Self::Key) -> R, + { + crate::typed_api::global_state::with_internal_keys(|keys| { + keys$(.$member)* + .as_ref() + .map(func) + .ok_or(crate::typed_api::errors::UninitializedServerKey($enum_variant)) + }) + } + } + } +} + +/// Global key access trait +/// +/// Each type we will expose to the user is going to need to have some internal keys. +/// This trait is there to make each of these internal keys have a convenience function that gives +/// access to the internal keys of its type. +/// +/// Typically, the implementation of the trait will be on the 'internal' key type +/// and will call [with_internal_keys_mut] and select the right member of the [ServerKey] type. +pub trait WithGlobalKey: Sized { + type Key; + + fn with_global(self, func: F) -> Result + where + F: FnOnce(&Self::Key) -> R; + + #[track_caller] + fn with_unwrapped_global(self, func: F) -> R + where + F: FnOnce(&Self::Key) -> R, + { + self.with_global(func).unwrap_display() + } +} diff --git a/tfhe/src/typed_api/integers/client_key.rs b/tfhe/src/typed_api/integers/client_key.rs new file mode 100644 index 000000000..5ad1f90fc --- /dev/null +++ b/tfhe/src/typed_api/integers/client_key.rs @@ -0,0 +1,72 @@ +use serde::{Deserialize, Serialize}; + +use crate::integer::{CrtCiphertext, CrtClientKey, RadixCiphertext, RadixClientKey, U256}; +use crate::typed_api::integers::parameters::IntegerParameter; +use crate::typed_api::internal_traits::{DecryptionKey, EncryptionKey, FromParameters}; + +impl EncryptionKey for RadixClientKey { + type Ciphertext = RadixCiphertext; + + fn encrypt(&self, value: u64) -> Self::Ciphertext { + self.encrypt(value) + } +} + +impl EncryptionKey for RadixClientKey { + type Ciphertext = RadixCiphertext; + + fn encrypt(&self, value: U256) -> Self::Ciphertext { + self.as_ref().encrypt_radix(value, self.num_blocks()) + } +} + +impl DecryptionKey for RadixClientKey { + type Ciphertext = RadixCiphertext; + + fn decrypt(&self, ciphertext: &Self::Ciphertext) -> u64 { + self.decrypt(ciphertext) + } +} + +impl DecryptionKey for RadixClientKey { + type Ciphertext = RadixCiphertext; + + fn decrypt(&self, ciphertext: &Self::Ciphertext) -> U256 { + let mut r = U256::default(); + self.as_ref().decrypt_radix_into(ciphertext, &mut r); + r + } +} + +impl EncryptionKey for CrtClientKey { + type Ciphertext = CrtCiphertext; + + fn encrypt(&self, value: u64) -> Self::Ciphertext { + self.encrypt(value) + } +} + +impl DecryptionKey for CrtClientKey { + type Ciphertext = CrtCiphertext; + + fn decrypt(&self, ciphertext: &Self::Ciphertext) -> u64 { + self.decrypt(ciphertext) + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GenericIntegerClientKey { + pub(in crate::typed_api::integers) inner: P::InnerClientKey, + pub(in crate::typed_api::integers) params: P, +} + +impl

From

for GenericIntegerClientKey

+where + P: IntegerParameter, + P::InnerClientKey: FromParameters

, +{ + fn from(params: P) -> Self { + let key = P::InnerClientKey::from_parameters(params.clone()); + Self { inner: key, params } + } +} diff --git a/tfhe/src/typed_api/integers/keys.rs b/tfhe/src/typed_api/integers/keys.rs new file mode 100644 index 000000000..e47a70300 --- /dev/null +++ b/tfhe/src/typed_api/integers/keys.rs @@ -0,0 +1,10 @@ +define_key_structs! { + Integer { + uint8: FheUint8, + uint10: FheUint10, + uint12: FheUint12, + uint14: FheUint14, + uint16: FheUint16, + uint256: FheUint256, + } +} diff --git a/tfhe/src/typed_api/integers/mod.rs b/tfhe/src/typed_api/integers/mod.rs new file mode 100644 index 000000000..1c5936d63 --- /dev/null +++ b/tfhe/src/typed_api/integers/mod.rs @@ -0,0 +1,12 @@ +pub(crate) use keys::{IntegerClientKey, IntegerConfig, IntegerPublicKey, IntegerServerKey}; +pub use parameters::{CrtParameters, RadixParameters}; +pub use types::{FheUint10, FheUint12, FheUint14, FheUint16, FheUint256, FheUint8, GenericInteger}; + +mod client_key; +mod keys; +mod parameters; +mod public_key; +mod server_key; +#[cfg(test)] +mod tests; +mod types; diff --git a/tfhe/src/typed_api/integers/parameters.rs b/tfhe/src/typed_api/integers/parameters.rs new file mode 100644 index 000000000..b2292d993 --- /dev/null +++ b/tfhe/src/typed_api/integers/parameters.rs @@ -0,0 +1,122 @@ +use crate::integer::{CrtCiphertext, CrtClientKey, RadixCiphertext, RadixClientKey}; +use crate::typed_api::internal_traits::{FromParameters, ParameterType}; +use serde::{Deserialize, Serialize}; + +/// Parameters for 'radix' decomposition +/// +/// Radix decomposition works by using multiple shortint blocks +/// with the same parameters to represent an integer. +/// +/// For example, by taking 4 blocks with parameters +/// for 2bits shortints, with have a 4 * 2 = 8 bit integer. +#[derive(Copy, Clone, Debug, Serialize, Deserialize)] +pub struct RadixParameters { + pub block_parameters: crate::shortint::Parameters, + pub num_block: usize, + pub wopbs_block_parameters: crate::shortint::Parameters, +} + +/// Parameters for 'CRT' decomposition +/// +/// (Chinese Remainder Theorem) +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct CrtParameters { + pub block_parameters: crate::shortint::Parameters, + pub moduli: Vec, + pub wopbs_block_parameters: crate::shortint::Parameters, +} + +/// Meant to be implemented on the inner server key +/// eg the crate::integer::ServerKey +pub trait EvaluationIntegerKey { + fn new(client_key: &ClientKey) -> Self; + + fn new_wopbs_key( + client_key: &ClientKey, + server_key: &Self, + wopbs_block_parameters: crate::shortint::Parameters, + ) -> crate::integer::wopbs::WopbsKey; +} + +impl

FromParameters

for crate::integer::RadixClientKey +where + P: Into, +{ + fn from_parameters(parameters: P) -> Self { + let params = parameters.into(); + #[cfg(feature = "internal-keycache")] + { + use crate::integer::keycache::KEY_CACHE; + let key = KEY_CACHE.get_from_params(params.block_parameters).0; + crate::integer::RadixClientKey::from((key, params.num_block)) + } + #[cfg(not(feature = "internal-keycache"))] + { + crate::integer::RadixClientKey::new(params.block_parameters, params.num_block) + } + } +} + +impl

FromParameters

for crate::integer::CrtClientKey +where + P: Into, +{ + fn from_parameters(parameters: P) -> Self { + let params = parameters.into(); + #[cfg(feature = "internal-keycache")] + { + use crate::integer::keycache::KEY_CACHE; + let key = KEY_CACHE.get_from_params(params.block_parameters).0; + crate::integer::CrtClientKey::from((key, params.moduli)) + } + #[cfg(not(feature = "internal-keycache"))] + { + crate::integer::CrtClientKey::new(params.block_parameters, params.moduli) + } + } +} + +/// Trait to mark parameters type for integers +pub trait IntegerParameter: ParameterType { + fn wopbs_block_parameters(&self) -> crate::shortint::Parameters; + + fn block_parameters(&self) -> crate::shortint::Parameters; +} + +/// Marker struct for the RadixRepresentation +#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)] +pub struct RadixRepresentation; +/// Marker struct for the CrtRepresentation +#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)] +pub struct CrtRepresentation; + +/// Trait to mark parameters type for static integers +/// +/// Static means the integer types with parameters provided by +/// the crate, so parameters for which we know the number of +/// bits the represent. +pub trait StaticIntegerParameter: IntegerParameter { + type Representation: Default + Eq; + + const MESSAGE_BITS: usize; +} + +pub trait StaticRadixParameter: + StaticIntegerParameter +where + Self: IntegerParameter< + InnerClientKey = RadixClientKey, + InnerServerKey = crate::integer::ServerKey, + InnerCiphertext = RadixCiphertext, + >, +{ +} +pub trait StaticCrtParameter: StaticIntegerParameter +where + Self: IntegerParameter< + InnerClientKey = CrtClientKey, + InnerServerKey = crate::integer::ServerKey, + InnerCiphertext = CrtCiphertext, + >, +{ +} diff --git a/tfhe/src/typed_api/integers/public_key.rs b/tfhe/src/typed_api/integers/public_key.rs new file mode 100644 index 000000000..2a58c89f3 --- /dev/null +++ b/tfhe/src/typed_api/integers/public_key.rs @@ -0,0 +1,95 @@ +use crate::typed_api::integers::client_key::GenericIntegerClientKey; + +use crate::integer::{CrtCiphertext, CrtClientKey, RadixCiphertext, RadixClientKey, U256}; +use crate::typed_api::internal_traits::{EncryptionKey, ParameterType}; +use serde::{Deserialize, Serialize}; + +use super::parameters::IntegerParameter; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RadixPublicKey { + key: crate::integer::PublicKey, + num_blocks: usize, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct CrtPublicKey { + key: crate::integer::PublicKey, + moduli: Vec, +} + +pub trait IntegerPublicKey { + type ClientKey; + + fn new(client_key: &Self::ClientKey) -> Self; +} + +impl IntegerPublicKey for RadixPublicKey { + type ClientKey = RadixClientKey; + + fn new(client_key: &Self::ClientKey) -> Self { + Self { + key: crate::integer::PublicKey::new(client_key.as_ref()), + num_blocks: client_key.num_blocks(), + } + } +} + +impl EncryptionKey for RadixPublicKey { + type Ciphertext = RadixCiphertext; + + fn encrypt(&self, value: u64) -> Self::Ciphertext { + self.key.encrypt_radix(value, self.num_blocks) + } +} + +impl EncryptionKey for RadixPublicKey { + type Ciphertext = RadixCiphertext; + + fn encrypt(&self, value: U256) -> Self::Ciphertext { + self.key.encrypt_radix(value, self.num_blocks) + } +} + +impl IntegerPublicKey for CrtPublicKey { + type ClientKey = CrtClientKey; + + fn new(client_key: &Self::ClientKey) -> Self { + Self { + key: crate::integer::PublicKey::new(client_key.as_ref()), + moduli: client_key.moduli().to_vec(), + } + } +} + +impl EncryptionKey for CrtPublicKey { + type Ciphertext = CrtCiphertext; + + fn encrypt(&self, value: u64) -> Self::Ciphertext { + self.key.encrypt_crt(value, self.moduli.clone()) + } +} + +#[cfg_attr(all(doc, not(doctest)), cfg(feature = "integer"))] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GenericIntegerPublicKey

+where + P: IntegerParameter, +{ + pub(in crate::typed_api::integers) inner: P::InnerPublicKey, + _marker: std::marker::PhantomData

, +} + +impl

GenericIntegerPublicKey

+where + P: IntegerParameter, +

::InnerPublicKey: IntegerPublicKey, +{ + pub fn new(client_key: &GenericIntegerClientKey

) -> Self { + let key =

::InnerPublicKey::new(&client_key.inner); + Self { + inner: key, + _marker: Default::default(), + } + } +} diff --git a/tfhe/src/typed_api/integers/server_key.rs b/tfhe/src/typed_api/integers/server_key.rs new file mode 100644 index 000000000..d926388a0 --- /dev/null +++ b/tfhe/src/typed_api/integers/server_key.rs @@ -0,0 +1,186 @@ +use std::marker::PhantomData; + +use crate::typed_api::integers::parameters::EvaluationIntegerKey; + +use super::client_key::GenericIntegerClientKey; +use super::parameters::IntegerParameter; + +use crate::integer::wopbs::WopbsKey; + +#[derive(Clone, serde::Deserialize, serde::Serialize)] +pub struct GenericIntegerServerKey { + pub(in crate::typed_api::integers) inner: P::InnerServerKey, + pub(in crate::typed_api::integers) wopbs_key: WopbsKey, + _marker: PhantomData

, +} + +impl

GenericIntegerServerKey

+where + P: IntegerParameter, + P::InnerServerKey: EvaluationIntegerKey, +{ + pub(super) fn new(client_key: &GenericIntegerClientKey

) -> Self { + let inner = P::InnerServerKey::new(&client_key.inner); + let wopbs_key = P::InnerServerKey::new_wopbs_key( + &client_key.inner, + &inner, + client_key.params.wopbs_block_parameters(), + ); + Self { + inner, + wopbs_key, + _marker: Default::default(), + } + } +} + +pub(super) trait SmartNeg { + type Output; + fn smart_neg(&self, lhs: Ciphertext) -> Self::Output; +} + +macro_rules! define_smart_server_key_op { + ($op_name:ident) => { + paste::paste! { + pub trait [< Smart $op_name >] { + type Output; + + fn [< smart_ $op_name:lower >]( + &self, + lhs: Lhs, + rhs: Rhs, + ) -> Self::Output; + } + + pub trait [< Smart $op_name Assign >] { + fn [< smart_ $op_name:lower _assign >]( + &self, + lhs: &mut Lhs, + rhs: Rhs, + ); + } + } + }; + ($($op:ident),*) => { + $( + define_smart_server_key_op!($op); + )* + }; +} + +define_smart_server_key_op!( + Add, Sub, Mul, BitAnd, BitOr, BitXor, Shl, Shr, Eq, Ge, Gt, Le, Lt, Max, Min +); + +macro_rules! impl_smart_op_for_tfhe_integer_server_key { + ($smart_trait:ident($smart_trait_fn:ident) => ($ciphertext:ty, $method:ident)) => { + impl $smart_trait<&mut $ciphertext, &mut $ciphertext> for crate::integer::ServerKey { + type Output = $ciphertext; + + fn $smart_trait_fn( + &self, + lhs: &mut $ciphertext, + rhs: &mut $ciphertext, + ) -> Self::Output { + self.$method(lhs, rhs) + } + } + }; +} + +macro_rules! impl_smart_assign_op_for_tfhe_integer_server_key { + ($smart_trait:ident($smart_trait_fn:ident) => ($ciphertext:ty, $method:ident)) => { + impl $smart_trait<$ciphertext, &mut $ciphertext> for crate::integer::ServerKey { + fn $smart_trait_fn(&self, lhs: &mut $ciphertext, rhs: &mut $ciphertext) { + self.$method(lhs, rhs); + } + } + }; +} + +macro_rules! impl_smart_scalar_op_for_tfhe_integer_server_key { + ($smart_trait:ident($smart_trait_fn:ident) => ($ciphertext:ty, $method:ident)) => { + impl $smart_trait<&mut $ciphertext, u64> for crate::integer::ServerKey { + type Output = $ciphertext; + + fn $smart_trait_fn(&self, lhs: &mut $ciphertext, rhs: u64) -> Self::Output { + self.$method(lhs, rhs.try_into().unwrap()) + } + } + }; +} + +macro_rules! impl_smart_scalar_assign_op_for_tfhe_integer_server_key { + ($smart_trait:ident($smart_trait_fn:ident) => ($ciphertext:ty, $method:ident)) => { + impl $smart_trait<$ciphertext, u64> for crate::integer::ServerKey { + fn $smart_trait_fn(&self, lhs: &mut $ciphertext, rhs: u64) { + self.$method(lhs, rhs.try_into().unwrap()); + } + } + }; +} + +impl SmartNeg<&mut crate::integer::RadixCiphertext> for crate::integer::ServerKey { + type Output = crate::integer::RadixCiphertext; + fn smart_neg(&self, lhs: &mut crate::integer::RadixCiphertext) -> Self::Output { + self.smart_neg_parallelized(lhs) + } +} + +impl_smart_op_for_tfhe_integer_server_key!(SmartAdd(smart_add) => (crate::integer::RadixCiphertext, smart_add_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartSub(smart_sub) => (crate::integer::RadixCiphertext, smart_sub_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartMul(smart_mul) => (crate::integer::RadixCiphertext, smart_mul_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartBitAnd(smart_bitand) => (crate::integer::RadixCiphertext, smart_bitand_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartBitOr(smart_bitor) => (crate::integer::RadixCiphertext, smart_bitor_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartBitXor(smart_bitxor) => (crate::integer::RadixCiphertext, smart_bitxor_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartEq(smart_eq) => (crate::integer::RadixCiphertext, smart_eq_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartGe(smart_ge) => (crate::integer::RadixCiphertext, smart_ge_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartGt(smart_gt) => (crate::integer::RadixCiphertext, smart_gt_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartLe(smart_le) => (crate::integer::RadixCiphertext, smart_le_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartLt(smart_lt) => (crate::integer::RadixCiphertext, smart_lt_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartMax(smart_max) => (crate::integer::RadixCiphertext, smart_max_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartMin(smart_min) => (crate::integer::RadixCiphertext, smart_min_parallelized)); + +impl_smart_assign_op_for_tfhe_integer_server_key!(SmartAddAssign(smart_add_assign) => (crate::integer::RadixCiphertext, smart_add_assign_parallelized)); +impl_smart_assign_op_for_tfhe_integer_server_key!(SmartSubAssign(smart_sub_assign) => (crate::integer::RadixCiphertext, smart_sub_assign_parallelized)); +impl_smart_assign_op_for_tfhe_integer_server_key!(SmartMulAssign(smart_mul_assign) => (crate::integer::RadixCiphertext, smart_mul_assign_parallelized)); +impl_smart_assign_op_for_tfhe_integer_server_key!(SmartBitAndAssign(smart_bitand_assign) => (crate::integer::RadixCiphertext, smart_bitand_assign_parallelized)); +impl_smart_assign_op_for_tfhe_integer_server_key!(SmartBitOrAssign(smart_bitor_assign) => (crate::integer::RadixCiphertext, smart_bitor_assign_parallelized)); +impl_smart_assign_op_for_tfhe_integer_server_key!(SmartBitXorAssign(smart_bitxor_assign) => (crate::integer::RadixCiphertext, smart_bitxor_assign_parallelized)); + +impl_smart_scalar_op_for_tfhe_integer_server_key!(SmartAdd(smart_add) => (crate::integer::RadixCiphertext, smart_scalar_add_parallelized)); +impl_smart_scalar_op_for_tfhe_integer_server_key!(SmartSub(smart_sub) => (crate::integer::RadixCiphertext, smart_scalar_sub_parallelized)); +impl_smart_scalar_op_for_tfhe_integer_server_key!(SmartMul(smart_mul) => (crate::integer::RadixCiphertext, smart_scalar_mul_parallelized)); +impl_smart_scalar_op_for_tfhe_integer_server_key!(SmartShl(smart_shl) => (crate::integer::RadixCiphertext, unchecked_scalar_left_shift_parallelized)); +impl_smart_scalar_op_for_tfhe_integer_server_key!(SmartShr(smart_shr) => (crate::integer::RadixCiphertext, unchecked_scalar_right_shift_parallelized)); + +impl_smart_scalar_assign_op_for_tfhe_integer_server_key!(SmartAddAssign(smart_add_assign) => (crate::integer::RadixCiphertext, smart_scalar_add_assign_parallelized)); +impl_smart_scalar_assign_op_for_tfhe_integer_server_key!(SmartSubAssign(smart_sub_assign) => (crate::integer::RadixCiphertext, smart_scalar_sub_assign_parallelized)); +impl_smart_scalar_assign_op_for_tfhe_integer_server_key!(SmartMulAssign(smart_mul_assign) => (crate::integer::RadixCiphertext, smart_scalar_mul_assign_parallelized)); +impl_smart_scalar_assign_op_for_tfhe_integer_server_key!(SmartShlAssign(smart_shl_assign) => (crate::integer::RadixCiphertext, unchecked_scalar_left_shift_assign_parallelized)); +impl_smart_scalar_assign_op_for_tfhe_integer_server_key!(SmartShrAssign(smart_shr_assign) => (crate::integer::RadixCiphertext, unchecked_scalar_right_shift_assign_parallelized)); + +// Crt + +impl_smart_op_for_tfhe_integer_server_key!(SmartAdd(smart_add) => (crate::integer::CrtCiphertext, smart_crt_add_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartSub(smart_sub) => (crate::integer::CrtCiphertext, smart_crt_sub_parallelized)); +impl_smart_op_for_tfhe_integer_server_key!(SmartMul(smart_mul) => (crate::integer::CrtCiphertext, smart_crt_mul_parallelized)); + +impl_smart_assign_op_for_tfhe_integer_server_key!(SmartAddAssign(smart_add_assign) => (crate::integer::CrtCiphertext, smart_crt_add_assign_parallelized)); +impl_smart_assign_op_for_tfhe_integer_server_key!(SmartSubAssign(smart_sub_assign) => (crate::integer::CrtCiphertext, smart_crt_sub_parallelized)); +impl_smart_assign_op_for_tfhe_integer_server_key!(SmartMulAssign(smart_mul_assign) => (crate::integer::CrtCiphertext, smart_crt_mul_assign_parallelized)); + +impl_smart_scalar_op_for_tfhe_integer_server_key!(SmartAdd(smart_add) => (crate::integer::CrtCiphertext, smart_crt_scalar_add)); +impl_smart_scalar_op_for_tfhe_integer_server_key!(SmartSub(smart_sub) => (crate::integer::CrtCiphertext, smart_crt_scalar_sub)); +impl_smart_scalar_op_for_tfhe_integer_server_key!(SmartMul(smart_mul) => (crate::integer::CrtCiphertext, smart_crt_scalar_mul)); + +impl_smart_scalar_assign_op_for_tfhe_integer_server_key!(SmartAddAssign(smart_add_assign) => (crate::integer::CrtCiphertext, smart_crt_scalar_add_assign)); +impl_smart_scalar_assign_op_for_tfhe_integer_server_key!(SmartSubAssign(smart_sub_assign) => (crate::integer::CrtCiphertext, smart_crt_scalar_sub_assign)); +impl_smart_scalar_assign_op_for_tfhe_integer_server_key!(SmartMulAssign(smart_mul_assign) => (crate::integer::CrtCiphertext, smart_crt_scalar_mul_assign)); + +impl SmartNeg<&mut crate::integer::CrtCiphertext> for crate::integer::ServerKey { + type Output = crate::integer::CrtCiphertext; + fn smart_neg(&self, lhs: &mut crate::integer::CrtCiphertext) -> Self::Output { + self.smart_crt_neg_parallelized(lhs) + } +} diff --git a/tfhe/src/typed_api/integers/tests.rs b/tfhe/src/typed_api/integers/tests.rs new file mode 100644 index 000000000..fe154cf6e --- /dev/null +++ b/tfhe/src/typed_api/integers/tests.rs @@ -0,0 +1,65 @@ +use crate::typed_api::prelude::*; +use crate::typed_api::{generate_keys, set_server_key, ConfigBuilder, FheUint8}; + +#[test] +fn test_quickstart_uint8() { + let config = ConfigBuilder::all_disabled().enable_default_uint8().build(); + + let (client_key, server_key) = generate_keys(config); + + set_server_key(server_key); + + let clear_a = 27u8; + let clear_b = 128u8; + + let a = FheUint8::encrypt(clear_a, &client_key); + let b = FheUint8::encrypt(clear_b, &client_key); + + let result = a + b; + + let decrypted_result: u8 = result.decrypt(&client_key); + + let clear_result = clear_a + clear_b; + + assert_eq!(decrypted_result, clear_result); +} + +#[test] +fn test_uint8_compare() { + let config = ConfigBuilder::all_disabled().enable_default_uint8().build(); + + let (client_key, server_key) = generate_keys(config); + + set_server_key(server_key); + + let clear_a = 27u8; + let clear_b = 128u8; + + let a = FheUint8::encrypt(clear_a, &client_key); + let b = FheUint8::encrypt(clear_b, &client_key); + + let result = &a.eq(&b); + let decrypted_result: u8 = result.decrypt(&client_key); + let clear_result = u8::from(clear_a == clear_b); + assert_eq!(decrypted_result, clear_result); + + let result = &a.le(&b); + let decrypted_result: u8 = result.decrypt(&client_key); + let clear_result = u8::from(clear_a <= clear_b); + assert_eq!(decrypted_result, clear_result); + + let result = &a.lt(&b); + let decrypted_result: u8 = result.decrypt(&client_key); + let clear_result = u8::from(clear_a < clear_b); + assert_eq!(decrypted_result, clear_result); + + let result = &a.ge(&b); + let decrypted_result: u8 = result.decrypt(&client_key); + let clear_result = u8::from(clear_a >= clear_b); + assert_eq!(decrypted_result, clear_result); + + let result = &a.gt(&b); + let decrypted_result: u8 = result.decrypt(&client_key); + let clear_result = u8::from(clear_a >= clear_b); + assert_eq!(decrypted_result, clear_result); +} diff --git a/tfhe/src/typed_api/integers/types/base.rs b/tfhe/src/typed_api/integers/types/base.rs new file mode 100644 index 000000000..642a1688d --- /dev/null +++ b/tfhe/src/typed_api/integers/types/base.rs @@ -0,0 +1,773 @@ +use std::borrow::Borrow; +use std::cell::RefCell; +use std::ops::{ + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Mul, MulAssign, + Neg, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, +}; + +use crate::integer::wopbs::WopbsKey; +use crate::integer::{CrtCiphertext, RadixCiphertext, U256}; +use crate::typed_api::global_state::WithGlobalKey; +use crate::typed_api::integers::client_key::GenericIntegerClientKey; +use crate::typed_api::integers::parameters::{ + CrtRepresentation, IntegerParameter, RadixRepresentation, StaticCrtParameter, + StaticIntegerParameter, StaticRadixParameter, +}; +use crate::typed_api::integers::public_key::GenericIntegerPublicKey; +use crate::typed_api::integers::server_key::{ + GenericIntegerServerKey, SmartAdd, SmartAddAssign, SmartBitAnd, SmartBitAndAssign, SmartBitOr, + SmartBitOrAssign, SmartBitXor, SmartBitXorAssign, SmartEq, SmartGe, SmartGt, SmartLe, SmartLt, + SmartMax, SmartMin, SmartMul, SmartMulAssign, SmartNeg, SmartShl, SmartShlAssign, SmartShr, + SmartShrAssign, SmartSub, SmartSubAssign, +}; +use crate::typed_api::internal_traits::{DecryptionKey, EncryptionKey}; +use crate::typed_api::keys::{RefKeyFromKeyChain, RefKeyFromPublicKeyChain}; +use crate::typed_api::traits::{FheBootstrap, FheDecrypt, FheEq, FheOrd, FheTryEncrypt}; +use crate::typed_api::{ClientKey, PublicKey}; + +/// A Generic FHE unsigned integer +/// +/// Contrary to *shortints*, these integers can in theory by parametrized to +/// represent integers of any number of bits (eg: 16, 24, 32, 64). +/// +/// However, in practice going above 16 bits may not be ideal as the +/// computations would not scale and become very expensive. +/// +/// Integers works by combining together multiple shortints +/// with one of the available representation. +/// +/// This struct is generic over some parameters, as its the parameters +/// that controls how many bit they represent. +/// You will need to use one of this type specialization (e.g., [FheUint8], [FheUint12], +/// [FheUint16]). +/// +/// Its the type that overloads the operators (`+`, `-`, `*`), +/// since the `GenericInteger` type is not `Copy` the operators are also overloaded +/// to work with references. +/// +/// +/// To be able to use this type, the cargo feature `integers` must be enabled, +/// and your config should also enable the type with either default parameters or custom ones. +/// +/// +/// [FheUint8]: crate::typed_api::FheUint8 +/// [FheUint12]: crate::typed_api::FheUint12 +/// [FheUint16]: crate::typed_api::FheUint16 +#[cfg_attr(all(doc, not(doctest)), doc(cfg(feature = "integer")))] +#[derive(Clone, serde::Deserialize, serde::Serialize)] +pub struct GenericInteger { + pub(in crate::typed_api::integers) ciphertext: RefCell, + pub(in crate::typed_api::integers) id: P::Id, +} + +impl

GenericInteger

+where + P: IntegerParameter, +{ + pub(in crate::typed_api::integers) fn new(ciphertext: P::InnerCiphertext, id: P::Id) -> Self { + Self { + ciphertext: RefCell::new(ciphertext), + id, + } + } +} + +impl

FheDecrypt for GenericInteger

+where + P: IntegerParameter, + P::Id: RefKeyFromKeyChain>, + P::InnerClientKey: DecryptionKey, +{ + fn decrypt(&self, key: &ClientKey) -> u64 { + let key = self.id.unwrapped_ref_key(key); + key.inner.decrypt(&self.ciphertext.borrow()) + } +} + +impl

FheDecrypt for GenericInteger

+where + P: IntegerParameter, + P::Id: RefKeyFromKeyChain>, + P::InnerClientKey: DecryptionKey, +{ + fn decrypt(&self, key: &ClientKey) -> U256 { + let key = self.id.unwrapped_ref_key(key); + key.inner.decrypt(&self.ciphertext.borrow()) + } +} + +impl FheTryEncrypt for GenericInteger

+where + T: Into, + P: StaticIntegerParameter, + P::Id: RefKeyFromKeyChain> + Default, + P::InnerClientKey: EncryptionKey, +{ + type Error = crate::typed_api::errors::Error; + + fn try_encrypt(value: T, key: &ClientKey) -> Result { + let value = value.into(); + let id = P::Id::default(); + let key = id.ref_key(key)?; + let ciphertext = key.inner.encrypt(value); + Ok(Self::new(ciphertext, id)) + } +} + +impl FheTryEncrypt for GenericInteger

+where + T: Into, + P: StaticIntegerParameter, + P::Id: RefKeyFromPublicKeyChain> + Default, + P::InnerPublicKey: EncryptionKey, +{ + type Error = crate::typed_api::errors::Error; + + fn try_encrypt(value: T, key: &PublicKey) -> Result { + let value = value.into(); + let id = P::Id::default(); + let key = id.ref_key(key)?; + let ciphertext = key.inner.encrypt(value); + Ok(Self::new(ciphertext, id)) + } +} + +impl

GenericInteger

+where + P: IntegerParameter, + GenericInteger

: Clone, + P::Id: WithGlobalKey>, + P::InnerServerKey: for<'a> SmartMax< + &'a mut P::InnerCiphertext, + &'a mut P::InnerCiphertext, + Output = P::InnerCiphertext, + >, +{ + pub fn max(&self, rhs: &Self) -> Self { + let inner_result = self.id.with_unwrapped_global(|server_key| { + if std::ptr::eq(self, rhs) { + let cloned = (*rhs).clone(); + let r = server_key.inner.smart_max( + &mut self.ciphertext.borrow_mut(), + &mut cloned.ciphertext.borrow_mut(), + ); + r + } else { + server_key.inner.smart_max( + &mut self.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ) + } + }); + + GenericInteger::new(inner_result, self.id) + } +} + +impl

GenericInteger

+where + P: IntegerParameter, + GenericInteger

: Clone, + P::Id: WithGlobalKey>, + P::InnerServerKey: for<'a> SmartMin< + &'a mut P::InnerCiphertext, + &'a mut P::InnerCiphertext, + Output = P::InnerCiphertext, + >, +{ + pub fn min(&self, rhs: &Self) -> Self { + let inner_result = self.id.with_unwrapped_global(|server_key| { + if std::ptr::eq(self, rhs) { + let cloned = (*rhs).clone(); + let r = server_key.inner.smart_min( + &mut self.ciphertext.borrow_mut(), + &mut cloned.ciphertext.borrow_mut(), + ); + r + } else { + server_key.inner.smart_min( + &mut self.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ) + } + }); + + GenericInteger::new(inner_result, self.id) + } +} + +impl FheEq for GenericInteger

+where + B: Borrow>, + P: IntegerParameter, + GenericInteger

: Clone, + P::Id: WithGlobalKey>, + P::InnerServerKey: for<'a> SmartEq< + &'a mut P::InnerCiphertext, + &'a mut P::InnerCiphertext, + Output = P::InnerCiphertext, + >, +{ + type Output = Self; + + fn eq(&self, rhs: B) -> Self::Output { + let inner_result = self.id.with_unwrapped_global(|server_key| { + let borrowed = rhs.borrow(); + if std::ptr::eq(self, borrowed) { + let cloned = (*borrowed).clone(); + let r = server_key.inner.smart_eq( + &mut self.ciphertext.borrow_mut(), + &mut cloned.ciphertext.borrow_mut(), + ); + r + } else { + server_key.inner.smart_eq( + &mut self.ciphertext.borrow_mut(), + &mut borrowed.ciphertext.borrow_mut(), + ) + } + }); + + GenericInteger::new(inner_result, self.id) + } +} + +impl FheOrd for GenericInteger

+where + B: Borrow>, + P: IntegerParameter, + GenericInteger

: Clone, + P::Id: WithGlobalKey>, + P::InnerServerKey: for<'a> SmartGe< + &'a mut P::InnerCiphertext, + &'a mut P::InnerCiphertext, + Output = P::InnerCiphertext, + > + for<'a> SmartGt< + &'a mut P::InnerCiphertext, + &'a mut P::InnerCiphertext, + Output = P::InnerCiphertext, + > + for<'a> SmartLe< + &'a mut P::InnerCiphertext, + &'a mut P::InnerCiphertext, + Output = P::InnerCiphertext, + > + for<'a> SmartLt< + &'a mut P::InnerCiphertext, + &'a mut P::InnerCiphertext, + Output = P::InnerCiphertext, + >, +{ + type Output = Self; + + fn lt(&self, other: B) -> Self::Output { + let inner_result = self.id.with_unwrapped_global(|server_key| { + let borrowed = other.borrow(); + if std::ptr::eq(self, borrowed) { + let cloned = borrowed.clone(); + let r = server_key.inner.smart_lt( + &mut self.ciphertext.borrow_mut(), + &mut cloned.ciphertext.borrow_mut(), + ); + r + } else { + server_key.inner.smart_lt( + &mut self.ciphertext.borrow_mut(), + &mut borrowed.ciphertext.borrow_mut(), + ) + } + }); + + GenericInteger::new(inner_result, self.id) + } + + fn le(&self, other: B) -> Self::Output { + let inner_result = self.id.with_unwrapped_global(|server_key| { + let borrowed = other.borrow(); + if std::ptr::eq(self, borrowed) { + let cloned = borrowed.clone(); + let r = server_key.inner.smart_le( + &mut self.ciphertext.borrow_mut(), + &mut cloned.ciphertext.borrow_mut(), + ); + r + } else { + server_key.inner.smart_le( + &mut self.ciphertext.borrow_mut(), + &mut borrowed.ciphertext.borrow_mut(), + ) + } + }); + + GenericInteger::new(inner_result, self.id) + } + + fn gt(&self, other: B) -> Self::Output { + let inner_result = self.id.with_unwrapped_global(|server_key| { + let borrowed = other.borrow(); + if std::ptr::eq(self, borrowed) { + let cloned = borrowed.clone(); + let r = server_key.inner.smart_gt( + &mut self.ciphertext.borrow_mut(), + &mut cloned.ciphertext.borrow_mut(), + ); + r + } else { + server_key.inner.smart_gt( + &mut self.ciphertext.borrow_mut(), + &mut borrowed.ciphertext.borrow_mut(), + ) + } + }); + + GenericInteger::new(inner_result, self.id) + } + + fn ge(&self, other: B) -> Self::Output { + let inner_result = self.id.with_unwrapped_global(|server_key| { + let borrowed = other.borrow(); + if std::ptr::eq(self, borrowed) { + let cloned = borrowed.clone(); + let r = server_key.inner.smart_ge( + &mut self.ciphertext.borrow_mut(), + &mut cloned.ciphertext.borrow_mut(), + ); + r + } else { + server_key.inner.smart_ge( + &mut self.ciphertext.borrow_mut(), + &mut borrowed.ciphertext.borrow_mut(), + ) + } + }); + + GenericInteger::new(inner_result, self.id) + } +} + +// This extra trait is needed as otherwise +// +// impl

FheBootstrap for GenericInteger

+// where P: StaticCrtParameters, +// P: IntegerParameter, +// { /* sutff */ } +// +// impl

FheBootstrap for GenericInteger

+// where P: StaticRadixParameters, +// P: IntegerParameter, +// P::Id: WithGlobalKey>, +// { /* sutff */ } +// +// Leads to errors about conflicting impl +pub trait WopbsExecutor< + P: StaticIntegerParameter, + R =

::Representation, +> +{ + fn execute_wopbs u64>( + &self, + ct_in: &GenericInteger

, + func: F, + ) -> GenericInteger

; + + fn execute_bivariate_wopbs u64>( + &self, + lhs: &GenericInteger

, + rhs: &GenericInteger

, + func: F, + ) -> GenericInteger

; +} + +pub(crate) fn wopbs_radix( + wopbs_key: &WopbsKey, + server_key: &crate::integer::ServerKey, + ct_in: &RadixCiphertext, + func: impl Fn(u64) -> u64, +) -> RadixCiphertext { + let switched_ct = wopbs_key.keyswitch_to_wopbs_params(server_key, ct_in); + let luts = wopbs_key.generate_lut_radix(&switched_ct, func); + let res = wopbs_key.wopbs(&switched_ct, luts.as_slice()); + wopbs_key.keyswitch_to_pbs_params(&res) +} + +pub(crate) fn bivariate_wopbs_radix( + wopbs_key: &WopbsKey, + server_key: &crate::integer::ServerKey, + lhs: &RadixCiphertext, + rhs: &RadixCiphertext, + func: impl Fn(u64, u64) -> u64, +) -> RadixCiphertext { + let switched_lhs = wopbs_key.keyswitch_to_wopbs_params(server_key, lhs); + let switched_rhs = wopbs_key.keyswitch_to_wopbs_params(server_key, rhs); + let lut = wopbs_key.generate_lut_bivariate_radix(&switched_lhs, &switched_rhs, func); + let res = wopbs_key.bivariate_wopbs_with_degree(&switched_lhs, &switched_rhs, lut.as_slice()); + wopbs_key.keyswitch_to_pbs_params(&res) +} + +pub(crate) fn wopbs_crt( + wopbs_key: &WopbsKey, + server_key: &crate::integer::ServerKey, + ct_in: &CrtCiphertext, + func: impl Fn(u64) -> u64, +) -> CrtCiphertext { + let switched_ct = wopbs_key.keyswitch_to_wopbs_params(server_key, ct_in); + let luts = wopbs_key.generate_lut_crt(&switched_ct, func); + let res = wopbs_key.wopbs(&switched_ct, luts.as_slice()); + wopbs_key.keyswitch_to_pbs_params(&res) +} + +pub(crate) fn bivariate_wopbs_crt( + wopbs_key: &WopbsKey, + server_key: &crate::integer::ServerKey, + lhs: &CrtCiphertext, + rhs: &CrtCiphertext, + func: impl Fn(u64, u64) -> u64, +) -> CrtCiphertext { + let switched_lhs = wopbs_key.keyswitch_to_wopbs_params(server_key, lhs); + let switched_rhs = wopbs_key.keyswitch_to_wopbs_params(server_key, rhs); + let lut = wopbs_key.generate_lut_bivariate_crt(&switched_lhs, &switched_rhs, func); + let res = wopbs_key.bivariate_wopbs_native_crt(&switched_lhs, &switched_rhs, lut.as_slice()); + wopbs_key.keyswitch_to_pbs_params(&res) +} + +impl

WopbsExecutor for GenericIntegerServerKey

+where + P: StaticRadixParameter, +{ + fn execute_wopbs u64>( + &self, + ct_in: &GenericInteger

, + func: F, + ) -> GenericInteger

{ + let ct = ct_in.ciphertext.borrow(); + let res = wopbs_radix(&self.wopbs_key, &self.inner, &ct, func); + GenericInteger::

::new(res, ct_in.id) + } + + fn execute_bivariate_wopbs u64>( + &self, + lhs: &GenericInteger

, + rhs: &GenericInteger

, + func: F, + ) -> GenericInteger

{ + let lhs_ct = lhs.ciphertext.borrow(); + let rhs_ct = rhs.ciphertext.borrow(); + + let res_ct = bivariate_wopbs_radix(&self.wopbs_key, &self.inner, &lhs_ct, &rhs_ct, func); + + GenericInteger::

::new(res_ct, lhs.id) + } +} + +impl

WopbsExecutor for GenericIntegerServerKey

+where + P: StaticCrtParameter, +{ + fn execute_wopbs u64>( + &self, + ct_in: &GenericInteger

, + func: F, + ) -> GenericInteger

{ + let ct = ct_in.ciphertext.borrow(); + let res = wopbs_crt(&self.wopbs_key, &self.inner, &ct, func); + GenericInteger::

::new(res, ct_in.id) + } + + fn execute_bivariate_wopbs u64>( + &self, + lhs: &GenericInteger

, + rhs: &GenericInteger

, + func: F, + ) -> GenericInteger

{ + let lhs_ct = lhs.ciphertext.borrow(); + let rhs_ct = rhs.ciphertext.borrow(); + + let res_ct = bivariate_wopbs_crt(&self.wopbs_key, &self.inner, &lhs_ct, &rhs_ct, func); + GenericInteger::

::new(res_ct, lhs.id) + } +} + +impl

FheBootstrap for GenericInteger

+where + P: StaticIntegerParameter, + P::Id: WithGlobalKey>, + GenericIntegerServerKey

: WopbsExecutor::Representation>, +{ + fn map u64>(&self, func: F) -> Self { + self.id + .with_unwrapped_global(|key| key.execute_wopbs(self, func)) + } + + fn apply u64>(&mut self, func: F) { + let result = self.map(func); + *self = result; + } +} + +impl

GenericInteger

+where + P: StaticIntegerParameter, + P::Id: WithGlobalKey>, + GenericIntegerServerKey

: WopbsExecutor::Representation>, +{ + pub fn bivariate_function(&self, other: &Self, func: F) -> Self + where + F: Fn(u64, u64) -> u64, + { + self.id + .with_unwrapped_global(|key| key.execute_bivariate_wopbs(self, other, func)) + } +} + +macro_rules! generic_integer_impl_operation ( + ($trait_name:ident($trait_method:ident,$op:tt, $smart_trait:ident) => $key_method:ident) => { + #[doc = concat!(" Allows using the `", stringify!($op), "` operator between a")] + #[doc = " `GenericInteger` and a `GenericInteger` or a `&GenericInteger`"] + #[doc = " "] + #[doc = " # Examples "] + #[doc = " "] + #[doc = " ```"] + #[doc = " # fn main() -> Result<(), tfhe::Error> {"] + #[doc = " use tfhe::prelude::*;"] + #[doc = " use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint8};"] + #[doc = " use std::num::Wrapping;"] + #[doc = " "] + #[doc = " let config = ConfigBuilder::all_disabled()"] + #[doc = " .enable_default_uint8()"] + #[doc = " .build();"] + #[doc = " let (keys, server_key) = generate_keys(config);"] + #[doc = " "] + #[doc = " let a = FheUint8::try_encrypt(142u32, &keys)?;"] + #[doc = " let b = FheUint8::try_encrypt(83u32, &keys)?;"] + #[doc = " "] + #[doc = " set_server_key(server_key);"] + #[doc = " "] + #[doc = concat!(" let c = a ", stringify!($op), " b;")] + #[doc = " let decrypted: u8 = c.decrypt(&keys);"] + #[doc = concat!(" let expected = Wrapping(142u8) ", stringify!($op), " Wrapping(83u8);")] + #[doc = " assert_eq!(decrypted, expected.0);"] + #[doc = " # Ok(())"] + #[doc = " # }"] + #[doc = " ```"] + #[doc = " "] + #[doc = " "] + #[doc = " ```"] + #[doc = " # fn main() -> Result<(), tfhe::Error> {"] + #[doc = " use tfhe::prelude::*;"] + #[doc = " use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint8};"] + #[doc = " use std::num::Wrapping;"] + #[doc = " "] + #[doc = " let config = ConfigBuilder::all_disabled()"] + #[doc = " .enable_default_uint8()"] + #[doc = " .build();"] + #[doc = " let (keys, server_key) = generate_keys(config);"] + #[doc = " "] + #[doc = " let a = FheUint8::try_encrypt(208u32, &keys)?;"] + #[doc = " let b = FheUint8::try_encrypt(29u32, &keys)?;"] + #[doc = " "] + #[doc = " set_server_key(server_key);"] + #[doc = " "] + #[doc = concat!(" let c = a ", stringify!($op), " &b;")] + #[doc = " let decrypted: u8 = c.decrypt(&keys);"] + #[doc = concat!(" let expected = Wrapping(208u8) ", stringify!($op), " Wrapping(29u8);")] + #[doc = " assert_eq!(decrypted, expected.0);"] + #[doc = " # Ok(())"] + #[doc = " # }"] + #[doc = " ```"] + impl $trait_name for GenericInteger

+ where + P: IntegerParameter, + B: Borrow, + GenericInteger

: Clone, + P::Id: WithGlobalKey>, + P::InnerServerKey: for<'a> $smart_trait< + &'a mut P::InnerCiphertext, + &'a mut P::InnerCiphertext, + Output=P::InnerCiphertext>, + { + type Output = Self; + + fn $trait_method(self, rhs: B) -> Self::Output { + <&Self as $trait_name>::$trait_method(&self, rhs) + } + } + + impl $trait_name for &GenericInteger

+ where + P: IntegerParameter, + B: Borrow>, + GenericInteger

: Clone, + P::Id: WithGlobalKey>, + P::InnerServerKey: for<'a> $smart_trait< + &'a mut P::InnerCiphertext, + &'a mut P::InnerCiphertext, + Output=P::InnerCiphertext>, + { + type Output = GenericInteger

; + + fn $trait_method(self, rhs: B) -> Self::Output { + let ciphertext = self.id.with_unwrapped_global(|key| { + let borrowed = rhs.borrow(); + if std::ptr::eq(self, borrowed) { + let cloned = (*borrowed).clone(); + let r = key.inner.$key_method( + &mut self.ciphertext.borrow_mut(), + &mut cloned.ciphertext.borrow_mut(), + ); + r + } else { + key.inner.$key_method( + &mut self.ciphertext.borrow_mut(), + &mut borrowed.ciphertext.borrow_mut(), + ) + } + }); + + GenericInteger::

::new(ciphertext, self.id) + } + } + } +); + +macro_rules! generic_integer_impl_operation_assign ( + ($trait_name:ident($trait_method:ident, $op:tt, $smart_assign_trait:ident) => $key_method:ident) => { + impl $trait_name for GenericInteger

+ where + P: IntegerParameter, + P::Id: WithGlobalKey>, + P::InnerServerKey: for<'a> $smart_assign_trait, + I: Borrow, + { + fn $trait_method(&mut self, rhs: I) { + self.id.with_unwrapped_global(|key| { + key.inner.$key_method( + self.ciphertext.get_mut(), + &mut rhs.borrow().ciphertext.borrow_mut() + ) + }) + } + } + } +); + +macro_rules! generic_integer_impl_scalar_operation { + ($trait_name:ident($trait_method:ident, $smart_trait:ident) => $key_method:ident($($scalar_type:ty),*)) => { + $( + impl

$trait_name<$scalar_type> for GenericInteger

+ where + P: IntegerParameter, + P::Id: WithGlobalKey>, + P::InnerServerKey: for<'a> $smart_trait< + &'a mut P::InnerCiphertext, + u64, + Output=P::InnerCiphertext>, + { + type Output = GenericInteger

; + + fn $trait_method(self, rhs: $scalar_type) -> Self::Output { + <&Self as $trait_name<$scalar_type>>::$trait_method(&self, rhs) + } + } + + impl

$trait_name<$scalar_type> for &GenericInteger

+ where + P: IntegerParameter, + P::Id: WithGlobalKey>, + P::InnerServerKey: for<'a> $smart_trait< + &'a mut P::InnerCiphertext, + u64, + Output=P::InnerCiphertext>, + { + type Output = GenericInteger

; + + fn $trait_method(self, rhs: $scalar_type) -> Self::Output { + let ciphertext = self.id.with_unwrapped_global(|key| { + key.inner.$key_method( + &mut self.ciphertext.borrow_mut(), + u64::from(rhs) + ) + }); + + GenericInteger::

::new(ciphertext, self.id) + } + } + )* + }; +} + +macro_rules! generic_integer_impl_scalar_operation_assign { + ($trait_name:ident($trait_method:ident,$smart_assign_trait:ident) => $key_method:ident($($scalar_type:ty),*)) => { + $( + impl

$trait_name<$scalar_type> for GenericInteger

+ where + P: IntegerParameter, + P::Id: WithGlobalKey>, + P::InnerServerKey: for<'a> $smart_assign_trait, + { + fn $trait_method(&mut self, rhs: $scalar_type) { + self.id.with_unwrapped_global(|key| { + key.inner.$key_method( + &mut self.ciphertext.borrow_mut(), + u64::from(rhs) + ) + }); + } + } + )* + } +} + +generic_integer_impl_operation!(Add(add,+, SmartAdd) => smart_add); +generic_integer_impl_operation!(Sub(sub,-, SmartSub) => smart_sub); +generic_integer_impl_operation!(Mul(mul,*, SmartMul) => smart_mul); +generic_integer_impl_operation!(BitAnd(bitand,&, SmartBitAnd) => smart_bitand); +generic_integer_impl_operation!(BitOr(bitor,|, SmartBitOr) => smart_bitor); +generic_integer_impl_operation!(BitXor(bitxor,^, SmartBitXor) => smart_bitxor); + +generic_integer_impl_operation_assign!(AddAssign(add_assign,+=, SmartAddAssign) => smart_add_assign); +generic_integer_impl_operation_assign!(SubAssign(sub_assign,-=, SmartSubAssign) => smart_sub_assign); +generic_integer_impl_operation_assign!(MulAssign(mul_assign,*=, SmartMulAssign) => smart_mul_assign); +generic_integer_impl_operation_assign!(BitAndAssign(bitand_assign,&=, SmartBitAndAssign) => smart_bitand_assign); +generic_integer_impl_operation_assign!(BitOrAssign(bitor_assign,|=, SmartBitOrAssign) => smart_bitor_assign); +generic_integer_impl_operation_assign!(BitXorAssign(bitxor_assign,^=, SmartBitXorAssign) => smart_bitxor_assign); + +generic_integer_impl_scalar_operation!(Add(add, SmartAdd) => smart_add(u8, u16, u32, u64)); +generic_integer_impl_scalar_operation!(Sub(sub, SmartSub) => smart_sub(u8, u16, u32, u64)); +generic_integer_impl_scalar_operation!(Mul(mul, SmartMul) => smart_mul(u8, u16, u32, u64)); +generic_integer_impl_scalar_operation!(Shl(shl, SmartShl) => smart_shl(u8, u16, u32, u64)); +generic_integer_impl_scalar_operation!(Shr(shr, SmartShr) => smart_shr(u8, u16, u32, u64)); + +generic_integer_impl_scalar_operation_assign!(AddAssign(add_assign, SmartAddAssign) => smart_add_assign(u8, u16, u32, u64)); +generic_integer_impl_scalar_operation_assign!(SubAssign(sub_assign, SmartSubAssign) => smart_sub_assign(u8, u16, u32, u64)); +generic_integer_impl_scalar_operation_assign!(MulAssign(mul_assign, SmartMulAssign) => smart_mul_assign(u8, u16, u32, u64)); +generic_integer_impl_scalar_operation_assign!(ShlAssign(shl_assign, SmartShlAssign) => smart_shl_assign(u8, u16, u32, u64)); +generic_integer_impl_scalar_operation_assign!(ShrAssign(shr_assign, SmartShrAssign) => smart_shr_assign(u8, u16, u32, u64)); + +impl

Neg for GenericInteger

+where + P: IntegerParameter, + P::Id: WithGlobalKey>, + GenericIntegerServerKey

: for<'a> SmartNeg<&'a GenericInteger

, Output = GenericInteger

>, + P::InnerServerKey: for<'a> SmartNeg<&'a mut P::InnerCiphertext, Output = P::InnerCiphertext>, +{ + type Output = GenericInteger

; + + fn neg(self) -> Self::Output { + <&Self as Neg>::neg(&self) + } +} + +impl

Neg for &GenericInteger

+where + P: IntegerParameter, + P::Id: WithGlobalKey>, + P::InnerServerKey: for<'a> SmartNeg<&'a mut P::InnerCiphertext, Output = P::InnerCiphertext>, +{ + type Output = GenericInteger

; + + fn neg(self) -> Self::Output { + let ciphertext = self + .id + .with_unwrapped_global(|key| key.inner.smart_neg(&mut self.ciphertext.borrow_mut())); + + GenericInteger::

::new(ciphertext, self.id) + } +} diff --git a/tfhe/src/typed_api/integers/types/mod.rs b/tfhe/src/typed_api/integers/types/mod.rs new file mode 100644 index 000000000..06e356011 --- /dev/null +++ b/tfhe/src/typed_api/integers/types/mod.rs @@ -0,0 +1,5 @@ +pub use base::GenericInteger; +pub use static_::{FheUint10, FheUint12, FheUint14, FheUint16, FheUint256, FheUint8}; + +pub(super) mod base; +pub(super) mod static_; diff --git a/tfhe/src/typed_api/integers/types/static_.rs b/tfhe/src/typed_api/integers/types/static_.rs new file mode 100644 index 000000000..5832b44fe --- /dev/null +++ b/tfhe/src/typed_api/integers/types/static_.rs @@ -0,0 +1,415 @@ +use serde::{Deserialize, Serialize}; + +use crate::typed_api::integers::client_key::GenericIntegerClientKey; +use crate::typed_api::integers::parameters::{ + EvaluationIntegerKey, IntegerParameter, RadixParameters, RadixRepresentation, + StaticIntegerParameter, StaticRadixParameter, +}; +use crate::typed_api::integers::public_key::GenericIntegerPublicKey; +use crate::typed_api::integers::server_key::GenericIntegerServerKey; +use crate::typed_api::keys::RefKeyFromKeyChain; +use crate::typed_api::traits::{FheDecrypt, FheEncrypt}; +use crate::typed_api::ClientKey; + +use super::base::GenericInteger; +#[cfg(feature = "internal-keycache")] +use crate::integer::keycache::{KEY_CACHE, KEY_CACHE_WOPBS}; +use crate::integer::wopbs::WopbsKey; +use crate::typed_api::internal_traits::ParameterType; +use paste::paste; + +macro_rules! define_static_integer_parameters { + ( + Radix { + num_bits: $num_bits:literal, + block_parameters: $block_parameters:expr, + num_block: $num_block:literal, + wopbs_block_parameters: $wopbs_block_parameters:expr, + } + ) => { + paste! { + #[doc = concat!("Id for the [FheUint", stringify!($num_bits), "] data type.")] + #[derive(Copy, Clone, Debug, Default, Serialize, Deserialize)] + pub struct []; + + #[doc = concat!("Parameters for the [FheUint", stringify!($num_bits), "] data type.")] + #[derive(Copy, Clone, Debug, Serialize, Deserialize)] + pub struct [](RadixParameters); + + impl Default for [] { + fn default() -> Self { + Self( + RadixParameters { + block_parameters: $block_parameters, + num_block: $num_block, + wopbs_block_parameters: $wopbs_block_parameters, + }, + ) + } + } + + impl ParameterType for [] { + type Id = []; + type InnerCiphertext = crate::integer::RadixCiphertext; + type InnerClientKey = crate::integer::RadixClientKey; + type InnerPublicKey = crate::typed_api::integers::public_key::RadixPublicKey; + type InnerServerKey = crate::integer::ServerKey; + } + + impl IntegerParameter for [] { + fn wopbs_block_parameters(&self) -> crate::shortint::Parameters { + self.0.wopbs_block_parameters + } + + fn block_parameters(&self) -> crate::shortint::Parameters { + self.0.block_parameters + } + } + + impl From<[]> for RadixParameters { + fn from(p: []) -> Self { + p.0 + } + } + + impl StaticIntegerParameter for [] { + type Representation = RadixRepresentation; + const MESSAGE_BITS: usize = $num_bits; + } + + impl StaticRadixParameter for [] {} + } + }; + ( + Crt { + num_bits: $num_bits:literal, + block_parameters: $block_parameters:expr, + moduli: $moduli:expr, + wopbs_block_parameters: $wopbs_block_parameters:expr, + } + ) => { + paste! { + #[doc = concat!("Id for the [FheUint", stringify!($num_bits), "] data type.")] + #[derive(Copy, Clone, Debug, Default, Serialize, Deserialize)] + pub struct []; + + #[doc = concat!("Parameters for the [FheUint", stringify!($num_bits), "] data type.")] + #[derive(Copy, Clone, Debug, Default, Serialize, Deserialize)] + pub struct [](CrtParameters); + + impl Default for [] { + fn default() -> Self { + Self( + CrtParameters { + block_parameters: $block_parameters, + moduli: $moduli, + wopbs_block_parameters: $wopbs_block_parameters, + }, + ) + } + } + + impl ParameterType for [] { + type Id = []; + type InnerCiphertext = crate::integer::CrtCiphertext; + type InnerClientKey = crate::integer::CrtClientKey; + type InnerPublicKey = crate::typed_api::integers::public_key::CrtPublicKey; + type InnerServerKey = crate::integer::ServerKey; + } + + impl IntegerParameter for [] { + fn wopbs_block_parameters(&self) -> crate::shortint::Parameters { + self.0.wopbs_block_parameters + } + + fn block_parameters(&self) -> crate::shortint::Parameters { + self.0.block_parameters + } + } + + impl From<[]> for CrtCiphertext { + fn from(p: []) -> Self { + p.0 + } + } + + impl StaticIntegerParameter for [] { + type Representation = CrtRepresentation; + const MESSAGE_BITS: usize = $num_bits; + } + + impl StaticCrtParameter for [] {} + } + }; +} + +macro_rules! static_int_type { + // This rule generates the types specialization + // as well as call the macros + // that implement necessary traits for the ClientKey and ServerKey + // + // This is not meant to be used directly, instead see the other rules below + ( + @impl_types_and_key_traits, + $(#[$outer:meta])* + $name:ident { + num_bits: $num_bits:literal, + keychain_member: $($member:ident).*, + } + ) => { + paste! { + #[doc = concat!("ClientKey for the [", stringify!($name), "] data type.")] + pub(in crate::typed_api::integers) type [<$name ClientKey>] = GenericIntegerClientKey<[<$name Parameters>]>; + + #[doc = concat!("PublicKey for the [", stringify!($name), "] data type.")] + pub(in crate::typed_api::integers) type [<$name PublicKey>] = GenericIntegerPublicKey<[<$name Parameters>]>; + + #[doc = concat!("ServerKey for the [", stringify!($name), "] data type.")] + pub(in crate::typed_api::integers) type [<$name ServerKey>] = GenericIntegerServerKey<[<$name Parameters>]>; + + #[doc = concat!("An unsigned integer type with", stringify!($num_bits), "bits")] + $(#[$outer])* + #[cfg_attr(all(doc, not(doctest)), cfg(feature = "integer"))] + pub type $name = GenericInteger<[<$name Parameters>]>; + + impl_ref_key_from_keychain!( + for <[<$name Parameters>] as ParameterType>::Id { + key_type: [<$name ClientKey>], + keychain_member: $($member).*, + type_variant: crate::typed_api::errors::Type::$name, + } + ); + + impl_ref_key_from_public_keychain!( + for <[<$name Parameters>] as ParameterType>::Id { + key_type: [<$name PublicKey>], + keychain_member: $($member).*, + type_variant: crate::typed_api::errors::Type::$name, + } + ); + + impl_with_global_key!( + for <[<$name Parameters>] as ParameterType>::Id { + key_type: [<$name ServerKey>], + keychain_member: $($member).*, + type_variant: crate::typed_api::errors::Type::$name, + } + ); + } + }; + + // Defines a static integer type that uses + // the `Radix` representation + ( + $(#[$outer:meta])* + { + num_bits: $num_bits:literal, + keychain_member: $($member:ident).*, + parameters: Radix { + block_parameters: $block_parameters:expr, + num_block: $num_block:literal, + wopbs_block_parameters: $wopbs_block_parameters:expr, + }, + } + ) => { + define_static_integer_parameters!( + Radix { + num_bits: $num_bits, + block_parameters: $block_parameters, + num_block: $num_block, + wopbs_block_parameters: $wopbs_block_parameters, + } + ); + + ::paste::paste!{ + static_int_type!( + @impl_types_and_key_traits, + $(#[$outer])* + [] { + num_bits: $num_bits, + keychain_member: $($member).*, + } + ); + } + }; + + // Defines a static integer type that uses + // the `CRT` representation + ( + $(#[$outer:meta])* + { + num_bits: $num_bits:literal, + keychain_member: $($member:ident).*, + parameters: Crt { + block_parameters: $block_parameters:expr, + moduli: $moduli:expr, + wopbs_block_parameters: $wopbs_block_parameters:expr, + }, + } + ) => { + define_static_integer_parameters!( + Crt { + num_bits: $num_bits, + block_parameters: $block_parameters, + moduli: $moduli, + wopbs_block_parameters: $wopbs_block_parameters, + } + ); + + ::paste::paste!{ + static_int_type!( + @impl_types_and_key_traits, + $(#[$outer])* + [] { + num_bits: $num_bits, + keychain_member: $($member).*, + } + ); + } + }; +} + +impl EvaluationIntegerKey for crate::integer::ServerKey +where + C: AsRef, +{ + fn new(client_key: &C) -> Self { + #[cfg(feature = "internal-keycache")] + { + KEY_CACHE + .get_from_params(client_key.as_ref().parameters()) + .1 + } + #[cfg(not(feature = "internal-keycache"))] + { + crate::integer::ServerKey::new(client_key) + } + } + + fn new_wopbs_key( + client_key: &C, + server_key: &Self, + wopbs_block_parameters: crate::shortint::Parameters, + ) -> WopbsKey { + #[cfg(not(feature = "internal-keycache"))] + { + WopbsKey::new_wopbs_key(client_key.as_ref(), server_key, &wopbs_block_parameters) + } + #[cfg(feature = "internal-keycache")] + { + let _ = &server_key; // silence warning + KEY_CACHE_WOPBS + .get_from_params((client_key.as_ref().parameters(), wopbs_block_parameters)) + } + } +} + +static_int_type! { + { + num_bits: 8, + keychain_member: integer_key.uint8_key, + parameters: Radix { + block_parameters: crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2, + num_block: 4, + wopbs_block_parameters: crate::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2, + }, + } +} + +static_int_type! { + { + num_bits: 10, + keychain_member: integer_key.uint10_key, + parameters: Radix { + block_parameters: crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2, + num_block: 5, + wopbs_block_parameters: crate::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2, + }, + } +} + +static_int_type! { + { + num_bits: 12, + keychain_member: integer_key.uint12_key, + parameters: Radix { + block_parameters: crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2, + num_block: 6, + wopbs_block_parameters: crate::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2, + }, + } +} + +static_int_type! { + { + num_bits: 14, + keychain_member: integer_key.uint14_key, + parameters: Radix { + block_parameters: crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2, + num_block: 7, + wopbs_block_parameters: crate::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2, + }, + } +} + +static_int_type! { + { + num_bits: 16, + keychain_member: integer_key.uint16_key, + parameters: Radix { + block_parameters: crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2, + num_block: 8, + wopbs_block_parameters: crate::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2, + }, + } +} + +static_int_type! { + { + num_bits: 256, + keychain_member: integer_key.uint256_key, + parameters: Radix { + block_parameters: crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2, + num_block: 128, + wopbs_block_parameters: crate::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2, + }, + } +} + +impl FheEncrypt for GenericInteger { + #[track_caller] + fn encrypt(value: u8, key: &ClientKey) -> Self { + let id = ::Id::default(); + let key = id.unwrapped_ref_key(key); + let ciphertext = key.inner.encrypt(u64::from(value)); + Self::new(ciphertext, id) + } +} + +impl FheDecrypt for FheUint8 { + #[track_caller] + fn decrypt(&self, key: &ClientKey) -> u8 { + let id = ::Id::default(); + let key = id.unwrapped_ref_key(key); + key.inner.decrypt(&self.ciphertext.borrow()) as u8 + } +} + +impl FheEncrypt for FheUint16 { + #[track_caller] + fn encrypt(value: u16, key: &ClientKey) -> Self { + let id = ::Id::default(); + let key = id.unwrapped_ref_key(key); + let ciphertext = key.inner.encrypt(u64::from(value)); + Self::new(ciphertext, id) + } +} + +impl FheDecrypt for FheUint16 { + #[track_caller] + fn decrypt(&self, key: &ClientKey) -> u16 { + let id = ::Id::default(); + let key = id.unwrapped_ref_key(key); + key.inner.decrypt(&self.ciphertext.borrow()) as u16 + } +} diff --git a/tfhe/src/typed_api/internal_traits.rs b/tfhe/src/typed_api/internal_traits.rs new file mode 100644 index 000000000..7b3212a5e --- /dev/null +++ b/tfhe/src/typed_api/internal_traits.rs @@ -0,0 +1,35 @@ +/// Trait to be implemented on keys that encrypts clear values into ciphertexts +pub(crate) trait EncryptionKey { + /// The type of ciphertext returned as a result of the encryption + type Ciphertext; + + /// The encryption process + fn encrypt(&self, value: ClearType) -> Self::Ciphertext; +} + +/// Trait to be implemented on keys that decrypts ciphertext into clear values +pub(crate) trait DecryptionKey { + /// The type of ciphertext that this key decrypts + type Ciphertext; + + /// The decryption process + fn decrypt(&self, ciphertext: &Self::Ciphertext) -> ClearType; +} + +pub trait FromParameters

{ + fn from_parameters(parameters: P) -> Self; +} + +pub trait ParameterType: Clone { + /// The Id allows to differentiate the different parameters + /// as well as retrieving the corresponding client key and server key + type Id: Copy; + /// The ciphertext type that will be wrapped. + type InnerCiphertext: serde::Serialize + for<'de> serde::Deserialize<'de>; + /// The client key type that will be wrapped. + type InnerClientKey; + /// The public key that will be wrapped; + type InnerPublicKey; + /// The server key type that will be wrapped. + type InnerServerKey; +} diff --git a/tfhe/src/typed_api/keys/client.rs b/tfhe/src/typed_api/keys/client.rs new file mode 100644 index 000000000..b0a35e093 --- /dev/null +++ b/tfhe/src/typed_api/keys/client.rs @@ -0,0 +1,107 @@ +//! This module defines ClientKey +//! +//! - [ClientKey] aggregates the keys used to encrypt/decrypt between normal and homomorphic types. + +#[cfg(feature = "boolean")] +use crate::typed_api::booleans::BooleanClientKey; +use crate::typed_api::config::Config; +use crate::typed_api::errors::{UninitializedClientKey, UnwrapResultExt}; +#[cfg(feature = "integer")] +use crate::typed_api::integers::IntegerClientKey; +#[cfg(feature = "shortint")] +use crate::typed_api::shortints::ShortIntClientKey; + +use super::ServerKey; + +/// Key of the client +/// +/// This struct contains the keys that are of interest to the user +/// as they will allow to encrypt and decrypt data. +/// +/// This key **MUST NOT** be sent to the server. +#[derive(Clone, Debug)] +pub struct ClientKey { + #[cfg(feature = "boolean")] + pub(crate) boolean_key: BooleanClientKey, + #[cfg(feature = "shortint")] + pub(crate) shortint_key: ShortIntClientKey, + #[cfg(feature = "integer")] + pub(crate) integer_key: IntegerClientKey, +} + +impl ClientKey { + /// Generates a new keys. + pub fn generate>(config: C) -> ClientKey { + #[allow(unused_variables)] + let config: Config = config.into(); + ClientKey { + #[cfg(feature = "boolean")] + boolean_key: BooleanClientKey::from(config.boolean_config), + #[cfg(feature = "shortint")] + shortint_key: ShortIntClientKey::from(config.shortint_config), + #[cfg(feature = "integer")] + integer_key: IntegerClientKey::from(config.integer_config), + } + } + + /// Generates a new ServerKeyChain + /// + /// The `ServerKeyChain` generated is meant to be used to initialize the global state + /// using [crate::typed_api::set_server_key]. + pub fn generate_server_key(&self) -> ServerKey { + ServerKey::new(self) + } +} + +/// Trait to be implemented on the client key types that have a corresponding member +/// in the `ClientKeyChain`. +/// +/// This is to allow the writing of generic functions. +pub trait RefKeyFromKeyChain: Sized { + type Key; + + /// The method to implement, shall return a ref to the key or an error if + /// the key member in the key was not initialized + fn ref_key(self, keys: &ClientKey) -> Result<&Self::Key, UninitializedClientKey>; + + /// Returns a mutable ref to the key member of the key + /// + /// # Panic + /// + /// This will panic if the key was not initialized + #[track_caller] + fn unwrapped_ref_key(self, keys: &ClientKey) -> &Self::Key { + self.ref_key(keys).unwrap_display() + } +} + +/// Helper macro to help reduce boiler plate +/// needed to implement `RefKeyFromKeyChain` since for +/// our keys, the implementation is the same, only a few things change. +/// +/// It expects: +/// - The implementor type +/// - The `name` of the key type for which the trait will be implemented. +/// - The identifier (or identifier chain) that points to the member in the `ClientKey` that holds +/// the key for which the trait is implemented. +/// - Type Variant used to identify the type at runtime (see `error.rs`) +#[cfg(any(feature = "integer", feature = "shortint", feature = "boolean"))] +macro_rules! impl_ref_key_from_keychain { + ( + for $implementor:ty { + key_type: $key_type:ty, + keychain_member: $($member:ident).*, + type_variant: $enum_variant:expr, + } + ) => { + impl crate::typed_api::keys::RefKeyFromKeyChain for $implementor { + type Key = $key_type; + + fn ref_key(self, keys: &crate::typed_api::keys::ClientKey) -> Result<&Self::Key, crate::typed_api::errors::UninitializedClientKey> { + keys$(.$member)* + .as_ref() + .ok_or(crate::typed_api::errors::UninitializedClientKey($enum_variant)) + } + } + } +} diff --git a/tfhe/src/typed_api/keys/mod.rs b/tfhe/src/typed_api/keys/mod.rs new file mode 100644 index 000000000..88d4aae9e --- /dev/null +++ b/tfhe/src/typed_api/keys/mod.rs @@ -0,0 +1,31 @@ +#[macro_use] +mod client; +#[macro_use] +mod public; +mod server; + +pub use client::{ClientKey, RefKeyFromKeyChain}; +pub use public::{PublicKey, RefKeyFromPublicKeyChain}; +pub use server::ServerKey; + +use crate::typed_api::config::Config; + +/// Generates keys using the provided config. +/// +/// # Example +/// +/// ``` +/// # #[cfg(feature = "shortint")] +/// # { +/// use tfhe::{generate_keys, ConfigBuilder}; +/// +/// let config = ConfigBuilder::all_disabled().enable_default_uint3().build(); +/// let (client_key, server_key) = generate_keys(config); +/// # } +/// ``` +pub fn generate_keys>(config: C) -> (ClientKey, ServerKey) { + let client_kc = ClientKey::generate(config); + let server_kc = client_kc.generate_server_key(); + + (client_kc, server_kc) +} diff --git a/tfhe/src/typed_api/keys/public.rs b/tfhe/src/typed_api/keys/public.rs new file mode 100644 index 000000000..ef653f90c --- /dev/null +++ b/tfhe/src/typed_api/keys/public.rs @@ -0,0 +1,82 @@ +//! This module defines PublicKey +//! +//! - [PublicKey] aggregates a key that can be made public, and that allows to encrypt (only) + +#[cfg(feature = "boolean")] +use crate::typed_api::booleans::BooleanPublicKey; +use crate::typed_api::errors::{UninitializedPublicKey, UnwrapResultExt}; +#[cfg(feature = "integer")] +use crate::typed_api::integers::IntegerPublicKey; +#[cfg(feature = "shortint")] +use crate::typed_api::shortints::ShortIntPublicKey; + +use super::ClientKey; + +pub struct PublicKey { + #[cfg(feature = "boolean")] + pub(crate) boolean_key: BooleanPublicKey, + #[cfg(feature = "shortint")] + pub(crate) shortint_key: ShortIntPublicKey, + #[cfg(feature = "integer")] + pub(crate) integer_key: IntegerPublicKey, +} + +impl PublicKey { + pub fn new(client_key: &ClientKey) -> Self { + // Silence warning about unused variable when none of these feature is used + #[cfg(not(any(feature = "boolean", feature = "shortint", feature = "integer")))] + let _ = client_key; + + Self { + #[cfg(feature = "boolean")] + boolean_key: BooleanPublicKey::new(&client_key.boolean_key), + #[cfg(feature = "shortint")] + shortint_key: ShortIntPublicKey::new(&client_key.shortint_key), + #[cfg(feature = "integer")] + integer_key: IntegerPublicKey::new(&client_key.integer_key), + } + } +} + +/// Trait to be implemented on the public key types that have a corresponding member +/// in the `PublicKey`. +/// +/// This is to allow the writing of generic functions. +pub trait RefKeyFromPublicKeyChain: Sized { + type Key; + + /// The method to implement, shall return a ref to the key or an error if + /// the key member in the key was not initialized + fn ref_key(self, keys: &PublicKey) -> Result<&Self::Key, UninitializedPublicKey>; + + /// Returns a mutable ref to the key member of the key + /// + /// # Panic + /// + /// This will panic if the key was not initialized + #[track_caller] + fn unwrapped_ref_key(self, keys: &PublicKey) -> &Self::Key { + self.ref_key(keys).unwrap_display() + } +} + +#[cfg(any(feature = "integer", feature = "shortint", feature = "boolean"))] +macro_rules! impl_ref_key_from_public_keychain { + ( + for $implementor:ty { + key_type: $key_type:ty, + keychain_member: $($member:ident).*, + type_variant: $enum_variant:expr, + } + ) => { + impl crate::typed_api::keys::RefKeyFromPublicKeyChain for $implementor { + type Key = $key_type; + + fn ref_key(self, keys: &crate::typed_api::keys::PublicKey) -> Result<&Self::Key, crate::typed_api::errors::UninitializedPublicKey> { + keys$(.$member)* + .as_ref() + .ok_or(crate::typed_api::errors::UninitializedPublicKey($enum_variant)) + } + } + } +} diff --git a/tfhe/src/typed_api/keys/server.rs b/tfhe/src/typed_api/keys/server.rs new file mode 100644 index 000000000..41d6b0d07 --- /dev/null +++ b/tfhe/src/typed_api/keys/server.rs @@ -0,0 +1,45 @@ +#[cfg(feature = "boolean")] +use crate::typed_api::booleans::BooleanServerKey; +#[cfg(feature = "integer")] +use crate::typed_api::integers::IntegerServerKey; +#[cfg(feature = "shortint")] +use crate::typed_api::shortints::ShortIntServerKey; + +#[cfg(any(feature = "boolean", feature = "shortint", feature = "integer"))] +use std::sync::Arc; + +use super::ClientKey; + +/// Key of the server +/// +/// This key contains the different keys needed to be able to do computations for +/// each data type. +/// +/// For a server to be able to do some FHE computations, the client needs to send this key +/// beforehand. +// Keys are stored in an Arc, so that cloning them is cheap +// (compared to an actual clone hundreds of MB / GB), and cheap cloning is needed for +// multithreading with less overhead) +#[derive(Clone, Default)] +pub struct ServerKey { + #[cfg(feature = "boolean")] + pub(crate) boolean_key: Arc, + #[cfg(feature = "shortint")] + pub(crate) shortint_key: Arc, + #[cfg(feature = "integer")] + pub(crate) integer_key: Arc, +} + +impl ServerKey { + #[allow(unused_variables)] + pub(crate) fn new(keys: &ClientKey) -> Self { + Self { + #[cfg(feature = "boolean")] + boolean_key: Arc::new(BooleanServerKey::new(&keys.boolean_key)), + #[cfg(feature = "shortint")] + shortint_key: Arc::new(ShortIntServerKey::new(&keys.shortint_key)), + #[cfg(feature = "integer")] + integer_key: Arc::new(IntegerServerKey::new(&keys.integer_key)), + } + } +} diff --git a/tfhe/src/typed_api/mod.rs b/tfhe/src/typed_api/mod.rs new file mode 100644 index 000000000..8486b9f43 --- /dev/null +++ b/tfhe/src/typed_api/mod.rs @@ -0,0 +1,42 @@ +#![allow(unused_doc_comments)] +pub use config::{Config, ConfigBuilder}; +pub use errors::{Error, OutOfRangeError}; +pub use global_state::{set_server_key, unset_server_key, with_server_key_as_context}; +pub use keys::{generate_keys, ClientKey, PublicKey, ServerKey}; + +#[cfg(test)] +mod tests; + +#[cfg(feature = "boolean")] +pub use crate::typed_api::booleans::{CompressedFheBool, FheBool, FheBoolParameters}; +#[cfg(feature = "integer")] +pub use crate::typed_api::integers::{ + CrtParameters, FheUint10, FheUint12, FheUint14, FheUint16, FheUint256, FheUint8, + GenericInteger, RadixParameters, +}; +#[cfg(feature = "shortint")] +pub use crate::typed_api::shortints::{ + CompressedFheUint2, CompressedFheUint3, CompressedFheUint4, FheUint2, FheUint2Parameters, + FheUint3, FheUint3Parameters, FheUint4, FheUint4Parameters, +}; +#[macro_use] +mod details; +#[macro_use] +mod global_state; +#[macro_use] +mod keys; +mod config; +mod internal_traits; +mod traits; + +#[cfg(feature = "boolean")] +mod booleans; +pub mod errors; +#[cfg(feature = "integer")] +mod integers; +/// The tfhe prelude. +pub mod prelude; +#[cfg(feature = "shortint")] +mod shortints; + +pub mod parameters {} diff --git a/tfhe/src/typed_api/prelude.rs b/tfhe/src/typed_api/prelude.rs new file mode 100644 index 000000000..971d51d9c --- /dev/null +++ b/tfhe/src/typed_api/prelude.rs @@ -0,0 +1,12 @@ +//! The purpose of this module is to make it easier to have the most commonly needed +//! traits of this crate. +//! +//! It is meant to be glob imported: +//! ``` +//! use tfhe::prelude::*; +//! ``` +pub use crate::typed_api::traits::{ + DynamicFheEncryptor, DynamicFheTrivialEncryptor, DynamicFheTryEncryptor, FheBootstrap, + FheDecrypt, FheEncrypt, FheEq, FheNumberConstant, FheOrd, FheTrivialEncrypt, FheTryEncrypt, + FheTryTrivialEncrypt, +}; diff --git a/tfhe/src/typed_api/shortints/client_key.rs b/tfhe/src/typed_api/shortints/client_key.rs new file mode 100644 index 000000000..9e76f8014 --- /dev/null +++ b/tfhe/src/typed_api/shortints/client_key.rs @@ -0,0 +1,38 @@ +use std::marker::PhantomData; + +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "internal-keycache")] +use crate::shortint::keycache::KEY_CACHE; +use crate::shortint::ClientKey; + +use super::parameters::ShortIntegerParameter; + +/// The key associated to a short integer type +/// +/// Can encrypt and decrypt it. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GenericShortIntClientKey { + pub(super) key: ClientKey, + _marker: PhantomData

, +} + +impl

From

for GenericShortIntClientKey

+where + P: ShortIntegerParameter, +{ + fn from(parameters: P) -> Self { + #[cfg(feature = "internal-keycache")] + let key = KEY_CACHE + .get_from_param(parameters.into()) + .client_key() + .clone(); + #[cfg(not(feature = "internal-keycache"))] + let key = ClientKey::new(parameters.into()); + + Self { + key, + _marker: Default::default(), + } + } +} diff --git a/tfhe/src/typed_api/shortints/keys.rs b/tfhe/src/typed_api/shortints/keys.rs new file mode 100644 index 000000000..77c59797e --- /dev/null +++ b/tfhe/src/typed_api/shortints/keys.rs @@ -0,0 +1,7 @@ +define_key_structs! { + ShortInt { + uint2: FheUint2, + uint3: FheUint3, + uint4: FheUint4, + } +} diff --git a/tfhe/src/typed_api/shortints/mod.rs b/tfhe/src/typed_api/shortints/mod.rs new file mode 100644 index 000000000..f7eb204c1 --- /dev/null +++ b/tfhe/src/typed_api/shortints/mod.rs @@ -0,0 +1,16 @@ +pub(crate) use keys::{ShortIntClientKey, ShortIntConfig, ShortIntPublicKey, ShortIntServerKey}; +pub use types::{ + CompressedFheUint2, CompressedFheUint3, CompressedFheUint4, CompressedGenericShortint, + FheUint2, FheUint2Parameters, FheUint3, FheUint3Parameters, FheUint4, FheUint4Parameters, + GenericShortInt, +}; + +mod client_key; +mod keys; +mod parameters; +mod public_key; +mod server_key; +mod types; + +#[cfg(test)] +mod tests; diff --git a/tfhe/src/typed_api/shortints/parameters.rs b/tfhe/src/typed_api/shortints/parameters.rs new file mode 100644 index 000000000..afa0a64d9 --- /dev/null +++ b/tfhe/src/typed_api/shortints/parameters.rs @@ -0,0 +1,7 @@ +pub trait ShortIntegerParameter: Copy + Into { + type Id: Copy; +} + +pub trait StaticShortIntegerParameter: ShortIntegerParameter { + const MESSAGE_BITS: u8; +} diff --git a/tfhe/src/typed_api/shortints/public_key.rs b/tfhe/src/typed_api/shortints/public_key.rs new file mode 100644 index 000000000..8cd39c8e1 --- /dev/null +++ b/tfhe/src/typed_api/shortints/public_key.rs @@ -0,0 +1,27 @@ +use crate::typed_api::shortints::client_key::GenericShortIntClientKey; + +use crate::typed_api::shortints::parameters::ShortIntegerParameter; +use serde::{Deserialize, Serialize}; + +#[cfg_attr(all(doc, not(doctest)), cfg(feature = "shortint"))] +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GenericShortIntPublicKey

+where + P: ShortIntegerParameter, +{ + pub(in crate::typed_api::shortints) key: crate::shortint::public_key::PublicKey, + _marker: std::marker::PhantomData

, +} + +impl

GenericShortIntPublicKey

+where + P: ShortIntegerParameter, +{ + pub fn new(client_key: &GenericShortIntClientKey

) -> Self { + let key = crate::shortint::public_key::PublicKey::new(&client_key.key); + Self { + key, + _marker: Default::default(), + } + } +} diff --git a/tfhe/src/typed_api/shortints/server_key.rs b/tfhe/src/typed_api/shortints/server_key.rs new file mode 100644 index 000000000..29295b2eb --- /dev/null +++ b/tfhe/src/typed_api/shortints/server_key.rs @@ -0,0 +1,501 @@ +use std::cell::RefCell; +use std::marker::PhantomData; + +use serde::{Deserialize, Serialize}; + +#[cfg(feature = "internal-keycache")] +use crate::shortint::keycache::KEY_CACHE; +use crate::shortint::ServerKey; + +use super::client_key::GenericShortIntClientKey; +use super::parameters::ShortIntegerParameter; +use super::types::GenericShortInt; + +/// The internal key of a short integer type +/// +/// A wrapper around `tfhe-shortint` `ServerKey` +#[derive(Clone, Serialize, Deserialize)] +pub struct GenericShortIntServerKey { + pub(super) key: ServerKey, + _marker: PhantomData

, +} + +/// The internal key wraps some of the inner ServerKey methods +/// so that its input and outputs are type of this crate. +impl

GenericShortIntServerKey

+where + P: ShortIntegerParameter, +{ + pub(crate) fn new(client_key: &GenericShortIntClientKey

) -> Self { + #[cfg(feature = "internal-keycache")] + let key = KEY_CACHE + .get_from_param(client_key.key.parameters) + .server_key() + .clone(); + #[cfg(not(feature = "internal-keycache"))] + let key = ServerKey::new(&client_key.key); + + Self { + key, + _marker: Default::default(), + } + } + + pub(crate) fn smart_add( + &self, + lhs: &GenericShortInt

, + rhs: &GenericShortInt

, + ) -> GenericShortInt

{ + let ciphertext = self.key.smart_add( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_sub( + &self, + lhs: &GenericShortInt

, + rhs: &GenericShortInt

, + ) -> GenericShortInt

{ + let ciphertext = self.key.smart_sub( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_mul( + &self, + lhs: &GenericShortInt

, + rhs: &GenericShortInt

, + ) -> GenericShortInt

{ + let ciphertext = self.key.smart_mul_lsb( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_div( + &self, + lhs: &GenericShortInt

, + rhs: &GenericShortInt

, + ) -> GenericShortInt

{ + let ciphertext = self.key.smart_div( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_add_assign(&self, lhs: &GenericShortInt

, rhs: &GenericShortInt

) { + self.key.smart_add_assign( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + } + + pub(crate) fn smart_sub_assign(&self, lhs: &GenericShortInt

, rhs: &GenericShortInt

) { + self.key.smart_sub_assign( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + } + + pub(crate) fn smart_mul_assign(&self, lhs: &GenericShortInt

, rhs: &GenericShortInt

) { + self.key.smart_mul_lsb_assign( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + } + + pub(crate) fn smart_div_assign(&self, lhs: &GenericShortInt

, rhs: &GenericShortInt

) { + self.key.smart_div_assign( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ) + } + + pub(crate) fn smart_bitand_assign(&self, lhs: &GenericShortInt

, rhs: &GenericShortInt

) { + self.key.smart_bitand_assign( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + } + + pub(crate) fn smart_bitor_assign(&self, lhs: &GenericShortInt

, rhs: &GenericShortInt

) { + self.key.smart_bitor_assign( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + } + + pub(crate) fn smart_bitxor_assign(&self, lhs: &GenericShortInt

, rhs: &GenericShortInt

) { + self.key.smart_bitxor_assign( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + } + + pub(crate) fn smart_scalar_sub(&self, lhs: &GenericShortInt

, rhs: u8) -> GenericShortInt

{ + let ciphertext = self + .key + .smart_scalar_sub(&mut lhs.ciphertext.borrow_mut(), rhs); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_scalar_mul(&self, lhs: &GenericShortInt

, rhs: u8) -> GenericShortInt

{ + let ciphertext = self + .key + .smart_scalar_mul(&mut lhs.ciphertext.borrow_mut(), rhs); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_scalar_add( + &self, + lhs: &GenericShortInt

, + scalar: u8, + ) -> GenericShortInt

{ + let ciphertext = self + .key + .smart_scalar_add(&mut lhs.ciphertext.borrow_mut(), scalar); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_scalar_add_assign(&self, lhs: &mut GenericShortInt

, rhs: u8) { + self.key + .smart_scalar_add_assign(&mut lhs.ciphertext.borrow_mut(), rhs) + } + + pub(crate) fn smart_scalar_mul_assign(&self, lhs: &mut GenericShortInt

, rhs: u8) { + self.key + .smart_scalar_mul_assign(&mut lhs.ciphertext.borrow_mut(), rhs) + } + + pub(crate) fn smart_scalar_sub_assign(&self, lhs: &mut GenericShortInt

, rhs: u8) { + self.key + .smart_scalar_sub_assign(&mut lhs.ciphertext.borrow_mut(), rhs) + } + + pub(crate) fn smart_bitand( + &self, + lhs: &GenericShortInt

, + rhs: &GenericShortInt

, + ) -> GenericShortInt

{ + let ciphertext = self.key.smart_bitand( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_bitor( + &self, + lhs: &GenericShortInt

, + rhs: &GenericShortInt

, + ) -> GenericShortInt

{ + let ciphertext = self.key.smart_bitor( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_bitxor( + &self, + lhs: &GenericShortInt

, + rhs: &GenericShortInt

, + ) -> GenericShortInt

{ + let ciphertext = self.key.smart_bitxor( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_less( + &self, + lhs: &GenericShortInt

, + rhs: &GenericShortInt

, + ) -> GenericShortInt

{ + let ciphertext = self.key.smart_less( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_less_or_equal( + &self, + lhs: &GenericShortInt

, + rhs: &GenericShortInt

, + ) -> GenericShortInt

{ + let ciphertext = self.key.smart_less_or_equal( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_greater( + &self, + lhs: &GenericShortInt

, + rhs: &GenericShortInt

, + ) -> GenericShortInt

{ + let ciphertext = self.key.smart_greater( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_greater_or_equal( + &self, + lhs: &GenericShortInt

, + rhs: &GenericShortInt

, + ) -> GenericShortInt

{ + let ciphertext = self.key.smart_greater_or_equal( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_equal( + &self, + lhs: &GenericShortInt

, + rhs: &GenericShortInt

, + ) -> GenericShortInt

{ + let ciphertext = self.key.smart_equal( + &mut lhs.ciphertext.borrow_mut(), + &mut rhs.ciphertext.borrow_mut(), + ); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_scalar_equal( + &self, + lhs: &GenericShortInt

, + scalar: u8, + ) -> GenericShortInt

{ + let ciphertext = self + .key + .smart_scalar_equal(&lhs.ciphertext.borrow(), scalar); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_scalar_greater_or_equal( + &self, + lhs: &GenericShortInt

, + scalar: u8, + ) -> GenericShortInt

{ + let ciphertext = self + .key + .smart_scalar_greater_or_equal(&lhs.ciphertext.borrow(), scalar); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_scalar_less_or_equal( + &self, + lhs: &GenericShortInt

, + scalar: u8, + ) -> GenericShortInt

{ + let ciphertext = self + .key + .smart_scalar_less_or_equal(&lhs.ciphertext.borrow(), scalar); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_scalar_greater( + &self, + lhs: &GenericShortInt

, + scalar: u8, + ) -> GenericShortInt

{ + let ciphertext = self + .key + .smart_scalar_greater(&lhs.ciphertext.borrow(), scalar); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_scalar_less( + &self, + lhs: &GenericShortInt

, + scalar: u8, + ) -> GenericShortInt

{ + let ciphertext = self.key.smart_scalar_less(&lhs.ciphertext.borrow(), scalar); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_scalar_left_shift( + &self, + lhs: &GenericShortInt

, + rhs: u8, + ) -> GenericShortInt

{ + let ciphertext = self + .key + .smart_scalar_left_shift(&mut lhs.ciphertext.borrow_mut(), rhs); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn unchecked_scalar_right_shift( + &self, + lhs: &GenericShortInt

, + rhs: u8, + ) -> GenericShortInt

{ + let ciphertext = self + .key + .unchecked_scalar_right_shift(&lhs.ciphertext.borrow(), rhs); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn unchecked_scalar_div( + &self, + lhs: &GenericShortInt

, + rhs: u8, + ) -> GenericShortInt

{ + let ciphertext = self.key.unchecked_scalar_div(&lhs.ciphertext.borrow(), rhs); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn unchecked_scalar_mod( + &self, + lhs: &GenericShortInt

, + rhs: u8, + ) -> GenericShortInt

{ + let ciphertext = self.key.unchecked_scalar_mod(&lhs.ciphertext.borrow(), rhs); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(crate) fn smart_neg(&self, lhs: &GenericShortInt

) -> GenericShortInt

{ + let ciphertext = self.key.smart_neg(&mut lhs.ciphertext.borrow_mut()); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs.id, + } + } + + pub(super) fn bootstrap_with( + &self, + ciphertext: &GenericShortInt

, + func: F, + ) -> GenericShortInt

+ where + F: Fn(u64) -> u64, + { + let accumulator = self.key.generate_accumulator(func); + let new_ciphertext = self + .key + .keyswitch_programmable_bootstrap(&ciphertext.ciphertext.borrow(), &accumulator); + GenericShortInt { + ciphertext: RefCell::new(new_ciphertext), + id: ciphertext.id, + } + } + + pub(super) fn bootstrap_inplace_with(&self, ciphertext: &mut GenericShortInt

, func: F) + where + F: Fn(u64) -> u64, + { + let accumulator = self.key.generate_accumulator(func); + self.key.keyswitch_programmable_bootstrap_assign( + &mut ciphertext.ciphertext.borrow_mut(), + &accumulator, + ) + } + + pub(super) fn bivariate_pbs( + &self, + lhs_ct: &GenericShortInt

, + rhs_ct: &GenericShortInt

, + func: F, + ) -> GenericShortInt

+ where + P: ShortIntegerParameter, + F: Fn(u8, u8) -> u8, + { + let wrapped_f = |lhs: u64, rhs: u64| -> u64 { u64::from(func(lhs as u8, rhs as u8)) }; + + let ciphertext = self.key.smart_functional_bivariate_pbs( + &mut lhs_ct.ciphertext.borrow_mut(), + &mut rhs_ct.ciphertext.borrow_mut(), + wrapped_f, + ); + GenericShortInt { + ciphertext: RefCell::new(ciphertext), + id: lhs_ct.id, + } + } +} diff --git a/tfhe/src/typed_api/shortints/tests.rs b/tfhe/src/typed_api/shortints/tests.rs new file mode 100644 index 000000000..adb4ac52c --- /dev/null +++ b/tfhe/src/typed_api/shortints/tests.rs @@ -0,0 +1,13 @@ +use crate::typed_api::prelude::*; +use crate::typed_api::{generate_keys, CompressedFheUint2, ConfigBuilder, FheUint2}; + +#[test] +fn test_shortint_compressed() { + let config = ConfigBuilder::all_enabled().enable_default_uint2().build(); + let (client_key, _) = generate_keys(config); + + let compressed: CompressedFheUint2 = CompressedFheUint2::try_encrypt(2, &client_key).unwrap(); + let a = FheUint2::from(compressed); + let decompressed = a.decrypt(&client_key); + assert_eq!(decompressed, 2); +} diff --git a/tfhe/src/typed_api/shortints/types/base.rs b/tfhe/src/typed_api/shortints/types/base.rs new file mode 100644 index 000000000..c18da01e2 --- /dev/null +++ b/tfhe/src/typed_api/shortints/types/base.rs @@ -0,0 +1,807 @@ +use std::borrow::Borrow; +use std::cell::RefCell; +use std::ops::{ + Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, + Mul, MulAssign, Neg, Rem, Shl, Shr, Sub, SubAssign, +}; + +use serde::{Deserialize, Serialize}; + +use crate::shortint::ciphertext::Ciphertext; + +use crate::typed_api::errors::OutOfRangeError; +use crate::typed_api::global_state::WithGlobalKey; +use crate::typed_api::keys::{ClientKey, RefKeyFromKeyChain, RefKeyFromPublicKeyChain}; +use crate::typed_api::traits::{ + FheBootstrap, FheDecrypt, FheEq, FheNumberConstant, FheOrd, FheTryEncrypt, FheTryTrivialEncrypt, +}; +use crate::typed_api::PublicKey; + +use super::{GenericShortIntClientKey, GenericShortIntServerKey}; + +use crate::typed_api::shortints::parameters::{ShortIntegerParameter, StaticShortIntegerParameter}; +use crate::typed_api::shortints::public_key::GenericShortIntPublicKey; + +/// A Generic short FHE unsigned integer +/// +/// Short means less than 7 bits. +/// +/// It is generic over some parameters, as its the parameters +/// that controls how many bit they represent. +/// +/// Its the type that overloads the operators (`+`, `-`, `*`). +/// Since the `GenericShortInt` type is not `Copy` the operators are also overloaded +/// to work with references. +/// +/// You will need to use one of this type specialization (e.g., [FheUint2], [FheUint3], [FheUint4]). +/// +/// To be able to use this type, the cargo feature `shortints` must be enabled, +/// and your config should also enable the type with either default parameters or custom ones. +/// +/// # Example +/// +/// To use FheUint2 +/// +/// ``` +/// # #[cfg(feature = "shortint")] +/// # fn main() -> Result<(), Box> { +/// use tfhe::prelude::*; +/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint2}; +/// +/// // Enable the FheUint2 type in the config +/// let config = ConfigBuilder::all_disabled().enable_default_uint2().build(); +/// +/// // With the FheUint2 type enabled in the config, the needed keys and details +/// // can be taken care of. +/// let (client_key, server_key) = generate_keys(config); +/// +/// let a = FheUint2::try_encrypt(0, &client_key)?; +/// let b = FheUint2::try_encrypt(1, &client_key)?; +/// +/// // Do not forget to set the server key before doing any computation +/// set_server_key(server_key); +/// +/// // Since FHE types are bigger than native rust type they are not `Copy`, +/// // meaning that to reuse the same value in a computation and avoid the cost +/// // of calling `clone`, you'll have to use references: +/// let c = a + &b; +/// // `a` was moved but not `b`, so `a` cannot be reused, but `b` can +/// let d = &c + b; +/// // `b` was moved but not `c`, so `b` cannot be reused, but `c` can +/// let fhe_result = d + c; +/// // both `d` and `c` were moved. +/// +/// let expected: u8 = { +/// let a = 0; +/// let b = 1; +/// +/// let c = a + b; +/// let d = c + b; +/// d + c +/// }; +/// let clear_result = fhe_result.decrypt(&client_key); +/// assert_eq!(expected, 3); +/// assert_eq!(clear_result, expected); +/// +/// # Ok(()) +/// # } +/// ``` +/// +/// [FheUint2]: crate::FheUint2 +/// [FheUint3]: crate::FheUint3 +/// [FheUint4]: crate::FheUint4 +#[cfg_attr(all(doc, not(doctest)), cfg(feature = "shortint"))] +#[derive(Clone, Serialize, Deserialize)] +pub struct GenericShortInt { + /// The actual ciphertext. + /// Wrapped inside a RefCell because some methods + /// of the corresponding `ServerKey` (in tfhe-shortint) + /// require the ciphertext to be a `&mut`, + /// while we also overloads rust operators for have a `&` references + pub(in crate::typed_api::shortints) ciphertext: RefCell, + pub(in crate::typed_api::shortints) id: P::Id, +} + +impl

GenericShortInt

+where + P: ShortIntegerParameter, +{ + pub(crate) fn new(inner: Ciphertext, id: P::Id) -> Self { + Self { + ciphertext: RefCell::new(inner), + id, + } + } +} + +impl

GenericShortInt

+where + P: ShortIntegerParameter, +{ + pub fn message_max(&self) -> u64 { + self.message_modulus() - 1 + } + + pub fn message_modulus(&self) -> u64 { + self.ciphertext.borrow().message_modulus.0 as u64 + } +} + +impl

GenericShortInt

+where + P: StaticShortIntegerParameter, +{ + /// Minimum value this type can hold, always 0. + pub const MIN: u8 = 0; + + /// Maximum value this type can hold. + pub const MAX: u8 = (1 << P::MESSAGE_BITS) - 1; + + pub const MODULUS: u8 = (1 << P::MESSAGE_BITS); +} + +impl

FheNumberConstant for GenericShortInt

+where + P: StaticShortIntegerParameter, +{ + const MIN: u64 = 0; + + const MAX: u64 = Self::MAX as u64; + + const MODULUS: u64 = Self::MODULUS as u64; +} + +impl FheTryEncrypt for GenericShortInt

+where + T: TryInto, + P: StaticShortIntegerParameter, + P::Id: Default + RefKeyFromKeyChain>, +{ + type Error = OutOfRangeError; + + /// Try to create a new value. + /// + /// As Shortints exposed by this crate have between 2 and 7 bits, + /// creating a value from a rust `u8` may not be possible. + /// + /// # Example + /// + /// ``` + /// # #[cfg(feature = "shortint")] + /// # { + /// # use tfhe::{ConfigBuilder, FheUint3, generate_keys, set_server_key}; + /// # let config = ConfigBuilder::all_disabled().enable_default_uint3().build(); + /// # let (client_key, server_key) = generate_keys(config); + /// # set_server_key(server_key); + /// use tfhe::prelude::*; + /// use tfhe::Error; + /// + /// // The maximum value that can be represented with 3 bits is 7. + /// let a = FheUint3::try_encrypt(8, &client_key); + /// assert_eq!(a.is_err(), true); + /// + /// let a = FheUint3::try_encrypt(7, &client_key); + /// assert_eq!(a.is_ok(), true); + /// # } + /// ``` + #[track_caller] + fn try_encrypt(value: T, key: &ClientKey) -> Result { + let value = value.try_into().map_err(|_err| OutOfRangeError)?; + if value > Self::MAX { + Err(OutOfRangeError) + } else { + let id = P::Id::default(); + let key = id.unwrapped_ref_key(key); + let ciphertext = key.key.encrypt(u64::from(value)); + Ok(Self { + ciphertext: RefCell::new(ciphertext), + id, + }) + } + } +} + +impl FheTryEncrypt for GenericShortInt

+where + T: TryInto, + P: StaticShortIntegerParameter, + P::Id: Default + RefKeyFromPublicKeyChain>, +{ + type Error = crate::typed_api::errors::Error; + + /// Try to create a new value. + /// + /// As Shortints exposed by this crate have between 2 and 7 bits, + /// creating a value from a rust `u8` may not be possible. + /// + /// # Example + /// + /// ``` + /// # use tfhe::PublicKey; + /// #[cfg(feature = "shortint")] + /// # { + /// # use tfhe::{ConfigBuilder, PublicKey, FheUint2, generate_keys, set_server_key}; + /// # let config = ConfigBuilder::all_disabled().enable_default_uint2().build(); + /// # let (client_key, server_key) = generate_keys(config); + /// # set_server_key(server_key); + /// use tfhe::prelude::*; + /// use tfhe::Error; + /// + /// let public_key = PublicKey::new(&client_key); + /// + /// // The maximum value that can be represented with 2 bits is 3. + /// let a = FheUint2::try_encrypt(8, &public_key); + /// assert_eq!(a.is_err(), true); + /// + /// let a = FheUint2::try_encrypt(3, &public_key); + /// assert_eq!(a.is_ok(), true); + /// # } + /// ``` + #[track_caller] + fn try_encrypt(value: T, key: &PublicKey) -> Result { + let value = value.try_into().map_err(|_err| OutOfRangeError)?; + if value > Self::MAX { + Err(OutOfRangeError.into()) + } else { + let id = P::Id::default(); + let key = id.unwrapped_ref_key(key); + let ciphertext = key.key.encrypt(u64::from(value)); + Ok(Self { + ciphertext: RefCell::new(ciphertext), + id, + }) + } + } +} + +impl FheTryTrivialEncrypt for GenericShortInt

+where + Clear: TryInto, + P: StaticShortIntegerParameter, + P::Id: Default + WithGlobalKey>, +{ + type Error = crate::typed_api::errors::Error; + + fn try_encrypt_trivial(value: Clear) -> Result { + let value = value.try_into().map_err(|_err| OutOfRangeError)?; + if value > Self::MAX { + Err(OutOfRangeError.into()) + } else { + let id = P::Id::default(); + id.with_global(|key| { + let ciphertext = key.key.create_trivial(value.into()); + Ok(Self { + ciphertext: RefCell::new(ciphertext), + id, + }) + })? + } + } +} + +impl

GenericShortInt

+where + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, +{ + pub fn bivariate_function(&self, other: &Self, func: F) -> Self + where + F: Fn(u8, u8) -> u8, + { + self.id + .with_unwrapped_global(|server_key| server_key.bivariate_pbs(self, other, func)) + } +} + +impl

FheOrd for GenericShortInt

+where + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, +{ + type Output = Self; + + fn lt(&self, rhs: u8) -> Self { + self.id + .with_unwrapped_global(|server_key| server_key.smart_scalar_less(self, rhs)) + } + + fn le(&self, rhs: u8) -> Self { + self.id + .with_unwrapped_global(|server_key| server_key.smart_scalar_less_or_equal(self, rhs)) + } + + fn gt(&self, rhs: u8) -> Self { + self.id + .with_unwrapped_global(|server_key| server_key.smart_scalar_greater(self, rhs)) + } + + fn ge(&self, rhs: u8) -> Self { + self.id + .with_unwrapped_global(|server_key| server_key.smart_scalar_greater_or_equal(self, rhs)) + } +} + +impl

FheEq for GenericShortInt

+where + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, +{ + type Output = Self; + + fn eq(&self, rhs: u8) -> Self::Output { + self.id + .with_unwrapped_global(|server_key| server_key.smart_scalar_equal(self, rhs)) + } +} + +impl FheOrd for GenericShortInt

+where + B: Borrow, + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, +{ + type Output = Self; + + fn lt(&self, other: B) -> Self::Output { + self.id + .with_unwrapped_global(|server_key| server_key.smart_less(self, other.borrow())) + } + + fn le(&self, other: B) -> Self::Output { + self.id.with_unwrapped_global(|server_key| { + server_key.smart_less_or_equal(self, other.borrow()) + }) + } + + fn gt(&self, other: B) -> Self::Output { + self.id + .with_unwrapped_global(|server_key| server_key.smart_greater(self, other.borrow())) + } + + fn ge(&self, other: B) -> Self::Output { + self.id.with_unwrapped_global(|server_key| { + server_key.smart_greater_or_equal(self, other.borrow()) + }) + } +} + +impl FheEq for GenericShortInt

+where + B: Borrow, + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, +{ + type Output = Self; + + fn eq(&self, other: B) -> Self { + self.id + .with_unwrapped_global(|server_key| server_key.smart_equal(self, other.borrow())) + } +} + +impl

FheBootstrap for GenericShortInt

+where + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, +{ + fn map(&self, func: F) -> Self + where + F: Fn(u64) -> u64, + { + self.id + .with_unwrapped_global(|key| key.bootstrap_with(self, func)) + } + + fn apply(&mut self, func: F) + where + F: Fn(u64) -> u64, + { + self.id.with_unwrapped_global(|key| { + key.bootstrap_inplace_with(self, func); + }) + } +} + +impl std::iter::Sum for GenericShortInt

+where + B: Borrow, + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, + Self: FheTryTrivialEncrypt + AddAssign, +{ + fn sum>(iter: I) -> Self { + let mut sum = Self::try_encrypt_trivial(0u8).expect("Failed to trivially encrypt zero"); + for item in iter { + sum += item; + } + sum + } +} + +impl std::iter::Product for GenericShortInt

+where + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, + Self: FheTryTrivialEncrypt + MulAssign, +{ + fn product>(iter: I) -> Self { + let mut product = Self::try_encrypt_trivial(1u8).expect( + "Failed to trivially encrypt +one", + ); + for item in iter { + product *= item; + } + product + } +} + +impl

FheDecrypt for GenericShortInt

+where + P: ShortIntegerParameter, + P::Id: RefKeyFromKeyChain>, +{ + /// Decrypt the encrypted value to a u8 + /// # Example + /// + /// ``` + /// # #[cfg(feature = "shortint")] + /// # fn main() -> Result<(), Box> { + /// # use tfhe::{ConfigBuilder, FheUint3, FheUint2, generate_keys, set_server_key}; + /// # let config = ConfigBuilder::all_disabled().enable_default_uint3().enable_default_uint2().build(); + /// # let (client_key, server_key) = generate_keys(config); + /// # set_server_key(server_key); + /// use tfhe::Error; + /// use tfhe::prelude::*; + /// + /// let a = FheUint2::try_encrypt(2, &client_key)?; + /// let a_clear = a.decrypt(&client_key); + /// assert_eq!(a_clear, 2); + /// + /// let a = FheUint3::try_encrypt(7, &client_key)?; + /// let a_clear = a.decrypt(&client_key); + /// assert_eq!(a_clear, 7); + /// # Ok(()) + /// # } + /// ``` + #[track_caller] + fn decrypt(&self, key: &ClientKey) -> u8 { + let key = self.id.unwrapped_ref_key(key); + key.key.decrypt(&self.ciphertext.borrow()) as u8 + } +} + +macro_rules! short_int_impl_operation ( + ($trait_name:ident($trait_method:ident, $op:tt) => $key_method:ident) => { + #[doc = concat!(" Allows using the `", stringify!($op), "` operator between a")] + #[doc = " `GenericFheUint` and a `GenericFheUint` or a `&GenericFheUint`"] + #[doc = " "] + #[doc = " # Examples "] + #[doc = " "] + #[doc = " ```"] + #[doc = " # fn main() -> Result<(), tfhe::Error> {"] + #[doc = " use tfhe::prelude::*;"] + #[doc = " use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint2};"] + #[doc = " "] + #[doc = " let config = ConfigBuilder::all_disabled()"] + #[doc = " .enable_default_uint2()"] + #[doc = " .build();"] + #[doc = " let (keys, server_key) = generate_keys(config);"] + #[doc = " "] + #[doc = " let a = FheUint2::try_encrypt(2, &keys)?;"] + #[doc = " let b = FheUint2::try_encrypt(1, &keys)?;"] + #[doc = " "] + #[doc = " set_server_key(server_key);"] + #[doc = " "] + #[doc = concat!(" let c = a ", stringify!($op), " b;")] + #[doc = " let decrypted = c.decrypt(&keys);"] + #[doc = concat!(" let expected = 2 ", stringify!($op), " 1;")] + #[doc = " assert_eq!(decrypted, expected);"] + #[doc = " # Ok(())"] + #[doc = " # }"] + #[doc = " ```"] + #[doc = " "] + #[doc = " "] + #[doc = " ```"] + #[doc = " # fn main() -> Result<(), tfhe::Error> {"] + #[doc = " use tfhe::prelude::*;"] + #[doc = " use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint2};"] + #[doc = " "] + #[doc = " let config = ConfigBuilder::all_disabled()"] + #[doc = " .enable_default_uint2()"] + #[doc = " .build();"] + #[doc = " let (keys, server_key) = generate_keys(config);"] + #[doc = " "] + #[doc = " let a = FheUint2::try_encrypt(2, &keys)?;"] + #[doc = " let b = FheUint2::try_encrypt(1, &keys)?;"] + #[doc = " "] + #[doc = " set_server_key(server_key);"] + #[doc = " "] + #[doc = concat!(" let c = a ", stringify!($op), " &b;")] + #[doc = " let decrypted = c.decrypt(&keys);"] + #[doc = concat!(" let expected = 2 ", stringify!($op), " 1;")] + #[doc = " assert_eq!(decrypted, expected);"] + #[doc = " # Ok(())"] + #[doc = " # }"] + #[doc = " ```"] + impl $trait_name for GenericShortInt

+ where + P: ShortIntegerParameter, + GenericShortInt

: Clone, + P::Id: WithGlobalKey>, + I: Borrow, + { + type Output = Self; + + fn $trait_method(self, rhs: I) -> Self::Output { + <&Self as $trait_name>::$trait_method(&self, rhs) + } + } + + #[doc = concat!(" Allows using the `", stringify!($op), "` operator between a")] + #[doc = " `&GenericFheUint` and a `GenericFheUint` or a `&GenericFheUint`"] + #[doc = " "] + #[doc = " # Examples "] + #[doc = " "] + #[doc = " ```"] + #[doc = " # fn main() -> Result<(), tfhe::Error> {"] + #[doc = " use tfhe::prelude::*;"] + #[doc = " use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint2};"] + #[doc = " "] + #[doc = " let config = ConfigBuilder::all_disabled()"] + #[doc = " .enable_default_uint2()"] + #[doc = " .build();"] + #[doc = " let (keys, server_key) = generate_keys(config);"] + #[doc = " "] + #[doc = " let a = FheUint2::try_encrypt(2, &keys)?;"] + #[doc = " let b = FheUint2::try_encrypt(1, &keys)?;"] + #[doc = " "] + #[doc = " set_server_key(server_key);"] + #[doc = " "] + #[doc = concat!(" let c = &a ", stringify!($op), " b;")] + #[doc = " let decrypted = c.decrypt(&keys);"] + #[doc = concat!(" let expected = 2 ", stringify!($op), " 1;")] + #[doc = " assert_eq!(decrypted, expected);"] + #[doc = " # Ok(())"] + #[doc = " # }"] + #[doc = " ```"] + #[doc = " "] + #[doc = " "] + #[doc = " ```"] + #[doc = " # fn main() -> Result<(), tfhe::Error> {"] + #[doc = " use tfhe::prelude::*;"] + #[doc = " use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint2};"] + #[doc = " "] + #[doc = " let config = ConfigBuilder::all_disabled()"] + #[doc = " .enable_default_uint2()"] + #[doc = " .build();"] + #[doc = " let (keys, server_key) = generate_keys(config);"] + #[doc = " "] + #[doc = " let a = FheUint2::try_encrypt(2, &keys)?;"] + #[doc = " let b = FheUint2::try_encrypt(1, &keys)?;"] + #[doc = " "] + #[doc = " set_server_key(server_key);"] + #[doc = " "] + #[doc = concat!(" let c = &a ", stringify!($op), " &b;")] + #[doc = " let decrypted = c.decrypt(&keys);"] + #[doc = concat!(" let expected = 2 ", stringify!($op), " 1;")] + #[doc = " assert_eq!(decrypted, expected);"] + #[doc = " # Ok(())"] + #[doc = " # }"] + #[doc = " ```"] + impl $trait_name for &GenericShortInt

+ where + P: ShortIntegerParameter, + GenericShortInt

: Clone, + P::Id: WithGlobalKey>, + I: Borrow>, + { + type Output = GenericShortInt

; + + fn $trait_method(self, rhs: I) -> Self::Output { + self.id.with_unwrapped_global(|key| { + let borrowed = rhs.borrow(); + if std::ptr::eq(self, borrowed) { + let cloned = (*borrowed).clone(); + key.$key_method(self, &cloned) + } else { + key.$key_method(self, borrowed) + } + }) + } + } + }; +); + +macro_rules! short_int_impl_operation_assign ( + ($trait_name:ident($trait_method:ident, $op:tt) => $key_method:ident) => { + #[doc = concat!(" Allows using the `", stringify!($op), "` operator between a")] + #[doc = " `GenericFheUint` and a `GenericFheUint` or a `&GenericFheUint`"] + #[doc = " "] + #[doc = " # Examples "] + #[doc = " "] + #[doc = " ```"] + #[doc = " # fn main() -> Result<(), tfhe::Error> {"] + #[doc = " use tfhe::prelude::*;"] + #[doc = " use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint2};"] + #[doc = " "] + #[doc = " let config = ConfigBuilder::all_disabled()"] + #[doc = " .enable_default_uint2()"] + #[doc = " .build();"] + #[doc = " let (keys, server_key) = generate_keys(config);"] + #[doc = " "] + #[doc = " let mut a = FheUint2::try_encrypt(2, &keys)?;"] + #[doc = " let b = FheUint2::try_encrypt(1, &keys)?;"] + #[doc = " "] + #[doc = " set_server_key(server_key);"] + #[doc = " "] + #[doc = concat!(" a ", stringify!($op), " b;")] + #[doc = " let decrypted = a.decrypt(&keys);"] + #[doc = " let mut expected = 2;"] + #[doc = concat!(" expected ", stringify!($op), " 1;")] + #[doc = " assert_eq!(decrypted, expected);"] + #[doc = " # Ok(())"] + #[doc = " # }"] + #[doc = " ```"] + #[doc = " "] + #[doc = " "] + #[doc = " ```"] + #[doc = " # fn main() -> Result<(), tfhe::Error> {"] + #[doc = " use tfhe::prelude::*;"] + #[doc = " use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint2};"] + #[doc = " "] + #[doc = " let config = ConfigBuilder::all_disabled()"] + #[doc = " .enable_default_uint2()"] + #[doc = " .build();"] + #[doc = " let (keys, server_key) = generate_keys(config);"] + #[doc = " "] + #[doc = " let mut a = FheUint2::try_encrypt(2, &keys)?;"] + #[doc = " let b = FheUint2::try_encrypt(1, &keys)?;"] + #[doc = " "] + #[doc = " set_server_key(server_key);"] + #[doc = " "] + #[doc = concat!(" a ", stringify!($op), " &b;")] + #[doc = " let decrypted = a.decrypt(&keys);"] + #[doc = " let mut expected = 2;"] + #[doc = concat!(" expected ", stringify!($op), " 1;")] + #[doc = " assert_eq!(decrypted, expected);"] + #[doc = " # Ok(())"] + #[doc = " # }"] + #[doc = " ```"] + impl $trait_name for GenericShortInt

+ where + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, + I: Borrow, + { + fn $trait_method(&mut self, rhs: I) { + // no need to check if self == rhs as, since we have &mut to self + // we know its exclusive + self.id.with_unwrapped_global(|key| { + key.$key_method(&self, rhs.borrow()) + }) + } + } + } +); + +// Scalar operations +macro_rules! short_int_impl_scalar_operation { + ($trait_name:ident($trait_method:ident) => $key_method:ident) => { + impl

$trait_name for &GenericShortInt

+ where + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, + { + type Output = GenericShortInt

; + + fn $trait_method(self, rhs: u8) -> Self::Output { + self.id + .with_unwrapped_global(|key| key.$key_method(self, rhs)) + } + } + + impl

$trait_name for GenericShortInt

+ where + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, + { + type Output = GenericShortInt

; + + fn $trait_method(self, rhs: u8) -> Self::Output { + <&Self as $trait_name>::$trait_method(&self, rhs) + } + } + + impl

$trait_name<&GenericShortInt

> for u8 + where + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, + { + type Output = GenericShortInt

; + + fn $trait_method(self, rhs: &GenericShortInt

) -> Self::Output { + <&GenericShortInt

as $trait_name>::$trait_method(rhs, self) + } + } + + impl

$trait_name> for u8 + where + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, + { + type Output = GenericShortInt

; + + fn $trait_method(self, rhs: GenericShortInt

) -> Self::Output { + >>::$trait_method(self, &rhs) + } + } + }; +} + +macro_rules! short_int_impl_scalar_operation_assign { + ($trait_name:ident($trait_method:ident) => $key_method:ident) => { + impl

$trait_name for GenericShortInt

+ where + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, + { + fn $trait_method(&mut self, rhs: u8) { + self.id + .with_unwrapped_global(|key| key.$key_method(self, rhs)) + } + } + }; +} + +impl

Neg for GenericShortInt

+where + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, +{ + type Output = Self; + + fn neg(self) -> Self::Output { + self.id.with_unwrapped_global(|key| key.smart_neg(&self)) + } +} + +impl

Neg for &GenericShortInt

+where + P: ShortIntegerParameter, + P::Id: WithGlobalKey>, +{ + type Output = GenericShortInt

; + + fn neg(self) -> Self::Output { + self.id.with_unwrapped_global(|key| key.smart_neg(self)) + } +} + +short_int_impl_operation!(Add(add,+) => smart_add); +short_int_impl_operation!(Sub(sub,-) => smart_sub); +short_int_impl_operation!(Mul(mul,*) => smart_mul); +short_int_impl_operation!(Div(div,/) => smart_div); +short_int_impl_operation!(BitAnd(bitand,&) => smart_bitand); +short_int_impl_operation!(BitOr(bitor,|) => smart_bitor); +short_int_impl_operation!(BitXor(bitxor,^) => smart_bitxor); + +short_int_impl_operation_assign!(AddAssign(add_assign,+=) => smart_add_assign); +short_int_impl_operation_assign!(SubAssign(sub_assign,-=) => smart_sub_assign); +short_int_impl_operation_assign!(MulAssign(mul_assign,*=) => smart_mul_assign); +short_int_impl_operation_assign!(DivAssign(div_assign,/=) => smart_div_assign); +short_int_impl_operation_assign!(BitAndAssign(bitand_assign,&=) => smart_bitand_assign); +short_int_impl_operation_assign!(BitOrAssign(bitor_assign,|=) => smart_bitor_assign); +short_int_impl_operation_assign!(BitXorAssign(bitxor_assign,^=) => smart_bitxor_assign); + +short_int_impl_scalar_operation!(Add(add) => smart_scalar_add); +short_int_impl_scalar_operation!(Sub(sub) => smart_scalar_sub); +short_int_impl_scalar_operation!(Mul(mul) => smart_scalar_mul); +short_int_impl_scalar_operation!(Div(div) => unchecked_scalar_div); +short_int_impl_scalar_operation!(Rem(rem) => unchecked_scalar_mod); +short_int_impl_scalar_operation!(Shl(shl) => smart_scalar_left_shift); +short_int_impl_scalar_operation!(Shr(shr) => unchecked_scalar_right_shift); + +short_int_impl_scalar_operation_assign!(AddAssign(add_assign) => smart_scalar_add_assign); +short_int_impl_scalar_operation_assign!(SubAssign(sub_assign) => smart_scalar_sub_assign); +short_int_impl_scalar_operation_assign!(MulAssign(mul_assign) => smart_scalar_mul_assign); diff --git a/tfhe/src/typed_api/shortints/types/compressed.rs b/tfhe/src/typed_api/shortints/types/compressed.rs new file mode 100644 index 000000000..eb6369756 --- /dev/null +++ b/tfhe/src/typed_api/shortints/types/compressed.rs @@ -0,0 +1,53 @@ +use crate::shortint::CompressedCiphertext; +use crate::typed_api::keys::RefKeyFromKeyChain; +use crate::typed_api::shortints::client_key::GenericShortIntClientKey; +use crate::typed_api::shortints::parameters::ShortIntegerParameter; +use crate::typed_api::shortints::GenericShortInt; +use crate::typed_api::traits::FheTryEncrypt; +use crate::typed_api::ClientKey; + +pub struct CompressedGenericShortint

+where + P: ShortIntegerParameter, +{ + pub(in crate::typed_api::shortints) ciphertext: CompressedCiphertext, + pub(in crate::typed_api::shortints) id: P::Id, +} + +impl

CompressedGenericShortint

+where + P: ShortIntegerParameter, +{ + pub(crate) fn new(inner: CompressedCiphertext, id: P::Id) -> Self { + Self { + ciphertext: inner, + id, + } + } +} + +impl

From> for GenericShortInt

+where + P: ShortIntegerParameter, +{ + fn from(value: CompressedGenericShortint

) -> Self { + let inner = value.ciphertext.into(); + Self::new(inner, value.id) + } +} + +impl

FheTryEncrypt for CompressedGenericShortint

+where + P: ShortIntegerParameter, + P::Id: Default + RefKeyFromKeyChain>, +{ + type Error = crate::typed_api::errors::Error; + + fn try_encrypt(value: u8, key: &ClientKey) -> Result { + let id = P::Id::default(); + let key = id.ref_key(key)?; + + let inner = key.key.encrypt_compressed(value as u64); + Ok(Self::new(inner, id)) + } +} diff --git a/tfhe/src/typed_api/shortints/types/mod.rs b/tfhe/src/typed_api/shortints/types/mod.rs new file mode 100644 index 000000000..dc512081e --- /dev/null +++ b/tfhe/src/typed_api/shortints/types/mod.rs @@ -0,0 +1,15 @@ +pub use base::GenericShortInt; +pub use compressed::CompressedGenericShortint; + +pub use static_::{ + CompressedFheUint2, CompressedFheUint3, CompressedFheUint4, FheUint2, FheUint2Parameters, + FheUint3, FheUint3Parameters, FheUint4, FheUint4Parameters, +}; + +use super::client_key::GenericShortIntClientKey; +use super::public_key::GenericShortIntPublicKey; +use super::server_key::GenericShortIntServerKey; + +mod base; +mod compressed; +pub(crate) mod static_; diff --git a/tfhe/src/typed_api/shortints/types/static_.rs b/tfhe/src/typed_api/shortints/types/static_.rs new file mode 100644 index 000000000..b8f32cbbf --- /dev/null +++ b/tfhe/src/typed_api/shortints/types/static_.rs @@ -0,0 +1,332 @@ +use serde::{Deserialize, Serialize}; +use std::fmt::Formatter; + +use crate::shortint::parameters::{ + CarryModulus, DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, + MessageModulus, Parameters, PolynomialSize, StandardDev, +}; + +use crate::typed_api::shortints::{CompressedGenericShortint, GenericShortInt}; + +use super::{GenericShortIntClientKey, GenericShortIntPublicKey, GenericShortIntServerKey}; + +use crate::typed_api::shortints::parameters::{ShortIntegerParameter, StaticShortIntegerParameter}; + +use paste::paste; + +/// Generic Parameter struct for short integers. +/// +/// It allows to customize the same parameters as the ones +/// from the underlying `crate::shortint` with the exception of +/// the number of bits of message as its embeded in the type. +#[derive(Copy, Clone, Debug)] +pub struct ShortIntegerParameterSet { + pub lwe_dimension: LweDimension, + pub glwe_dimension: GlweDimension, + pub polynomial_size: PolynomialSize, + pub lwe_modular_std_dev: StandardDev, + pub glwe_modular_std_dev: StandardDev, + pub pbs_base_log: DecompositionBaseLog, + pub pbs_level: DecompositionLevelCount, + pub ks_base_log: DecompositionBaseLog, + pub ks_level: DecompositionLevelCount, + pub pfks_level: DecompositionLevelCount, + pub pfks_base_log: DecompositionBaseLog, + pub pfks_modular_std_dev: StandardDev, + pub cbs_level: DecompositionLevelCount, + pub cbs_base_log: DecompositionBaseLog, + pub carry_modulus: CarryModulus, +} + +impl ShortIntegerParameterSet { + const fn from_static(params: &'static Parameters) -> Self { + if params.message_modulus.0 != 1 << MESSAGE_BITS as usize { + panic!("Invalid bit number"); + } + Self { + lwe_dimension: params.lwe_dimension, + glwe_dimension: params.glwe_dimension, + polynomial_size: params.polynomial_size, + lwe_modular_std_dev: params.lwe_modular_std_dev, + glwe_modular_std_dev: params.glwe_modular_std_dev, + pbs_base_log: params.pbs_base_log, + pbs_level: params.pbs_level, + ks_base_log: params.ks_base_log, + ks_level: params.ks_level, + pfks_level: params.pfks_level, + pfks_base_log: params.pfks_base_log, + pfks_modular_std_dev: params.pfks_modular_std_dev, + cbs_level: params.cbs_level, + cbs_base_log: params.cbs_base_log, + carry_modulus: params.carry_modulus, + } + } +} + +impl From> for Parameters { + fn from(params: ShortIntegerParameterSet) -> Self { + Self { + lwe_dimension: params.lwe_dimension, + glwe_dimension: params.glwe_dimension, + polynomial_size: params.polynomial_size, + lwe_modular_std_dev: params.lwe_modular_std_dev, + glwe_modular_std_dev: params.glwe_modular_std_dev, + pbs_base_log: params.pbs_base_log, + pbs_level: params.pbs_level, + ks_base_log: params.ks_base_log, + ks_level: params.ks_level, + pfks_level: params.pfks_level, + pfks_base_log: params.pfks_base_log, + pfks_modular_std_dev: params.pfks_modular_std_dev, + cbs_level: params.cbs_level, + cbs_base_log: params.cbs_base_log, + message_modulus: MessageModulus(1 << MESSAGE_BITS as usize), + carry_modulus: params.carry_modulus, + } + } +} + +/// The Id that is used to identify and retrieve the corresponding keys +#[derive(Copy, Clone, Default, Debug, Eq, PartialEq)] +pub struct ShorIntId; + +impl Serialize for ShorIntId { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_unit_struct("ShorIntId") + } +} + +impl<'de, const MESSAGE_BITS: u8> Deserialize<'de> for ShorIntId { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct Visitor; + + impl<'de, const MESSAGE_BITS: u8> serde::de::Visitor<'de> for Visitor { + type Value = ShorIntId; + + fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result { + formatter.write_str("struct ShorIntId") + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(ShorIntId::) + } + } + + deserializer.deserialize_unit_struct("ShorIntId", Visitor::) + } +} + +impl ShortIntegerParameter for ShortIntegerParameterSet { + type Id = ShorIntId; +} + +impl StaticShortIntegerParameter + for ShortIntegerParameterSet +{ + const MESSAGE_BITS: u8 = MESSAGE_BITS; +} + +/// Defines a new static shortint type. +/// +/// It needs as input the: +/// - name of the type +/// - the number of bits of message the type has +/// - the keychain member where ClientKey / Server Key is stored +/// +/// It generates code: +/// - type alias for the client key, server key, parameter and shortint types +/// - the trait impl on the id type to access the keys +macro_rules! static_shortint_type { + ( + $(#[$outer:meta])* + $name:ident { + num_bits: $num_bits:literal, + keychain_member: $($member:ident).*, + } + ) => { + paste! { + + #[doc = concat!("Parameters for the [", stringify!($name), "] data type.")] + #[cfg_attr(all(doc, not(doctest)), cfg(feature = "shortint"))] + pub type [<$name Parameters>] = ShortIntegerParameterSet<$num_bits>; + + pub(in crate::typed_api) type [<$name ClientKey>] = GenericShortIntClientKey<[<$name Parameters>]>; + pub(in crate::typed_api) type [<$name PublicKey>] = GenericShortIntPublicKey<[<$name Parameters>]>; + pub(in crate::typed_api) type [<$name ServerKey>] = GenericShortIntServerKey<[<$name Parameters>]>; + + $(#[$outer])* + #[doc=concat!("An unsigned integer type with ", stringify!($num_bits), " bits.")] + #[cfg_attr(all(doc, not(doctest)), cfg(feature = "shortint"))] + pub type $name = GenericShortInt<[<$name Parameters>]>; + + #[cfg_attr(all(doc, not(doctest)), cfg(feature = "shortint"))] + pub type [] = CompressedGenericShortint<[<$name Parameters>]>; + + impl_ref_key_from_keychain!( + for <[<$name Parameters>] as ShortIntegerParameter>::Id { + key_type: [<$name ClientKey>], + keychain_member: $($member).*, + type_variant: crate::typed_api::errors::Type::$name, + } + ); + + impl_ref_key_from_public_keychain!( + for <[<$name Parameters>] as ShortIntegerParameter>::Id { + key_type: [<$name PublicKey>], + keychain_member: $($member).*, + type_variant: crate::typed_api::errors::Type::$name, + } + ); + + impl_with_global_key!( + for <[<$name Parameters>] as ShortIntegerParameter>::Id { + key_type: [<$name ServerKey>], + keychain_member: $($member).*, + type_variant: crate::typed_api::errors::Type::$name, + } + ); + } + }; +} + +static_shortint_type! { + FheUint2 { + num_bits: 2, + keychain_member: shortint_key.uint2_key, + } +} + +static_shortint_type! { + FheUint3 { + num_bits: 3, + keychain_member: shortint_key.uint3_key, + } +} + +static_shortint_type! { + FheUint4 { + num_bits: 4, + keychain_member: shortint_key.uint4_key, + } +} + +impl FheUint2Parameters { + pub fn with_carry_1() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_1) + } + + pub fn with_carry_2() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2) + } + + pub fn with_carry_3() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_3) + } + + pub fn with_carry_4() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_4) + } + + pub fn with_carry_5() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_5) + } + + pub fn with_carry_6() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_6) + } + + pub fn wopbs_default() -> Self { + Self::from_static(&crate::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2) + } +} + +impl Default for FheUint2Parameters { + fn default() -> Self { + Self::with_carry_2() + } +} + +impl FheUint3Parameters { + pub fn with_carry_1() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_1) + } + + pub fn with_carry_2() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_2) + } + + pub fn with_carry_3() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_3) + } + + pub fn with_carry_4() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_4) + } + + pub fn with_carry_5() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_3_CARRY_5) + } + + pub fn wopbs_default() -> Self { + Self::from_static(&crate::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_3_CARRY_3) + } +} + +impl Default for FheUint3Parameters { + fn default() -> Self { + Self::with_carry_3() + } +} + +impl FheUint4Parameters { + pub fn with_carry_1() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_1) + } + + pub fn with_carry_2() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_2) + } + + pub fn with_carry_3() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_3) + } + + pub fn with_carry_4() -> Self { + Self::from_static(&crate::shortint::parameters::PARAM_MESSAGE_4_CARRY_4) + } + + pub fn wopbs_default() -> Self { + Self::from_static(&crate::shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_4_CARRY_4) + } +} + +impl Default for FheUint4Parameters { + fn default() -> Self { + Self::with_carry_4() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn can_serialize_deserialize_shortint_id() { + let id = ShorIntId::<2>; + let mut cursor = std::io::Cursor::new(Vec::::new()); + bincode::serialize_into(&mut cursor, &id).unwrap(); + cursor.set_position(0); + let id2: ShorIntId<2> = bincode::deserialize_from(cursor).unwrap(); + + assert_eq!(id, id2); + } +} diff --git a/tfhe/src/typed_api/tests.rs b/tfhe/src/typed_api/tests.rs new file mode 100644 index 000000000..2da30cf3b --- /dev/null +++ b/tfhe/src/typed_api/tests.rs @@ -0,0 +1,82 @@ +use crate::typed_api::prelude::*; +#[cfg(feature = "boolean")] +use crate::typed_api::FheBool; +#[cfg(feature = "shortint")] +use crate::typed_api::FheUint2; +#[cfg(feature = "integer")] +use crate::typed_api::FheUint8; +#[cfg(any(feature = "boolean", feature = "shortint", feature = "integer"))] +use crate::typed_api::{generate_keys, ClientKey, ConfigBuilder, PublicKey}; +#[cfg(any(feature = "boolean", feature = "shortint", feature = "integer"))] +use std::fmt::Debug; + +#[cfg(any(feature = "boolean", feature = "shortint", feature = "integer"))] +fn assert_that_public_key_encryption_is_decrypted_by_client_key( + clear: ClearType, + pks: &PublicKey, + cks: &ClientKey, +) where + ClearType: Copy + Eq + Debug, + FheType: FheTryEncrypt + FheDecrypt, +{ + let encrypted = FheType::try_encrypt(clear, pks).unwrap(); + let decrypted: ClearType = encrypted.decrypt(cks); + assert_eq!(clear, decrypted); +} + +#[cfg(feature = "boolean")] +#[test] +fn test_boolean_public_key() { + let config = ConfigBuilder::all_disabled().enable_default_bool().build(); + + let (cks, _sks) = generate_keys(config); + + let pks = PublicKey::new(&cks); + + assert_that_public_key_encryption_is_decrypted_by_client_key::( + false, &pks, &cks, + ); + assert_that_public_key_encryption_is_decrypted_by_client_key::(true, &pks, &cks); +} + +#[cfg(feature = "shortint")] +#[test] +fn test_shortint_public_key() { + let config = ConfigBuilder::all_disabled().enable_default_uint2().build(); + + let (cks, _sks) = generate_keys(config); + + let pks = PublicKey::new(&cks); + + assert_that_public_key_encryption_is_decrypted_by_client_key::(0, &pks, &cks); + assert_that_public_key_encryption_is_decrypted_by_client_key::(1, &pks, &cks); + assert_that_public_key_encryption_is_decrypted_by_client_key::(2, &pks, &cks); + assert_that_public_key_encryption_is_decrypted_by_client_key::(3, &pks, &cks); +} + +#[cfg(feature = "integer")] +#[test] +fn test_integer_public_key() { + let config = ConfigBuilder::all_disabled().enable_default_uint8().build(); + + let (cks, _sks) = generate_keys(config); + + let pks = PublicKey::new(&cks); + + assert_that_public_key_encryption_is_decrypted_by_client_key::(235, &pks, &cks); +} + +#[cfg(feature = "boolean")] +#[test] +fn test_with_context() { + let config = ConfigBuilder::all_disabled().enable_default_bool().build(); + + let (cks, sks) = generate_keys(config); + + let a = FheBool::encrypt(false, &cks); + let b = FheBool::encrypt(true, &cks); + + let (r, _) = crate::typed_api::with_server_key_as_context(sks, move || a & b); + let d = r.decrypt(&cks); + assert!(d); +} diff --git a/tfhe/src/typed_api/traits.rs b/tfhe/src/typed_api/traits.rs new file mode 100644 index 000000000..d5ac7d647 --- /dev/null +++ b/tfhe/src/typed_api/traits.rs @@ -0,0 +1,123 @@ +use crate::typed_api::ClientKey; + +/// Trait used to have a generic way of creating a value of a FHE type +/// from a native value. +/// +/// This trait is for when FHE type the native value is encrypted +/// supports the same numbers of bits of precision. +/// +/// The `Key` is required as it contains the key needed to do the +/// actual encryption. +pub trait FheEncrypt { + fn encrypt(value: T, key: &Key) -> Self; +} + +pub trait DynamicFheEncryptor { + type FheType; + + fn encrypt(&self, value: T, key: &ClientKey) -> Self::FheType; +} + +// This trait has the same signature than +// `std::convert::From` however we create our own trait +// to be explicit about the `trivial` +pub trait FheTrivialEncrypt { + fn encrypt_trivial(value: T) -> Self; +} + +pub trait DynamicFheTrivialEncryptor { + type FheType; + + fn encrypt_trivial(&self, value: T) -> Self::FheType; +} + +/// Trait used to have a generic **fallible** way of creating a value of a FHE type. +/// +/// For example this trait may be implemented by FHE types which may not be able +/// to represent all the values of even the smallest native type. +/// +/// For example, `FheUint2` which has 2 bits of precision may not be constructed from +/// all values that a `u8` can hold. +pub trait FheTryEncrypt +where + Self: Sized, +{ + type Error: std::error::Error; + + fn try_encrypt(value: T, key: &Key) -> Result; +} + +/// Trait for fallible trivial encryption. +pub trait FheTryTrivialEncrypt +where + Self: Sized, +{ + type Error: std::error::Error; + + fn try_encrypt_trivial(value: T) -> Result; +} + +pub trait DynamicFheTryEncryptor { + type FheType; + type Error; + + fn try_encrypt(&self, value: T, key: &ClientKey) -> Result; +} + +/// Decrypt a FHE type to a native type. +pub trait FheDecrypt { + fn decrypt(&self, key: &ClientKey) -> T; +} + +/// Trait for fully homomorphic equality test. +/// +/// The standard trait [std::cmp::PartialEq] can not be used +/// has it requires to return a [bool]. +/// +/// This means that to compare ciphertext to another ciphertext or a scalar, +/// for equality, one cannot use the standard operator `==` but rather, use +/// the function directly. +pub trait FheEq { + type Output; + + fn eq(&self, other: Rhs) -> Self::Output; +} + +/// Trait for fully homomorphic comparisons. +/// +/// The standard trait [std::cmp::PartialOrd] can not be used +/// has it requires to return a [bool]. +/// +/// This means that to compare ciphertext to another ciphertext or a scalar, +/// one cannot use the standard operators (`>`, `<`, etc) and must use +/// the functions directly. +pub trait FheOrd { + type Output; + + fn lt(&self, other: Rhs) -> Self::Output; + fn le(&self, other: Rhs) -> Self::Output; + fn gt(&self, other: Rhs) -> Self::Output; + fn ge(&self, other: Rhs) -> Self::Output; +} + +/// Trait required to apply univariate function over homomorphic types. +/// +/// A `univariate function` is a function with one variable, e.g., of the form f(x). +pub trait FheBootstrap +where + Self: Sized, +{ + /// Compute a function over an encrypted message, and returns a new encrypted value containing + /// the result. + fn map u64>(&self, func: F) -> Self; + + /// Compute a function over the encrypted message. + fn apply u64>(&mut self, func: F); +} + +#[doc(hidden)] +pub trait FheNumberConstant { + const MIN: u64; + const MAX: u64; + const MODULUS: u64; +} diff --git a/tfhe/tests/test_fhe_number_ops.rs b/tfhe/tests/test_fhe_number_ops.rs new file mode 100644 index 000000000..adff2e968 --- /dev/null +++ b/tfhe/tests/test_fhe_number_ops.rs @@ -0,0 +1,232 @@ +#![cfg(any(feature = "integer", feature = "shortint"))] + +//! For now, in this test file, we don't want to check results +//! but rather check that short int, int, dyn short int, dyn int +//! all overload the same operators. +//! +//! For each operator overloaded operator we want to support $ variants: +//! +//! lhs + rhs +//! lhs + &rhs +//! &lhs + rhs +//! &lhs + &rhs +use std::fmt::Debug; +use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Rem, Shl, Shr, Sub}; +use tfhe::prelude::{FheDecrypt, FheTryEncrypt}; +use tfhe::ClientKey; + +macro_rules! define_operation_test { + ($name:ident => ($trait:ident, $symbol:tt)) => { + fn $name(lhs: T, rhs: T) + where + T: Clone + $trait, + T: for<'a> $trait<&'a T, Output = T>, + for<'a> &'a T: $trait + $trait<&'a T, Output = T>, + { + let _ = &lhs $symbol &rhs; + + let _ = &lhs $symbol rhs.clone(); + + let _ = lhs.clone() $symbol &rhs; + + let _ = lhs $symbol rhs; + } + }; +} + +/// We keep this to improve tests later +#[allow(dead_code)] +fn static_supports_all_add_ways(lhs_clear: T, rhs_clear: T, client_key: &ClientKey) +where + T: Add + Copy + Debug + PartialEq, + FheT: FheTryEncrypt + FheDecrypt, + FheT: Clone + Add, + FheT: for<'a> Add<&'a FheT, Output = FheT>, + for<'a> &'a FheT: Add + Add<&'a FheT, Output = FheT>, +{ + let lhs = FheT::try_encrypt(lhs_clear, client_key).unwrap(); + let rhs = FheT::try_encrypt(lhs_clear, client_key).unwrap(); + + let expected = lhs_clear + rhs_clear; + + let r = &lhs + &rhs; + let dec_r = r.decrypt(client_key); + assert_eq!(dec_r, expected); + + let r = &lhs + rhs.clone(); + let dec_r = r.decrypt(client_key); + assert_eq!(dec_r, expected); + + let r = lhs.clone() + &rhs; + let dec_r = r.decrypt(client_key); + assert_eq!(dec_r, expected); + + let r = lhs + rhs; + let dec_r = r.decrypt(client_key); + assert_eq!(dec_r, expected); +} + +define_operation_test!(supports_all_add_ways => (Add, +)); +define_operation_test!(supports_all_sub_ways => (Sub, -)); +define_operation_test!(supports_all_mul_ways => (Mul, *)); +define_operation_test!(supports_all_div_ways => (Div, /)); +define_operation_test!(supports_all_bitand_ways => (BitAnd, &)); +define_operation_test!(supports_all_bitor_ways => (BitOr, |)); +define_operation_test!(supports_all_bitxor_ways => (BitXor, ^)); + +fn supports_scalar_add_with_u8(lhs: T, rhs: u8) +where + T: Clone + Add, + for<'a> &'a T: Add, + u8: Add, + u8: for<'a> Add<&'a T, Output = T>, +{ + let _ = &lhs + rhs; + let _ = lhs.clone() + rhs; + + let _ = rhs + &lhs; + let _ = rhs + lhs; +} + +fn supports_scalar_div_with_u8(lhs: T, rhs: u8) +where + T: Clone + Div, + for<'a> &'a T: Div, + u8: Div, + u8: for<'a> Div<&'a T, Output = T>, +{ + let _ = &lhs / rhs; + let _ = lhs.clone() / rhs; + + let _ = rhs / &lhs; + let _ = rhs / lhs; +} + +fn supports_scalar_shl_with_u8(lhs: T, rhs: u8) +where + T: Clone + Shl, + for<'a> &'a T: Shl, + u8: Shl, + u8: for<'a> Shl<&'a T, Output = T>, +{ + let _ = &lhs << rhs; + let _ = lhs.clone() << rhs; + + let _ = rhs << &lhs; + let _ = rhs << lhs; +} + +fn supports_scalar_shr_with_u8(lhs: T, rhs: u8) +where + T: Clone + Shr, + for<'a> &'a T: Shr, + u8: Shr, + u8: for<'a> Shr<&'a T, Output = T>, +{ + let _ = &lhs >> rhs; + let _ = lhs.clone() >> rhs; + + let _ = rhs >> &lhs; + let _ = rhs >> lhs; +} + +fn supports_scalar_mod_with_u8(lhs: T, rhs: u8) +where + T: Clone + Rem, + for<'a> &'a T: Rem, + u8: Rem, + u8: for<'a> Rem<&'a T, Output = T>, +{ + let _ = &lhs % rhs; + let _ = lhs.clone() % rhs; + + let _ = rhs % &lhs; + let _ = rhs % lhs; +} + +fn supports_scalar_mul_with_u8(lhs: T, rhs: u8) +where + T: Clone + Mul, + for<'a> &'a T: Mul, + u8: Mul, + u8: for<'a> Mul<&'a T, Output = T>, +{ + let _ = &lhs * rhs; + let _ = lhs.clone() * rhs; + + let _ = rhs * &lhs; + let _ = rhs * lhs; +} + +#[cfg(feature = "shortint")] +#[test] +fn test_static_shortint_supports_ops() { + use tfhe::prelude::*; + use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint2}; + + let config = ConfigBuilder::all_disabled().enable_default_uint2().build(); + let (client_key, server_key) = generate_keys(config); + + set_server_key(server_key); + + let lhs = FheUint2::try_encrypt(0, &client_key).unwrap(); + let rhs = FheUint2::try_encrypt(1, &client_key).unwrap(); + + supports_all_add_ways(lhs.clone(), rhs.clone()); + supports_all_mul_ways(lhs.clone(), rhs.clone()); + supports_all_sub_ways(lhs.clone(), rhs.clone()); + supports_all_div_ways(lhs.clone(), rhs.clone()); + supports_all_bitand_ways(lhs.clone(), rhs.clone()); + supports_all_bitor_ways(lhs.clone(), rhs.clone()); + supports_all_bitxor_ways(lhs.clone(), rhs); + supports_scalar_mul_with_u8(lhs.clone(), 1); + supports_scalar_add_with_u8(lhs.clone(), 1); + supports_scalar_div_with_u8(lhs.clone(), 1); + supports_scalar_mod_with_u8(lhs.clone(), 1); + supports_scalar_shl_with_u8(lhs.clone(), 1); + supports_scalar_shr_with_u8(lhs, 1); +} + +#[cfg(feature = "integers")] +#[test] +fn test_static_supports_ops() { + use tfhe::prelude::*; + use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint8}; + + let config = ConfigBuilder::all_disabled().enable_default_uint8().build(); + let (client_key, server_key) = generate_keys(config); + + set_server_key(server_key); + + let lhs = FheUint8::encrypt(0, &client_key); + let rhs = FheUint8::encrypt(1, &client_key); + + supports_all_add_ways(lhs.clone(), rhs.clone()); + supports_all_mul_ways(lhs.clone(), rhs.clone()); + supports_all_sub_ways(lhs, rhs); +} + +#[cfg(feature = "integers")] +#[test] +fn test_dynamic_supports_ops() { + use tfhe::prelude::*; + use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint2Parameters, RadixParameters}; + + let mut config = ConfigBuilder::all_disabled(); + let uint10_type = config.add_integer_type(RadixParameters { + block_parameters: FheUint2Parameters::default().into(), + num_block: 5, + wopbs_block_parameters: tfhe_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2 + }); + + let (client_key, server_key) = generate_keys(config); + + set_server_key(server_key); + + let lhs = uint10_type.encrypt(127, &client_key); + let rhs = uint10_type.encrypt(100, &client_key); + + supports_all_add_ways(lhs.clone(), rhs.clone()); + supports_all_mul_ways(lhs.clone(), rhs.clone()); + supports_all_sub_ways(lhs, rhs); +} diff --git a/tfhe/tests/test_integers.rs b/tfhe/tests/test_integers.rs new file mode 100644 index 000000000..691dd6105 --- /dev/null +++ b/tfhe/tests/test_integers.rs @@ -0,0 +1,19 @@ +#![cfg(feature = "integer")] + +use tfhe::prelude::*; +use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint8}; + +#[test] +fn test_uint8() { + let config = ConfigBuilder::all_disabled().enable_default_uint8().build(); + + let (client_key, server_key) = generate_keys(config); + + set_server_key(server_key); + + let a = FheUint8::encrypt(27, &client_key); + let b = FheUint8::encrypt(100, &client_key); + + let c: u8 = (a + b).decrypt(&client_key); + assert_eq!(c, 127); +} diff --git a/tfhe/tests/test_shortints.rs b/tfhe/tests/test_shortints.rs new file mode 100644 index 000000000..c3b3d1ea0 --- /dev/null +++ b/tfhe/tests/test_shortints.rs @@ -0,0 +1,190 @@ +#![cfg(feature = "shortint")] +#![allow(clippy::assign_op_pattern)] +#![allow(clippy::identity_op)] + +use tfhe::prelude::*; +use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint2, FheUint3, FheUint4}; + +#[test] +fn test_uint2() -> Result<(), Box> { + let config = ConfigBuilder::all_disabled().enable_default_uint2().build(); + let (keys, server_keys) = generate_keys(config); + set_server_key(server_keys); + + let mut a = FheUint2::try_encrypt(0, &keys)?; + let b = FheUint2::try_encrypt(1, &keys)?; + + a += &b; + let decrypted = a.decrypt(&keys); + assert_eq!(decrypted, 1); + + a = a + &b; + let decrypted = a.decrypt(&keys); + assert_eq!(decrypted, 2); + + a = a - &b; + let decrypted = a.decrypt(&keys); + assert_eq!(decrypted, 1); + + Ok(()) +} + +#[test] +fn test_scalar_comparison_fhe_uint_3() -> Result<(), Box> { + let config = ConfigBuilder::all_disabled().enable_default_uint3().build(); + let (keys, server_keys) = generate_keys(config); + set_server_key(server_keys); + + let a = FheUint3::try_encrypt(2, &keys)?; + + let mut b = a.eq(2); + let decrypted = b.decrypt(&keys); + assert_eq!(decrypted, 1); + + b = a.ge(2); + let decrypted = b.decrypt(&keys); + assert_eq!(decrypted, 1); + + b = a.gt(2); + let decrypted = b.decrypt(&keys); + assert_eq!(decrypted, 0); + + b = a.le(2); + let decrypted = b.decrypt(&keys); + assert_eq!(decrypted, 1); + + b = a.lt(2); + let decrypted = b.decrypt(&keys); + assert_eq!(decrypted, 0); + + Ok(()) +} + +#[test] +fn test_sum_uint_3_vec() -> Result<(), Box> { + let config = ConfigBuilder::all_disabled().enable_default_uint3().build(); + let (keys, server_keys) = generate_keys(config); + set_server_key(server_keys); + + let clear_vec = vec![2, 5]; + let expected = clear_vec.iter().copied().sum::() % (2u8 << 3); + + let fhe_vec: Vec = clear_vec + .iter() + .copied() + .map(|v| FheUint3::try_encrypt(v, &keys).unwrap()) + .collect(); + + let result: FheUint3 = fhe_vec.iter().sum(); + let decrypted = result.decrypt(&keys); + assert_eq!(decrypted, expected); + + let slc = &[&fhe_vec[0], &fhe_vec[1]]; + let result: FheUint3 = slc.iter().copied().sum(); + let decrypted = result.decrypt(&keys); + assert_eq!(decrypted, expected); + + let empty_res: u8 = Vec::::new() + .into_iter() + .sum::() + .decrypt(&keys); + assert_eq!(empty_res, Vec::::new().into_iter().sum::()); + + Ok(()) +} + +#[test] +fn test_product_uint_4_vec() -> Result<(), Box> { + let config = ConfigBuilder::all_disabled().enable_default_uint4().build(); + let (keys, server_keys) = generate_keys(config); + set_server_key(server_keys); + + let clear_vec = vec![2, 5]; + let expected = clear_vec.iter().copied().product(); + + let fhe_vec: Vec = clear_vec + .iter() + .copied() + .map(|v| FheUint4::try_encrypt(v, &keys).unwrap()) + .collect(); + + let result: FheUint4 = fhe_vec.iter().product(); + let decrypted = result.decrypt(&keys); + assert_eq!(decrypted, expected); + + let slc = &[&fhe_vec[0], &fhe_vec[1]]; + let result: FheUint4 = slc.iter().copied().product(); + let decrypted = result.decrypt(&keys); + assert_eq!(decrypted, expected); + + let empty_res: u8 = Vec::::new() + .into_iter() + .product::() + .decrypt(&keys); + assert_eq!(empty_res, Vec::::new().into_iter().product::()); + + Ok(()) +} + +#[test] +fn test_programmable_bootstrap_fhe_uint2() -> Result<(), Box> { + let config = ConfigBuilder::all_disabled().enable_default_uint2().build(); + let (keys, server_keys) = generate_keys(config); + set_server_key(server_keys); + + let mut a = FheUint2::try_encrypt(2, &keys)?; + + let c = a.map(|value| value * value); + let decrypted = c.decrypt(&keys); + assert_eq!(decrypted, (2 * 2) % 2); + + a.apply(|value| value * value); + let decrypted = a.decrypt(&keys); + assert_eq!(decrypted, (2 * 2) % 2); + + Ok(()) +} + +#[test] +fn test_programmable_biviariate_bootstrap_fhe_uint3() -> Result<(), Box> { + let config = ConfigBuilder::all_disabled().enable_default_uint3().build(); + let (keys, server_keys) = generate_keys(config); + set_server_key(server_keys); + + for i in 0..FheUint3::MAX { + let clear_a = i; + let clear_b = i + 1; + + let a = FheUint3::try_encrypt(clear_a, &keys)?; + let b = FheUint3::try_encrypt(clear_b, &keys)?; + + let result = a.bivariate_function(&b, std::cmp::max); + let clear_result: u8 = result.decrypt(&keys); + assert_eq!(clear_result, std::cmp::max(clear_a, clear_b)); + + // check reversing lhs and rhs works + let result = b.bivariate_function(&a, std::cmp::max); + let clear_result: u8 = result.decrypt(&keys); + assert_eq!(clear_result, std::cmp::max(clear_b, clear_a)); + } + + Ok(()) +} + +#[test] +fn test_branchless_min_max() -> Result<(), Box> { + let config = ConfigBuilder::all_disabled().enable_default_uint4().build(); + let (keys, server_keys) = generate_keys(config); + set_server_key(server_keys); + + let x = FheUint4::try_encrypt(12, &keys)?; + let y = FheUint4::try_encrypt(4, &keys)?; + + let min = &y ^ (&x ^ &y) & -(x.lt(&y)); + let max = &x ^ (&x ^ &y) & -(x.lt(&y)); + + assert_eq!(min.decrypt(&keys), 4); + assert_eq!(max.decrypt(&keys), 12); + + Ok(()) +}