From 099345df02216fe4102faa8d151dbce5ce0a4351 Mon Sep 17 00:00:00 2001 From: swarnabhasinha Date: Sat, 28 Jun 2025 18:50:23 +0530 Subject: [PATCH] fix(api): Add min/max on owned types --- .../src/high_level_api/integers/signed/ops.rs | 20 +++++++++++++++ .../integers/signed/tests/cpu.rs | 6 +++++ .../integers/signed/tests/gpu.rs | 8 ++++++ .../integers/signed/tests/mod.rs | 25 +++++++++++++++++++ .../high_level_api/integers/unsigned/ops.rs | 20 +++++++++++++++ .../integers/unsigned/tests/cpu.rs | 6 +++++ .../integers/unsigned/tests/gpu.rs | 6 +++++ .../integers/unsigned/tests/mod.rs | 25 +++++++++++++++++++ 8 files changed, 116 insertions(+) diff --git a/tfhe/src/high_level_api/integers/signed/ops.rs b/tfhe/src/high_level_api/integers/signed/ops.rs index dd3e4e709..8010b7428 100644 --- a/tfhe/src/high_level_api/integers/signed/ops.rs +++ b/tfhe/src/high_level_api/integers/signed/ops.rs @@ -168,6 +168,16 @@ where } } +impl FheMax for FheInt +where + Id: FheIntId, +{ + type Output = Self; + fn max(&self, rhs: Self) -> Self::Output { + self.max(&rhs) + } +} + impl FheMin<&Self> for FheInt where Id: FheIntId, @@ -219,6 +229,16 @@ where } } +impl FheMin for FheInt +where + Id: FheIntId, +{ + type Output = Self; + fn min(&self, rhs: Self) -> Self::Output { + self.min(&rhs) + } +} + impl FheEq for FheInt where Id: FheIntId, diff --git a/tfhe/src/high_level_api/integers/signed/tests/cpu.rs b/tfhe/src/high_level_api/integers/signed/tests/cpu.rs index b33ab8e30..ee1de0dec 100644 --- a/tfhe/src/high_level_api/integers/signed/tests/cpu.rs +++ b/tfhe/src/high_level_api/integers/signed/tests/cpu.rs @@ -110,6 +110,12 @@ fn test_integer_compress_decompress() { super::test_case_integer_compress_decompress(&client_key); } +#[test] +fn test_min_max() { + let client_key = setup_default_cpu(); + super::test_case_min_max(&client_key); +} + #[test] fn test_trivial_fhe_int8() { let config = ConfigBuilder::default().build(); diff --git a/tfhe/src/high_level_api/integers/signed/tests/gpu.rs b/tfhe/src/high_level_api/integers/signed/tests/gpu.rs index fd211f9ff..3a8a3179f 100644 --- a/tfhe/src/high_level_api/integers/signed/tests/gpu.rs +++ b/tfhe/src/high_level_api/integers/signed/tests/gpu.rs @@ -77,6 +77,14 @@ fn test_leading_trailing_zeros_ones() { test_case_leading_trailing_zeros_ones(&client_key); } +#[test] +fn test_min_max() { + let client_key = crate::high_level_api::integers::unsigned::tests::gpu::setup_gpu(Some( + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS, + )); + super::test_case_min_max(&client_key); +} + #[test] fn test_gpu_get_add_sub_size_on_gpu() { let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS)); diff --git a/tfhe/src/high_level_api/integers/signed/tests/mod.rs b/tfhe/src/high_level_api/integers/signed/tests/mod.rs index fd9341528..729f7c6cb 100644 --- a/tfhe/src/high_level_api/integers/signed/tests/mod.rs +++ b/tfhe/src/high_level_api/integers/signed/tests/mod.rs @@ -453,3 +453,28 @@ fn test_case_ilog2(cks: &ClientKey) { assert!(!is_ok); } } + +fn test_case_min_max(cks: &ClientKey) { + let mut rng = rand::thread_rng(); + let a_val: i8 = rng.gen(); + let b_val: i8 = rng.gen(); + + let a = FheInt8::encrypt(a_val, cks); + let b = FheInt8::encrypt(b_val, cks); + + // Test by-reference operations + let encrypted_min = a.min(&b); + let encrypted_max = a.max(&b); + let decrypted_min: i8 = encrypted_min.decrypt(cks); + let decrypted_max: i8 = encrypted_max.decrypt(cks); + assert_eq!(decrypted_min, a_val.min(b_val)); + assert_eq!(decrypted_max, a_val.max(b_val)); + + // Test by-value operations + let encrypted_min = a.min(b.clone()); + let encrypted_max = a.max(b); + let decrypted_min: i8 = encrypted_min.decrypt(cks); + let decrypted_max: i8 = encrypted_max.decrypt(cks); + assert_eq!(decrypted_min, a_val.min(b_val)); + assert_eq!(decrypted_max, a_val.max(b_val)); +} diff --git a/tfhe/src/high_level_api/integers/unsigned/ops.rs b/tfhe/src/high_level_api/integers/unsigned/ops.rs index 7e9787d73..97ff9a95f 100644 --- a/tfhe/src/high_level_api/integers/unsigned/ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/ops.rs @@ -277,6 +277,16 @@ where } } +impl FheMax for FheUint +where + Id: FheUintId, +{ + type Output = Self; + fn max(&self, rhs: Self) -> Self::Output { + self.max(&rhs) + } +} + impl FheMin<&Self> for FheUint where Id: FheUintId, @@ -330,6 +340,16 @@ where } } +impl FheMin for FheUint +where + Id: FheUintId, +{ + type Output = Self; + fn min(&self, rhs: Self) -> Self::Output { + self.min(&rhs) + } +} + impl FheEq for FheUint where Id: FheUintId, diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs index 6a41700bf..0a682e8f4 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs @@ -444,6 +444,12 @@ fn test_sum() { super::test_case_sum(&client_key); } +#[test] +fn test_min_max() { + let client_key = setup_default_cpu(); + super::test_case_min_max(&client_key); +} + #[test] fn test_safe_deserialize_conformant_fhe_uint32() { let block_params = PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs index 19ec6d7af..1a5061413 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs @@ -169,6 +169,12 @@ fn test_ilog2_multibit() { super::test_case_ilog2(&client_key); } +#[test] +fn test_min_max() { + let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS)); + super::test_case_min_max(&client_key); +} + #[test] fn test_gpu_get_add_and_sub_size_on_gpu() { let cks = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS)); diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs b/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs index 27609fb96..7bd5ad36b 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs @@ -715,3 +715,28 @@ fn test_case_is_even_is_odd(cks: &ClientKey) { ); } } + +fn test_case_min_max(cks: &ClientKey) { + let mut rng = rand::thread_rng(); + let a_val: u8 = rng.gen(); + let b_val: u8 = rng.gen(); + + let a = FheUint8::encrypt(a_val, cks); + let b = FheUint8::encrypt(b_val, cks); + + // Test by-reference operations + let encrypted_min = a.min(&b); + let encrypted_max = a.max(&b); + let decrypted_min: u8 = encrypted_min.decrypt(cks); + let decrypted_max: u8 = encrypted_max.decrypt(cks); + assert_eq!(decrypted_min, a_val.min(b_val)); + assert_eq!(decrypted_max, a_val.max(b_val)); + + // Test by-value operations + let encrypted_min = a.min(b.clone()); + let encrypted_max = a.max(b); + let decrypted_min: u8 = encrypted_min.decrypt(cks); + let decrypted_max: u8 = encrypted_max.decrypt(cks); + assert_eq!(decrypted_min, a_val.min(b_val)); + assert_eq!(decrypted_max, a_val.max(b_val)); +}