feat(hlapi): add flip operation

This commit is contained in:
tmontaigu
2025-09-29 14:09:44 +02:00
parent f0f3dd76eb
commit c95e38e26f
10 changed files with 443 additions and 4 deletions

View File

@@ -6,7 +6,9 @@ use crate::high_level_api::errors::UninitializedReRandKey;
use crate::high_level_api::integers::{FheInt, FheIntId, FheUint, FheUintId};
use crate::high_level_api::keys::InternalServerKey;
use crate::high_level_api::re_randomization::ReRandomizationMetadata;
use crate::high_level_api::traits::{FheEq, IfThenElse, ReRandomize, ScalarIfThenElse, Tagged};
use crate::high_level_api::traits::{
FheEq, Flip, IfThenElse, ReRandomize, ScalarIfThenElse, Tagged,
};
use crate::high_level_api::{global_state, CompactPublicKey};
use crate::integer::block_decomposition::DecomposableInto;
use crate::integer::ciphertext::ReRandomizationSeed;
@@ -625,6 +627,310 @@ impl IfThenElse<Self> for FheBool {
}
}
impl<Id> Flip<&FheInt<Id>, &FheInt<Id>> for FheBool
where
Id: FheIntId + std::marker::Send + std::marker::Sync,
{
type Output = FheInt<Id>;
/// Flips the two inputs based on the value of `self`.
///
/// * flip(true, a, b) returns (b, a)
/// * flip(false, a, b) returns (a, b)
///
/// # Example
///
/// ```rust
/// use tfhe::prelude::*;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheBool, FheInt32};
///
/// let (client_key, server_key) = generate_keys(ConfigBuilder::default());
/// set_server_key(server_key);
///
/// let a = i32::MIN;
/// let b = FheInt32::encrypt(i32::MAX, &client_key);
/// let cond = FheBool::encrypt(true, &client_key);
///
/// let (ra, rb) = cond.flip(a, &b);
/// let da: i32 = ra.decrypt(&client_key);
/// let db: i32 = rb.decrypt(&client_key);
/// assert_eq!((da, db), (i32::MAX, i32::MIN));
/// ```
fn flip(&self, lhs: &FheInt<Id>, rhs: &FheInt<Id>) -> (Self::Output, Self::Output) {
let ct_condition = self;
global_state::with_internal_keys(|sks| match sks {
InternalServerKey::Cpu(cpu_sks) => {
let (a, b) = cpu_sks.pbs_key().flip_parallelized(
&ct_condition.ciphertext.on_cpu(),
&*lhs.ciphertext.on_cpu(),
&*rhs.ciphertext.on_cpu(),
);
(
FheInt::new(a, cpu_sks.tag.clone(), ReRandomizationMetadata::default()),
FheInt::new(b, cpu_sks.tag.clone(), ReRandomizationMetadata::default()),
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => rayon::join(
|| {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.if_then_else(
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
&*rhs.ciphertext.on_gpu(streams),
&*lhs.ciphertext.on_gpu(streams),
streams,
);
FheInt::new(
inner,
cuda_key.tag.clone(),
ReRandomizationMetadata::default(),
)
},
|| {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.if_then_else(
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
&*lhs.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
FheInt::new(
inner,
cuda_key.tag.clone(),
ReRandomizationMetadata::default(),
)
},
),
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
let a = self.if_then_else(rhs, lhs);
let b = self.if_then_else(lhs, rhs);
(a, b)
}
})
}
}
impl<Id> Flip<&FheUint<Id>, &FheUint<Id>> for FheBool
where
Id: FheUintId + std::marker::Send + std::marker::Sync,
{
type Output = FheUint<Id>;
/// Flips the two inputs based on the value of `self`.
///
/// * flip(true, a, b) returns (b, a)
/// * flip(false, a, b) returns (a, b)
///
/// # Example
///
/// ```rust
/// use tfhe::prelude::*;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheBool, FheUint32};
///
/// let (client_key, server_key) = generate_keys(ConfigBuilder::default());
/// set_server_key(server_key);
///
/// let a = u32::MIN;
/// let b = FheUint32::encrypt(u32::MAX, &client_key);
/// let cond = FheBool::encrypt(true, &client_key);
///
/// let (ra, rb) = cond.flip(a, &b);
/// let da: u32 = ra.decrypt(&client_key);
/// let db: u32 = rb.decrypt(&client_key);
/// assert_eq!((da, db), (u32::MAX, u32::MIN));
/// ```
fn flip(&self, lhs: &FheUint<Id>, rhs: &FheUint<Id>) -> (Self::Output, Self::Output) {
let ct_condition = self;
global_state::with_internal_keys(|sks| match sks {
InternalServerKey::Cpu(cpu_sks) => {
let (a, b) = cpu_sks.pbs_key().flip_parallelized(
&ct_condition.ciphertext.on_cpu(),
&*lhs.ciphertext.on_cpu(),
&*rhs.ciphertext.on_cpu(),
);
(
FheUint::new(a, cpu_sks.tag.clone(), ReRandomizationMetadata::default()),
FheUint::new(b, cpu_sks.tag.clone(), ReRandomizationMetadata::default()),
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => rayon::join(
|| {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.if_then_else(
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
&*rhs.ciphertext.on_gpu(streams),
&*lhs.ciphertext.on_gpu(streams),
streams,
);
FheUint::new(
inner,
cuda_key.tag.clone(),
ReRandomizationMetadata::default(),
)
},
|| {
let streams = &cuda_key.streams;
let inner = cuda_key.key.key.if_then_else(
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
&*lhs.ciphertext.on_gpu(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
FheUint::new(
inner,
cuda_key.tag.clone(),
ReRandomizationMetadata::default(),
)
},
),
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
let a = self.if_then_else(rhs, lhs);
let b = self.if_then_else(lhs, rhs);
(a, b)
}
})
}
}
impl<Id, T> Flip<&FheUint<Id>, T> for FheBool
where
Id: FheUintId,
T: DecomposableInto<u64> + UnsignedNumeric,
{
type Output = FheUint<Id>;
fn flip(&self, lhs: &FheUint<Id>, rhs: T) -> (Self::Output, Self::Output) {
let ct_condition = self;
global_state::with_internal_keys(|sks| match sks {
InternalServerKey::Cpu(cpu_sks) => {
let (a, b) = cpu_sks.pbs_key().flip_parallelized(
&ct_condition.ciphertext.on_cpu(),
&*lhs.ciphertext.on_cpu(),
rhs,
);
(
FheUint::new(a, cpu_sks.tag.clone(), ReRandomizationMetadata::default()),
FheUint::new(b, cpu_sks.tag.clone(), ReRandomizationMetadata::default()),
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Gpu does not support FheBool::flip with clear input")
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support FheBool::flip with clear input")
}
})
}
}
impl<Id, T> Flip<&FheInt<Id>, T> for FheBool
where
Id: FheIntId,
T: DecomposableInto<u64> + SignedNumeric,
{
type Output = FheInt<Id>;
fn flip(&self, lhs: &FheInt<Id>, rhs: T) -> (Self::Output, Self::Output) {
let ct_condition = self;
global_state::with_internal_keys(|sks| match sks {
InternalServerKey::Cpu(cpu_sks) => {
let (a, b) = cpu_sks.pbs_key().flip_parallelized(
&ct_condition.ciphertext.on_cpu(),
&*lhs.ciphertext.on_cpu(),
rhs,
);
(
FheInt::new(a, cpu_sks.tag.clone(), ReRandomizationMetadata::default()),
FheInt::new(b, cpu_sks.tag.clone(), ReRandomizationMetadata::default()),
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Gpu does not support FheBool::flip with clear input")
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support FheBool::flip with clear input")
}
})
}
}
impl<Id, T> Flip<T, &FheInt<Id>> for FheBool
where
Id: FheIntId,
T: DecomposableInto<u64> + SignedNumeric,
{
type Output = FheInt<Id>;
fn flip(&self, lhs: T, rhs: &FheInt<Id>) -> (Self::Output, Self::Output) {
let ct_condition = self;
global_state::with_internal_keys(|sks| match sks {
InternalServerKey::Cpu(cpu_sks) => {
let (a, b) = cpu_sks.pbs_key().flip_parallelized(
&ct_condition.ciphertext.on_cpu(),
lhs,
&*rhs.ciphertext.on_cpu(),
);
(
FheInt::new(a, cpu_sks.tag.clone(), ReRandomizationMetadata::default()),
FheInt::new(b, cpu_sks.tag.clone(), ReRandomizationMetadata::default()),
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Gpu does not support FheBool::flip with clear input")
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support FheBool::flip with clear input")
}
})
}
}
impl<Id, T> Flip<T, &FheUint<Id>> for FheBool
where
Id: FheUintId,
T: DecomposableInto<u64> + UnsignedNumeric,
{
type Output = FheUint<Id>;
fn flip(&self, lhs: T, rhs: &FheUint<Id>) -> (Self::Output, Self::Output) {
let ct_condition = self;
global_state::with_internal_keys(|sks| match sks {
InternalServerKey::Cpu(cpu_sks) => {
let (a, b) = cpu_sks.pbs_key().flip_parallelized(
&ct_condition.ciphertext.on_cpu(),
lhs,
&*rhs.ciphertext.on_cpu(),
);
(
FheUint::new(a, cpu_sks.tag.clone(), ReRandomizationMetadata::default()),
FheUint::new(b, cpu_sks.tag.clone(), ReRandomizationMetadata::default()),
)
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Gpu does not support FheBool::flip with clear input")
}
#[cfg(feature = "hpu")]
InternalServerKey::Hpu(_device) => {
panic!("Hpu does not support FheBool::flip with clear input")
}
})
}
}
impl Tagged for FheBool {
fn tag(&self) -> &Tag {
&self.tag

View File

@@ -98,6 +98,18 @@ fn test_if_then_else() {
super::test_case_if_then_else(&client_key);
}
#[test]
fn test_flip() {
let client_key = setup_default_cpu();
super::test_case_flip(&client_key);
}
#[test]
fn test_scalar_flip() {
let client_key = setup_default_cpu();
super::test_case_scalar_flip(&client_key);
}
#[test]
fn test_abs() {
let client_key = setup_default_cpu();

View File

@@ -53,6 +53,14 @@ fn test_if_then_else() {
super::test_case_if_then_else(&client_key);
}
#[test]
fn test_flip() {
let client_key = crate::high_level_api::integers::unsigned::tests::gpu::setup_gpu(Some(
PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS,
));
super::test_case_flip(&client_key);
}
#[test]
fn test_abs() {
let client_key = crate::high_level_api::integers::unsigned::tests::gpu::setup_gpu(Some(

View File

@@ -1,5 +1,5 @@
use crate::prelude::*;
use crate::{ClientKey, FheInt16, FheInt32, FheInt64, FheInt8, FheUint64, FheUint8};
use crate::{ClientKey, FheBool, FheInt16, FheInt32, FheInt64, FheInt8, FheUint64, FheUint8};
use rand::prelude::*;
mod cpu;
@@ -390,6 +390,47 @@ fn test_case_if_then_else(cks: &ClientKey) {
);
}
fn test_case_flip(client_key: &ClientKey) {
let clear_a = rand::random::<i32>();
let clear_b = rand::random::<i32>();
let a = FheInt32::encrypt(clear_a, client_key);
let b = FheInt32::encrypt(clear_b, client_key);
let c = FheBool::encrypt(true, client_key);
let (ra, rb) = c.flip(&a, &b);
let decrypted_a: i32 = ra.decrypt(client_key);
let decrypted_b: i32 = rb.decrypt(client_key);
assert_eq!((decrypted_a, decrypted_b), (clear_b, clear_a));
let c = FheBool::encrypt(false, client_key);
let (ra, rb) = c.flip(&a, &b);
let decrypted_a: i32 = ra.decrypt(client_key);
let decrypted_b: i32 = rb.decrypt(client_key);
assert_eq!((decrypted_a, decrypted_b), (clear_a, clear_b));
}
fn test_case_scalar_flip(client_key: &ClientKey) {
let clear_a = rand::random::<i32>();
let clear_b = rand::random::<i32>();
let a = FheInt32::encrypt(clear_a, client_key);
let b = FheInt32::encrypt(clear_b, client_key);
let c = FheBool::encrypt(true, client_key);
let (ra, rb) = c.flip(&a, clear_b);
let decrypted_a: i32 = ra.decrypt(client_key);
let decrypted_b: i32 = rb.decrypt(client_key);
assert_eq!((decrypted_a, decrypted_b), (clear_b, clear_a));
let c = FheBool::encrypt(false, client_key);
let (ra, rb) = c.flip(clear_a, &b);
let decrypted_a: i32 = ra.decrypt(client_key);
let decrypted_b: i32 = rb.decrypt(client_key);
assert_eq!((decrypted_a, decrypted_b), (clear_a, clear_b));
}
fn test_case_abs(cks: &ClientKey) {
let mut rng = rand::thread_rng();

View File

@@ -386,6 +386,18 @@ fn test_if_then_else() {
super::test_case_if_then_else(&client_key);
}
#[test]
fn test_flip() {
let client_key = setup_default_cpu();
super::test_case_flip(&client_key);
}
#[test]
fn test_scalar_flip() {
let client_key = setup_default_cpu();
super::test_case_scalar_flip(&client_key);
}
#[test]
fn test_scalar_shift_when_clear_type_is_small() {
// This is a regression tests

View File

@@ -121,6 +121,12 @@ fn test_if_then_else_gpu_multibit() {
super::test_case_if_then_else(&client_key);
}
#[test]
fn test_flip() {
let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS));
super::test_case_flip(&client_key);
}
#[test]
fn test_sum_gpu() {
let client_key = setup_default_gpu();

View File

@@ -88,6 +88,13 @@ fn test_case_if_then_else_hpu() {
let client_key = setup_default_hpu();
super::test_case_if_then_else(&client_key);
}
#[test]
fn test_case_flip_hpu() {
let client_key = setup_default_hpu();
super::test_case_flip(&client_key);
}
#[test]
fn test_case_uint32_div_rem_hpu() {
let client_key = setup_default_hpu();

View File

@@ -1,7 +1,7 @@
use crate::high_level_api::traits::BitSlice;
use crate::integer::U256;
use crate::prelude::*;
use crate::{ClientKey, FheUint256, FheUint32, FheUint64, FheUint8};
use crate::{ClientKey, FheBool, FheUint256, FheUint32, FheUint64, FheUint8};
use rand::{thread_rng, Rng};
mod cpu;
@@ -567,6 +567,47 @@ fn test_case_if_then_else(client_key: &ClientKey) {
);
}
fn test_case_flip(client_key: &ClientKey) {
let clear_a = rand::random::<u32>();
let clear_b = rand::random::<u32>();
let a = FheUint32::encrypt(clear_a, client_key);
let b = FheUint32::encrypt(clear_b, client_key);
let c = FheBool::encrypt(true, client_key);
let (ra, rb) = c.flip(&a, &b);
let decrypted_a: u32 = ra.decrypt(client_key);
let decrypted_b: u32 = rb.decrypt(client_key);
assert_eq!((decrypted_a, decrypted_b), (clear_b, clear_a));
let c = FheBool::encrypt(false, client_key);
let (ra, rb) = c.flip(&a, &b);
let decrypted_a: u32 = ra.decrypt(client_key);
let decrypted_b: u32 = rb.decrypt(client_key);
assert_eq!((decrypted_a, decrypted_b), (clear_a, clear_b));
}
fn test_case_scalar_flip(client_key: &ClientKey) {
let clear_a = rand::random::<u32>();
let clear_b = rand::random::<u32>();
let a = FheUint32::encrypt(clear_a, client_key);
let b = FheUint32::encrypt(clear_b, client_key);
let c = FheBool::encrypt(true, client_key);
let (ra, rb) = c.flip(&a, clear_b);
let decrypted_a: u32 = ra.decrypt(client_key);
let decrypted_b: u32 = rb.decrypt(client_key);
assert_eq!((decrypted_a, decrypted_b), (clear_b, clear_a));
let c = FheBool::encrypt(false, client_key);
let (ra, rb) = c.flip(clear_a, &b);
let decrypted_a: u32 = ra.decrypt(client_key);
let decrypted_b: u32 = rb.decrypt(client_key);
assert_eq!((decrypted_a, decrypted_b), (clear_a, clear_b));
}
fn test_case_leading_trailing_zeros_ones(cks: &ClientKey) {
let mut rng = rand::thread_rng();
for _ in 0..5 {

View File

@@ -8,7 +8,7 @@
//! ```
pub use crate::high_level_api::traits::{
BitSlice, CiphertextList, DivRem, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin,
FheOrd, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, FheWait, IfThenElse,
FheOrd, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, FheWait, Flip, IfThenElse,
OverflowingAdd, OverflowingMul, OverflowingNeg, OverflowingSub, ReRandomize, RotateLeft,
RotateLeftAssign, RotateRight, RotateRightAssign, ScalarIfThenElse, SquashNoise, Tagged,
};

View File

@@ -163,6 +163,12 @@ pub trait ScalarIfThenElse<Lhs, Rhs> {
}
}
pub trait Flip<Lhs, Rhs> {
type Output;
fn flip(&self, lhs: Lhs, rhs: Rhs) -> (Self::Output, Self::Output);
}
pub trait OverflowingAdd<Rhs> {
type Output;