fix(api): Add min/max on owned types

This commit is contained in:
swarnabhasinha
2025-06-28 18:50:23 +05:30
committed by tmontaigu
parent 48c10e91f7
commit 099345df02
8 changed files with 116 additions and 0 deletions

View File

@@ -168,6 +168,16 @@ where
}
}
impl<Id> FheMax<Self> for FheInt<Id>
where
Id: FheIntId,
{
type Output = Self;
fn max(&self, rhs: Self) -> Self::Output {
self.max(&rhs)
}
}
impl<Id> FheMin<&Self> for FheInt<Id>
where
Id: FheIntId,
@@ -219,6 +229,16 @@ where
}
}
impl<Id> FheMin<Self> for FheInt<Id>
where
Id: FheIntId,
{
type Output = Self;
fn min(&self, rhs: Self) -> Self::Output {
self.min(&rhs)
}
}
impl<Id> FheEq<Self> for FheInt<Id>
where
Id: FheIntId,

View File

@@ -110,6 +110,12 @@ fn test_integer_compress_decompress() {
super::test_case_integer_compress_decompress(&client_key);
}
#[test]
fn test_min_max() {
let client_key = setup_default_cpu();
super::test_case_min_max(&client_key);
}
#[test]
fn test_trivial_fhe_int8() {
let config = ConfigBuilder::default().build();

View File

@@ -77,6 +77,14 @@ fn test_leading_trailing_zeros_ones() {
test_case_leading_trailing_zeros_ones(&client_key);
}
#[test]
fn test_min_max() {
let client_key = crate::high_level_api::integers::unsigned::tests::gpu::setup_gpu(Some(
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS,
));
super::test_case_min_max(&client_key);
}
#[test]
fn test_gpu_get_add_sub_size_on_gpu() {
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));

View File

@@ -453,3 +453,28 @@ fn test_case_ilog2(cks: &ClientKey) {
assert!(!is_ok);
}
}
fn test_case_min_max(cks: &ClientKey) {
let mut rng = rand::thread_rng();
let a_val: i8 = rng.gen();
let b_val: i8 = rng.gen();
let a = FheInt8::encrypt(a_val, cks);
let b = FheInt8::encrypt(b_val, cks);
// Test by-reference operations
let encrypted_min = a.min(&b);
let encrypted_max = a.max(&b);
let decrypted_min: i8 = encrypted_min.decrypt(cks);
let decrypted_max: i8 = encrypted_max.decrypt(cks);
assert_eq!(decrypted_min, a_val.min(b_val));
assert_eq!(decrypted_max, a_val.max(b_val));
// Test by-value operations
let encrypted_min = a.min(b.clone());
let encrypted_max = a.max(b);
let decrypted_min: i8 = encrypted_min.decrypt(cks);
let decrypted_max: i8 = encrypted_max.decrypt(cks);
assert_eq!(decrypted_min, a_val.min(b_val));
assert_eq!(decrypted_max, a_val.max(b_val));
}

View File

@@ -277,6 +277,16 @@ where
}
}
impl<Id> FheMax<Self> for FheUint<Id>
where
Id: FheUintId,
{
type Output = Self;
fn max(&self, rhs: Self) -> Self::Output {
self.max(&rhs)
}
}
impl<Id> FheMin<&Self> for FheUint<Id>
where
Id: FheUintId,
@@ -330,6 +340,16 @@ where
}
}
impl<Id> FheMin<Self> for FheUint<Id>
where
Id: FheUintId,
{
type Output = Self;
fn min(&self, rhs: Self) -> Self::Output {
self.min(&rhs)
}
}
impl<Id> FheEq<Self> for FheUint<Id>
where
Id: FheUintId,

View File

@@ -444,6 +444,12 @@ fn test_sum() {
super::test_case_sum(&client_key);
}
#[test]
fn test_min_max() {
let client_key = setup_default_cpu();
super::test_case_min_max(&client_key);
}
#[test]
fn test_safe_deserialize_conformant_fhe_uint32() {
let block_params = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;

View File

@@ -169,6 +169,12 @@ fn test_ilog2_multibit() {
super::test_case_ilog2(&client_key);
}
#[test]
fn test_min_max() {
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
super::test_case_min_max(&client_key);
}
#[test]
fn test_gpu_get_add_and_sub_size_on_gpu() {
let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));

View File

@@ -715,3 +715,28 @@ fn test_case_is_even_is_odd(cks: &ClientKey) {
);
}
}
fn test_case_min_max(cks: &ClientKey) {
let mut rng = rand::thread_rng();
let a_val: u8 = rng.gen();
let b_val: u8 = rng.gen();
let a = FheUint8::encrypt(a_val, cks);
let b = FheUint8::encrypt(b_val, cks);
// Test by-reference operations
let encrypted_min = a.min(&b);
let encrypted_max = a.max(&b);
let decrypted_min: u8 = encrypted_min.decrypt(cks);
let decrypted_max: u8 = encrypted_max.decrypt(cks);
assert_eq!(decrypted_min, a_val.min(b_val));
assert_eq!(decrypted_max, a_val.max(b_val));
// Test by-value operations
let encrypted_min = a.min(b.clone());
let encrypted_max = a.max(b);
let decrypted_min: u8 = encrypted_min.decrypt(cks);
let decrypted_max: u8 = encrypted_max.decrypt(cks);
assert_eq!(decrypted_min, a_val.min(b_val));
assert_eq!(decrypted_max, a_val.max(b_val));
}