mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 15:18:33 -05:00
fix(api): Add min/max on owned types
This commit is contained in:
committed by
tmontaigu
parent
48c10e91f7
commit
099345df02
@@ -168,6 +168,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id> FheMax<Self> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
{
|
||||
type Output = Self;
|
||||
fn max(&self, rhs: Self) -> Self::Output {
|
||||
self.max(&rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id> FheMin<&Self> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
@@ -219,6 +229,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id> FheMin<Self> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
{
|
||||
type Output = Self;
|
||||
fn min(&self, rhs: Self) -> Self::Output {
|
||||
self.min(&rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id> FheEq<Self> for FheInt<Id>
|
||||
where
|
||||
Id: FheIntId,
|
||||
|
||||
@@ -110,6 +110,12 @@ fn test_integer_compress_decompress() {
|
||||
super::test_case_integer_compress_decompress(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_max() {
|
||||
let client_key = setup_default_cpu();
|
||||
super::test_case_min_max(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trivial_fhe_int8() {
|
||||
let config = ConfigBuilder::default().build();
|
||||
|
||||
@@ -77,6 +77,14 @@ fn test_leading_trailing_zeros_ones() {
|
||||
test_case_leading_trailing_zeros_ones(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_max() {
|
||||
let client_key = crate::high_level_api::integers::unsigned::tests::gpu::setup_gpu(Some(
|
||||
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS,
|
||||
));
|
||||
super::test_case_min_max(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gpu_get_add_sub_size_on_gpu() {
|
||||
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
|
||||
|
||||
@@ -453,3 +453,28 @@ fn test_case_ilog2(cks: &ClientKey) {
|
||||
assert!(!is_ok);
|
||||
}
|
||||
}
|
||||
|
||||
fn test_case_min_max(cks: &ClientKey) {
|
||||
let mut rng = rand::thread_rng();
|
||||
let a_val: i8 = rng.gen();
|
||||
let b_val: i8 = rng.gen();
|
||||
|
||||
let a = FheInt8::encrypt(a_val, cks);
|
||||
let b = FheInt8::encrypt(b_val, cks);
|
||||
|
||||
// Test by-reference operations
|
||||
let encrypted_min = a.min(&b);
|
||||
let encrypted_max = a.max(&b);
|
||||
let decrypted_min: i8 = encrypted_min.decrypt(cks);
|
||||
let decrypted_max: i8 = encrypted_max.decrypt(cks);
|
||||
assert_eq!(decrypted_min, a_val.min(b_val));
|
||||
assert_eq!(decrypted_max, a_val.max(b_val));
|
||||
|
||||
// Test by-value operations
|
||||
let encrypted_min = a.min(b.clone());
|
||||
let encrypted_max = a.max(b);
|
||||
let decrypted_min: i8 = encrypted_min.decrypt(cks);
|
||||
let decrypted_max: i8 = encrypted_max.decrypt(cks);
|
||||
assert_eq!(decrypted_min, a_val.min(b_val));
|
||||
assert_eq!(decrypted_max, a_val.max(b_val));
|
||||
}
|
||||
|
||||
@@ -277,6 +277,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id> FheMax<Self> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
{
|
||||
type Output = Self;
|
||||
fn max(&self, rhs: Self) -> Self::Output {
|
||||
self.max(&rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id> FheMin<&Self> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
@@ -330,6 +340,16 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id> FheMin<Self> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
{
|
||||
type Output = Self;
|
||||
fn min(&self, rhs: Self) -> Self::Output {
|
||||
self.min(&rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<Id> FheEq<Self> for FheUint<Id>
|
||||
where
|
||||
Id: FheUintId,
|
||||
|
||||
@@ -444,6 +444,12 @@ fn test_sum() {
|
||||
super::test_case_sum(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_max() {
|
||||
let client_key = setup_default_cpu();
|
||||
super::test_case_min_max(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_safe_deserialize_conformant_fhe_uint32() {
|
||||
let block_params = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
|
||||
|
||||
@@ -169,6 +169,12 @@ fn test_ilog2_multibit() {
|
||||
super::test_case_ilog2(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_max() {
|
||||
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
|
||||
super::test_case_min_max(&client_key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gpu_get_add_and_sub_size_on_gpu() {
|
||||
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
|
||||
|
||||
@@ -715,3 +715,28 @@ fn test_case_is_even_is_odd(cks: &ClientKey) {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn test_case_min_max(cks: &ClientKey) {
|
||||
let mut rng = rand::thread_rng();
|
||||
let a_val: u8 = rng.gen();
|
||||
let b_val: u8 = rng.gen();
|
||||
|
||||
let a = FheUint8::encrypt(a_val, cks);
|
||||
let b = FheUint8::encrypt(b_val, cks);
|
||||
|
||||
// Test by-reference operations
|
||||
let encrypted_min = a.min(&b);
|
||||
let encrypted_max = a.max(&b);
|
||||
let decrypted_min: u8 = encrypted_min.decrypt(cks);
|
||||
let decrypted_max: u8 = encrypted_max.decrypt(cks);
|
||||
assert_eq!(decrypted_min, a_val.min(b_val));
|
||||
assert_eq!(decrypted_max, a_val.max(b_val));
|
||||
|
||||
// Test by-value operations
|
||||
let encrypted_min = a.min(b.clone());
|
||||
let encrypted_max = a.max(b);
|
||||
let decrypted_min: u8 = encrypted_min.decrypt(cks);
|
||||
let decrypted_max: u8 = encrypted_max.decrypt(cks);
|
||||
assert_eq!(decrypted_min, a_val.min(b_val));
|
||||
assert_eq!(decrypted_max, a_val.max(b_val));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user