mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(integer): add min,max and comparisons ops
This commit is contained in:
@@ -277,6 +277,38 @@ define_server_key_bench_unary_fn!(smart_neg);
|
||||
define_server_key_bench_unary_fn!(full_propagate);
|
||||
define_server_key_bench_unary_fn!(full_propagate_parallelized);
|
||||
|
||||
define_server_key_bench_fn!(unchecked_max);
|
||||
define_server_key_bench_fn!(unchecked_min);
|
||||
define_server_key_bench_fn!(unchecked_eq);
|
||||
define_server_key_bench_fn!(unchecked_lt);
|
||||
define_server_key_bench_fn!(unchecked_le);
|
||||
define_server_key_bench_fn!(unchecked_gt);
|
||||
define_server_key_bench_fn!(unchecked_ge);
|
||||
|
||||
define_server_key_bench_fn!(unchecked_max_parallelized);
|
||||
define_server_key_bench_fn!(unchecked_min_parallelized);
|
||||
define_server_key_bench_fn!(unchecked_eq_parallelized);
|
||||
define_server_key_bench_fn!(unchecked_lt_parallelized);
|
||||
define_server_key_bench_fn!(unchecked_le_parallelized);
|
||||
define_server_key_bench_fn!(unchecked_gt_parallelized);
|
||||
define_server_key_bench_fn!(unchecked_ge_parallelized);
|
||||
|
||||
define_server_key_bench_fn!(smart_max);
|
||||
define_server_key_bench_fn!(smart_min);
|
||||
define_server_key_bench_fn!(smart_eq);
|
||||
define_server_key_bench_fn!(smart_lt);
|
||||
define_server_key_bench_fn!(smart_le);
|
||||
define_server_key_bench_fn!(smart_gt);
|
||||
define_server_key_bench_fn!(smart_ge);
|
||||
|
||||
define_server_key_bench_fn!(smart_max_parallelized);
|
||||
define_server_key_bench_fn!(smart_min_parallelized);
|
||||
define_server_key_bench_fn!(smart_eq_parallelized);
|
||||
define_server_key_bench_fn!(smart_lt_parallelized);
|
||||
define_server_key_bench_fn!(smart_le_parallelized);
|
||||
define_server_key_bench_fn!(smart_gt_parallelized);
|
||||
define_server_key_bench_fn!(smart_ge_parallelized);
|
||||
|
||||
criterion_group!(
|
||||
smart_arithmetic_operation,
|
||||
smart_neg,
|
||||
@@ -285,6 +317,13 @@ criterion_group!(
|
||||
smart_bitand,
|
||||
smart_bitor,
|
||||
smart_bitxor,
|
||||
smart_max,
|
||||
smart_min,
|
||||
smart_eq,
|
||||
smart_lt,
|
||||
smart_le,
|
||||
smart_gt,
|
||||
smart_ge,
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
@@ -295,6 +334,13 @@ criterion_group!(
|
||||
smart_bitand_parallelized,
|
||||
smart_bitor_parallelized,
|
||||
smart_bitxor_parallelized,
|
||||
smart_max_parallelized,
|
||||
smart_min_parallelized,
|
||||
smart_eq_parallelized,
|
||||
smart_lt_parallelized,
|
||||
smart_le_parallelized,
|
||||
smart_gt_parallelized,
|
||||
smart_ge_parallelized,
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
@@ -319,6 +365,13 @@ criterion_group!(
|
||||
unchecked_bitand,
|
||||
unchecked_bitor,
|
||||
unchecked_bitxor,
|
||||
unchecked_max,
|
||||
unchecked_min,
|
||||
unchecked_eq,
|
||||
unchecked_lt,
|
||||
unchecked_le,
|
||||
unchecked_gt,
|
||||
unchecked_ge,
|
||||
);
|
||||
|
||||
criterion_group!(
|
||||
@@ -326,6 +379,13 @@ criterion_group!(
|
||||
unchecked_scalar_add,
|
||||
unchecked_scalar_sub,
|
||||
unchecked_small_scalar_mul,
|
||||
unchecked_max_parallelized,
|
||||
unchecked_min_parallelized,
|
||||
unchecked_eq_parallelized,
|
||||
unchecked_lt_parallelized,
|
||||
unchecked_le_parallelized,
|
||||
unchecked_gt_parallelized,
|
||||
unchecked_ge_parallelized,
|
||||
);
|
||||
|
||||
criterion_group!(misc, full_propagate, full_propagate_parallelized);
|
||||
|
||||
@@ -77,9 +77,31 @@ impl ClearText for u128 {
|
||||
unsafe { std::slice::from_raw_parts_mut(u128_slc.as_mut_ptr() as *mut u64, 2) }
|
||||
}
|
||||
}
|
||||
#[derive(Copy, Clone, Debug, PartialEq)]
|
||||
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
|
||||
pub struct U256([u128; 2]);
|
||||
|
||||
impl U256 {
|
||||
#[inline]
|
||||
pub fn low(&self) -> u128 {
|
||||
self.0[0]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn high(&self) -> u128 {
|
||||
self.0[1]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn low_mut(&mut self) -> &mut u128 {
|
||||
&mut self.0[0]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn high_mut(&mut self) -> &mut u128 {
|
||||
&mut self.0[1]
|
||||
}
|
||||
}
|
||||
|
||||
impl ClearText for U256 {
|
||||
fn as_words(&self) -> &[u64] {
|
||||
let u128_slc = self.0.as_slice();
|
||||
@@ -92,12 +114,66 @@ impl ClearText for U256 {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl rand::distributions::Distribution<U256> for rand::distributions::Standard {
|
||||
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> U256 {
|
||||
let low = rng.gen::<u128>();
|
||||
let high = rng.gen::<u128>();
|
||||
U256::from((low, high))
|
||||
}
|
||||
}
|
||||
|
||||
// Since we store as [low, high], deriving ord
|
||||
// would produces bad ordering
|
||||
impl std::cmp::Ord for U256 {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
let high_bits_ord = self.high().cmp(&other.high());
|
||||
if let std::cmp::Ordering::Equal = high_bits_ord {
|
||||
self.low().cmp(&other.low())
|
||||
} else {
|
||||
high_bits_ord
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Add<Self> for U256 {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
let (new_low, has_overflowed) = self.low().overflowing_add(rhs.low());
|
||||
let new_high = self
|
||||
.high()
|
||||
.wrapping_add(rhs.high())
|
||||
.wrapping_add(u128::from(has_overflowed));
|
||||
|
||||
Self::from((new_low, new_high))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::AddAssign<Self> for U256 {
|
||||
fn add_assign(&mut self, rhs: Self) {
|
||||
*self = *self + rhs;
|
||||
}
|
||||
}
|
||||
|
||||
impl std::cmp::PartialOrd for U256 {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<(u128, u128)> for U256 {
|
||||
fn from(v: (u128, u128)) -> Self {
|
||||
Self([v.0, v.1])
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u128> for U256 {
|
||||
fn from(value: u128) -> Self {
|
||||
Self::from((value, 0))
|
||||
}
|
||||
}
|
||||
|
||||
impl ClientKey {
|
||||
/// Creates a Client Key.
|
||||
///
|
||||
|
||||
1227
tfhe/src/integer/server_key/comparator.rs
Normal file
1227
tfhe/src/integer/server_key/comparator.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@
|
||||
//!
|
||||
//! This module implements the generation of the server public key, together with all the
|
||||
//! available homomorphic integer operations.
|
||||
pub mod comparator;
|
||||
mod crt;
|
||||
mod crt_parallel;
|
||||
mod radix;
|
||||
|
||||
480
tfhe/src/integer/server_key/radix/comparison.rs
Normal file
480
tfhe/src/integer/server_key/radix/comparison.rs
Normal file
@@ -0,0 +1,480 @@
|
||||
use super::ServerKey;
|
||||
|
||||
use crate::integer::server_key::comparator::Comparator;
|
||||
use crate::integer::RadixCiphertext;
|
||||
|
||||
impl ServerKey {
|
||||
/// Compares for equality 2 ciphertexts
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs == rhs, otherwise 0
|
||||
///
|
||||
/// Requires carry bits to be empty
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 14;
|
||||
/// let msg2 = 97;
|
||||
///
|
||||
/// let ct1 = cks.encrypt(msg1);
|
||||
/// let ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.unchecked_eq(&ct1, &ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, u64::from(msg1 == msg2));
|
||||
/// ```
|
||||
pub fn unchecked_eq(&self, lhs: &RadixCiphertext, rhs: &RadixCiphertext) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_eq(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Compares if lhs is strictly greater than rhs
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs > rhs, otherwise 0
|
||||
///
|
||||
/// Requires carry bits to be empty
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 14;
|
||||
/// let msg2 = 97;
|
||||
///
|
||||
/// let ct1 = cks.encrypt(msg1);
|
||||
/// let ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.unchecked_gt(&ct1, &ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, u64::from(msg1 > msg2));
|
||||
/// ```
|
||||
pub fn unchecked_gt(&self, lhs: &RadixCiphertext, rhs: &RadixCiphertext) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_gt(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Compares if lhs is greater or equal than rhs
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs >= rhs, otherwise 0
|
||||
///
|
||||
/// Requires carry bits to be empty
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 97;
|
||||
/// let msg2 = 97;
|
||||
///
|
||||
/// let ct1 = cks.encrypt(msg1);
|
||||
/// let ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.unchecked_ge(&ct1, &ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, u64::from(msg1 >= msg2));
|
||||
/// ```
|
||||
pub fn unchecked_ge(&self, lhs: &RadixCiphertext, rhs: &RadixCiphertext) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_ge(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Compares if lhs is strictly lower than rhs
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs < rhs, otherwise 0
|
||||
///
|
||||
/// Requires carry bits to be empty
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 237;
|
||||
/// let msg2 = 23;
|
||||
///
|
||||
/// let ct1 = cks.encrypt(msg1);
|
||||
/// let ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.unchecked_lt(&ct1, &ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, u64::from(msg1 < msg2));
|
||||
/// ```
|
||||
pub fn unchecked_lt(&self, lhs: &RadixCiphertext, rhs: &RadixCiphertext) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_lt(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Compares if lhs is lower or equal than rhs
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs < rhs, otherwise 0
|
||||
///
|
||||
/// Requires carry bits to be empty
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 237;
|
||||
/// let msg2 = 23;
|
||||
///
|
||||
/// let ct1 = cks.encrypt(msg1);
|
||||
/// let ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.unchecked_le(&ct1, &ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, u64::from(msg1 < msg2));
|
||||
/// ```
|
||||
pub fn unchecked_le(&self, lhs: &RadixCiphertext, rhs: &RadixCiphertext) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_le(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Computes the max of two encrypted values
|
||||
///
|
||||
///
|
||||
/// Requires carry bits to be empty
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 237;
|
||||
/// let msg2 = 23;
|
||||
///
|
||||
/// let ct1 = cks.encrypt(msg1);
|
||||
/// let ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.unchecked_max(&ct1, &ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, std::cmp::max(msg1, msg2));
|
||||
/// ```
|
||||
pub fn unchecked_max(&self, lhs: &RadixCiphertext, rhs: &RadixCiphertext) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_max(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Computes the min of two encrypted values
|
||||
///
|
||||
///
|
||||
/// Requires carry bits to be empty
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 237;
|
||||
/// let msg2 = 23;
|
||||
///
|
||||
/// let ct1 = cks.encrypt(msg1);
|
||||
/// let ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.unchecked_min(&ct1, &ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, std::cmp::min(msg1, msg2));
|
||||
/// ```
|
||||
pub fn unchecked_min(&self, lhs: &RadixCiphertext, rhs: &RadixCiphertext) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_min(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Compares for equality 2 ciphertexts
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs == rhs, otherwise 0
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 14;
|
||||
/// let msg2 = 97;
|
||||
///
|
||||
/// let mut ct1 = cks.encrypt(msg1);
|
||||
/// let mut ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.smart_eq(&mut ct1, &mut ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, u64::from(msg1 == msg2));
|
||||
/// ```
|
||||
pub fn smart_eq(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_eq(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Compares if lhs is strictly greater than rhs
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs == rhs, otherwise 0
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 14;
|
||||
/// let msg2 = 97;
|
||||
///
|
||||
/// let mut ct1 = cks.encrypt(msg1);
|
||||
/// let mut ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.smart_gt(&mut ct1, &mut ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, u64::from(msg1 > msg2));
|
||||
/// ```
|
||||
pub fn smart_gt(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_gt(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Compares if lhs is greater or equal than rhs
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs >= rhs, otherwise 0
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 14;
|
||||
/// let msg2 = 97;
|
||||
///
|
||||
/// let mut ct1 = cks.encrypt(msg1);
|
||||
/// let mut ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.smart_gt(&mut ct1, &mut ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, u64::from(msg1 >= msg2));
|
||||
/// ```
|
||||
pub fn smart_ge(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_ge(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Compares if lhs is strictly lower than rhs
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs < rhs, otherwise 0
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 14;
|
||||
/// let msg2 = 97;
|
||||
///
|
||||
/// let mut ct1 = cks.encrypt(msg1);
|
||||
/// let mut ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.smart_lt(&mut ct1, &mut ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, u64::from(msg1 < msg2));
|
||||
/// ```
|
||||
pub fn smart_lt(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_lt(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Compares if lhs is lower or equal than rhs
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs <= rhs, otherwise 0
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 14;
|
||||
/// let msg2 = 97;
|
||||
///
|
||||
/// let mut ct1 = cks.encrypt(msg1);
|
||||
/// let mut ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.smart_le(&mut ct1, &mut ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, u64::from(msg1 <= msg2));
|
||||
/// ```
|
||||
pub fn smart_le(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_le(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Computes the max of two encrypted values
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs < rhs, otherwise 0
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 14;
|
||||
/// let msg2 = 97;
|
||||
///
|
||||
/// let mut ct1 = cks.encrypt(msg1);
|
||||
/// let mut ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.smart_max(&mut ct1, &mut ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, std::cmp::max(msg1, msg2));
|
||||
/// ```
|
||||
pub fn smart_max(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_max(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Computes the min of two encrypted values
|
||||
///
|
||||
/// Returns a ciphertext containing 1 if lhs < rhs, otherwise 0
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::integer::gen_keys_radix;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
|
||||
///
|
||||
/// let size = 4;
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let (cks, sks) = gen_keys_radix(&PARAM_MESSAGE_2_CARRY_2, size);
|
||||
///
|
||||
/// let msg1 = 14;
|
||||
/// let msg2 = 97;
|
||||
///
|
||||
/// let mut ct1 = cks.encrypt(msg1);
|
||||
/// let mut ct2 = cks.encrypt(msg2);
|
||||
///
|
||||
/// let ct_res = sks.smart_min(&mut ct1, &mut ct2);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result = cks.decrypt(&ct_res);
|
||||
/// assert_eq!(dec_result, std::cmp::min(msg1, msg2));
|
||||
/// ```
|
||||
pub fn smart_min(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_min(lhs, rhs)
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
mod add;
|
||||
mod bitwise_op;
|
||||
mod comparison;
|
||||
mod mul;
|
||||
mod neg;
|
||||
mod scalar_add;
|
||||
|
||||
@@ -73,8 +73,8 @@ fn integer_encrypt_decrypt_128_bits(param: Parameters) {
|
||||
|
||||
// RNG
|
||||
let mut rng = rand::thread_rng();
|
||||
let num_block = (64f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize * 2;
|
||||
for _ in 0..1 {
|
||||
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
for _ in 0..10 {
|
||||
let clear = rng.gen::<u128>();
|
||||
|
||||
//encryption
|
||||
@@ -92,7 +92,6 @@ fn integer_encrypt_decrypt_128_bits(param: Parameters) {
|
||||
fn integer_encrypt_decrypt_128_bits_specific_values(param: Parameters) {
|
||||
let (cks, sks) = KEY_CACHE.get_from_params(param);
|
||||
|
||||
// let num_block = (64f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize * 2;
|
||||
let num_block = (128f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
{
|
||||
let a = u64::MAX as u128;
|
||||
@@ -181,7 +180,7 @@ fn integer_encrypt_decrypt_256_bits(param: Parameters) {
|
||||
let mut rng = rand::thread_rng();
|
||||
let num_block = (256f64 / (param.message_modulus.0 as f64).log(2.0)).ceil() as usize;
|
||||
|
||||
for _ in 0..1 {
|
||||
for _ in 0..10 {
|
||||
let clear0 = rng.gen::<u128>();
|
||||
let clear1 = rng.gen::<u128>();
|
||||
|
||||
|
||||
118
tfhe/src/integer/server_key/radix_parallel/comparison.rs
Normal file
118
tfhe/src/integer/server_key/radix_parallel/comparison.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use super::ServerKey;
|
||||
|
||||
use crate::integer::server_key::comparator::Comparator;
|
||||
use crate::integer::RadixCiphertext;
|
||||
|
||||
impl ServerKey {
|
||||
pub fn unchecked_eq_parallelized(
|
||||
&self,
|
||||
lhs: &RadixCiphertext,
|
||||
rhs: &RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_eq_parallelized(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn unchecked_gt_parallelized(
|
||||
&self,
|
||||
lhs: &RadixCiphertext,
|
||||
rhs: &RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_gt_parallelized(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn unchecked_ge_parallelized(
|
||||
&self,
|
||||
lhs: &RadixCiphertext,
|
||||
rhs: &RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_ge_parallelized(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn unchecked_lt_parallelized(
|
||||
&self,
|
||||
lhs: &RadixCiphertext,
|
||||
rhs: &RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_lt_parallelized(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn unchecked_le_parallelized(
|
||||
&self,
|
||||
lhs: &RadixCiphertext,
|
||||
rhs: &RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_le_parallelized(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn unchecked_max_parallelized(
|
||||
&self,
|
||||
lhs: &RadixCiphertext,
|
||||
rhs: &RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_max_parallelized(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn unchecked_min_parallelized(
|
||||
&self,
|
||||
lhs: &RadixCiphertext,
|
||||
rhs: &RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).unchecked_min_parallelized(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn smart_eq_parallelized(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_eq_parallelized(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn smart_gt_parallelized(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_gt_parallelized(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn smart_ge_parallelized(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_ge_parallelized(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn smart_lt_parallelized(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_lt_parallelized(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn smart_le_parallelized(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_le_parallelized(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn smart_max_parallelized(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_max_parallelized(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn smart_min_parallelized(
|
||||
&self,
|
||||
lhs: &mut RadixCiphertext,
|
||||
rhs: &mut RadixCiphertext,
|
||||
) -> RadixCiphertext {
|
||||
Comparator::new(self).smart_min_parallelized(lhs, rhs)
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
mod add;
|
||||
mod bitwise_op;
|
||||
mod comparison;
|
||||
mod mul;
|
||||
mod neg;
|
||||
mod scalar_add;
|
||||
|
||||
Reference in New Issue
Block a user