diff --git a/tfhe/src/high_level_api/booleans/base.rs b/tfhe/src/high_level_api/booleans/base.rs index 60312d7bd..c30a8bc78 100644 --- a/tfhe/src/high_level_api/booleans/base.rs +++ b/tfhe/src/high_level_api/booleans/base.rs @@ -6,7 +6,9 @@ use crate::high_level_api::errors::UninitializedReRandKey; use crate::high_level_api::integers::{FheInt, FheIntId, FheUint, FheUintId}; use crate::high_level_api::keys::InternalServerKey; use crate::high_level_api::re_randomization::ReRandomizationMetadata; -use crate::high_level_api::traits::{FheEq, IfThenElse, ReRandomize, ScalarIfThenElse, Tagged}; +use crate::high_level_api::traits::{ + FheEq, Flip, IfThenElse, ReRandomize, ScalarIfThenElse, Tagged, +}; use crate::high_level_api::{global_state, CompactPublicKey}; use crate::integer::block_decomposition::DecomposableInto; use crate::integer::ciphertext::ReRandomizationSeed; @@ -625,6 +627,310 @@ impl IfThenElse for FheBool { } } +impl Flip<&FheInt, &FheInt> for FheBool +where + Id: FheIntId + std::marker::Send + std::marker::Sync, +{ + type Output = FheInt; + + /// Flips the two inputs based on the value of `self`. + /// + /// * flip(true, a, b) returns (b, a) + /// * flip(false, a, b) returns (a, b) + /// + /// # Example + /// + /// ```rust + /// use tfhe::prelude::*; + /// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheBool, FheInt32}; + /// + /// let (client_key, server_key) = generate_keys(ConfigBuilder::default()); + /// set_server_key(server_key); + /// + /// let a = i32::MIN; + /// let b = FheInt32::encrypt(i32::MAX, &client_key); + /// let cond = FheBool::encrypt(true, &client_key); + /// + /// let (ra, rb) = cond.flip(a, &b); + /// let da: i32 = ra.decrypt(&client_key); + /// let db: i32 = rb.decrypt(&client_key); + /// assert_eq!((da, db), (i32::MAX, i32::MIN)); + /// ``` + fn flip(&self, lhs: &FheInt, rhs: &FheInt) -> (Self::Output, Self::Output) { + let ct_condition = self; + global_state::with_internal_keys(|sks| match sks { + InternalServerKey::Cpu(cpu_sks) => { + let (a, b) = cpu_sks.pbs_key().flip_parallelized( + &ct_condition.ciphertext.on_cpu(), + &*lhs.ciphertext.on_cpu(), + &*rhs.ciphertext.on_cpu(), + ); + ( + FheInt::new(a, cpu_sks.tag.clone(), ReRandomizationMetadata::default()), + FheInt::new(b, cpu_sks.tag.clone(), ReRandomizationMetadata::default()), + ) + } + #[cfg(feature = "gpu")] + InternalServerKey::Cuda(cuda_key) => rayon::join( + || { + let streams = &cuda_key.streams; + let inner = cuda_key.key.key.if_then_else( + &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), + &*rhs.ciphertext.on_gpu(streams), + &*lhs.ciphertext.on_gpu(streams), + streams, + ); + + FheInt::new( + inner, + cuda_key.tag.clone(), + ReRandomizationMetadata::default(), + ) + }, + || { + let streams = &cuda_key.streams; + let inner = cuda_key.key.key.if_then_else( + &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), + &*lhs.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), + streams, + ); + + FheInt::new( + inner, + cuda_key.tag.clone(), + ReRandomizationMetadata::default(), + ) + }, + ), + #[cfg(feature = "hpu")] + InternalServerKey::Hpu(_device) => { + let a = self.if_then_else(rhs, lhs); + let b = self.if_then_else(lhs, rhs); + (a, b) + } + }) + } +} + +impl Flip<&FheUint, &FheUint> for FheBool +where + Id: FheUintId + std::marker::Send + std::marker::Sync, +{ + type Output = FheUint; + + /// Flips the two inputs based on the value of `self`. + /// + /// * flip(true, a, b) returns (b, a) + /// * flip(false, a, b) returns (a, b) + /// + /// # Example + /// + /// ```rust + /// use tfhe::prelude::*; + /// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheBool, FheUint32}; + /// + /// let (client_key, server_key) = generate_keys(ConfigBuilder::default()); + /// set_server_key(server_key); + /// + /// let a = u32::MIN; + /// let b = FheUint32::encrypt(u32::MAX, &client_key); + /// let cond = FheBool::encrypt(true, &client_key); + /// + /// let (ra, rb) = cond.flip(a, &b); + /// let da: u32 = ra.decrypt(&client_key); + /// let db: u32 = rb.decrypt(&client_key); + /// assert_eq!((da, db), (u32::MAX, u32::MIN)); + /// ``` + fn flip(&self, lhs: &FheUint, rhs: &FheUint) -> (Self::Output, Self::Output) { + let ct_condition = self; + global_state::with_internal_keys(|sks| match sks { + InternalServerKey::Cpu(cpu_sks) => { + let (a, b) = cpu_sks.pbs_key().flip_parallelized( + &ct_condition.ciphertext.on_cpu(), + &*lhs.ciphertext.on_cpu(), + &*rhs.ciphertext.on_cpu(), + ); + ( + FheUint::new(a, cpu_sks.tag.clone(), ReRandomizationMetadata::default()), + FheUint::new(b, cpu_sks.tag.clone(), ReRandomizationMetadata::default()), + ) + } + #[cfg(feature = "gpu")] + InternalServerKey::Cuda(cuda_key) => rayon::join( + || { + let streams = &cuda_key.streams; + let inner = cuda_key.key.key.if_then_else( + &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), + &*rhs.ciphertext.on_gpu(streams), + &*lhs.ciphertext.on_gpu(streams), + streams, + ); + + FheUint::new( + inner, + cuda_key.tag.clone(), + ReRandomizationMetadata::default(), + ) + }, + || { + let streams = &cuda_key.streams; + let inner = cuda_key.key.key.if_then_else( + &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), + &*lhs.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), + streams, + ); + + FheUint::new( + inner, + cuda_key.tag.clone(), + ReRandomizationMetadata::default(), + ) + }, + ), + #[cfg(feature = "hpu")] + InternalServerKey::Hpu(_device) => { + let a = self.if_then_else(rhs, lhs); + let b = self.if_then_else(lhs, rhs); + (a, b) + } + }) + } +} + +impl Flip<&FheUint, T> for FheBool +where + Id: FheUintId, + T: DecomposableInto + UnsignedNumeric, +{ + type Output = FheUint; + + fn flip(&self, lhs: &FheUint, rhs: T) -> (Self::Output, Self::Output) { + let ct_condition = self; + global_state::with_internal_keys(|sks| match sks { + InternalServerKey::Cpu(cpu_sks) => { + let (a, b) = cpu_sks.pbs_key().flip_parallelized( + &ct_condition.ciphertext.on_cpu(), + &*lhs.ciphertext.on_cpu(), + rhs, + ); + ( + FheUint::new(a, cpu_sks.tag.clone(), ReRandomizationMetadata::default()), + FheUint::new(b, cpu_sks.tag.clone(), ReRandomizationMetadata::default()), + ) + } + #[cfg(feature = "gpu")] + InternalServerKey::Cuda(_) => { + panic!("Gpu does not support FheBool::flip with clear input") + } + #[cfg(feature = "hpu")] + InternalServerKey::Hpu(_device) => { + panic!("Hpu does not support FheBool::flip with clear input") + } + }) + } +} + +impl Flip<&FheInt, T> for FheBool +where + Id: FheIntId, + T: DecomposableInto + SignedNumeric, +{ + type Output = FheInt; + + fn flip(&self, lhs: &FheInt, rhs: T) -> (Self::Output, Self::Output) { + let ct_condition = self; + global_state::with_internal_keys(|sks| match sks { + InternalServerKey::Cpu(cpu_sks) => { + let (a, b) = cpu_sks.pbs_key().flip_parallelized( + &ct_condition.ciphertext.on_cpu(), + &*lhs.ciphertext.on_cpu(), + rhs, + ); + ( + FheInt::new(a, cpu_sks.tag.clone(), ReRandomizationMetadata::default()), + FheInt::new(b, cpu_sks.tag.clone(), ReRandomizationMetadata::default()), + ) + } + #[cfg(feature = "gpu")] + InternalServerKey::Cuda(_) => { + panic!("Gpu does not support FheBool::flip with clear input") + } + #[cfg(feature = "hpu")] + InternalServerKey::Hpu(_device) => { + panic!("Hpu does not support FheBool::flip with clear input") + } + }) + } +} + +impl Flip> for FheBool +where + Id: FheIntId, + T: DecomposableInto + SignedNumeric, +{ + type Output = FheInt; + + fn flip(&self, lhs: T, rhs: &FheInt) -> (Self::Output, Self::Output) { + let ct_condition = self; + global_state::with_internal_keys(|sks| match sks { + InternalServerKey::Cpu(cpu_sks) => { + let (a, b) = cpu_sks.pbs_key().flip_parallelized( + &ct_condition.ciphertext.on_cpu(), + lhs, + &*rhs.ciphertext.on_cpu(), + ); + ( + FheInt::new(a, cpu_sks.tag.clone(), ReRandomizationMetadata::default()), + FheInt::new(b, cpu_sks.tag.clone(), ReRandomizationMetadata::default()), + ) + } + #[cfg(feature = "gpu")] + InternalServerKey::Cuda(_) => { + panic!("Gpu does not support FheBool::flip with clear input") + } + #[cfg(feature = "hpu")] + InternalServerKey::Hpu(_device) => { + panic!("Hpu does not support FheBool::flip with clear input") + } + }) + } +} + +impl Flip> for FheBool +where + Id: FheUintId, + T: DecomposableInto + UnsignedNumeric, +{ + type Output = FheUint; + + fn flip(&self, lhs: T, rhs: &FheUint) -> (Self::Output, Self::Output) { + let ct_condition = self; + global_state::with_internal_keys(|sks| match sks { + InternalServerKey::Cpu(cpu_sks) => { + let (a, b) = cpu_sks.pbs_key().flip_parallelized( + &ct_condition.ciphertext.on_cpu(), + lhs, + &*rhs.ciphertext.on_cpu(), + ); + ( + FheUint::new(a, cpu_sks.tag.clone(), ReRandomizationMetadata::default()), + FheUint::new(b, cpu_sks.tag.clone(), ReRandomizationMetadata::default()), + ) + } + #[cfg(feature = "gpu")] + InternalServerKey::Cuda(_) => { + panic!("Gpu does not support FheBool::flip with clear input") + } + #[cfg(feature = "hpu")] + InternalServerKey::Hpu(_device) => { + panic!("Hpu does not support FheBool::flip with clear input") + } + }) + } +} + impl Tagged for FheBool { fn tag(&self) -> &Tag { &self.tag diff --git a/tfhe/src/high_level_api/integers/signed/tests/cpu.rs b/tfhe/src/high_level_api/integers/signed/tests/cpu.rs index ee1de0dec..92c3d8ece 100644 --- a/tfhe/src/high_level_api/integers/signed/tests/cpu.rs +++ b/tfhe/src/high_level_api/integers/signed/tests/cpu.rs @@ -98,6 +98,18 @@ fn test_if_then_else() { super::test_case_if_then_else(&client_key); } +#[test] +fn test_flip() { + let client_key = setup_default_cpu(); + super::test_case_flip(&client_key); +} + +#[test] +fn test_scalar_flip() { + let client_key = setup_default_cpu(); + super::test_case_scalar_flip(&client_key); +} + #[test] fn test_abs() { let client_key = setup_default_cpu(); diff --git a/tfhe/src/high_level_api/integers/signed/tests/gpu.rs b/tfhe/src/high_level_api/integers/signed/tests/gpu.rs index 3a8a3179f..f1e013c84 100644 --- a/tfhe/src/high_level_api/integers/signed/tests/gpu.rs +++ b/tfhe/src/high_level_api/integers/signed/tests/gpu.rs @@ -53,6 +53,14 @@ fn test_if_then_else() { super::test_case_if_then_else(&client_key); } +#[test] +fn test_flip() { + let client_key = crate::high_level_api::integers::unsigned::tests::gpu::setup_gpu(Some( + PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS, + )); + super::test_case_flip(&client_key); +} + #[test] fn test_abs() { let client_key = crate::high_level_api::integers::unsigned::tests::gpu::setup_gpu(Some( diff --git a/tfhe/src/high_level_api/integers/signed/tests/mod.rs b/tfhe/src/high_level_api/integers/signed/tests/mod.rs index 729f7c6cb..34ccb5c2d 100644 --- a/tfhe/src/high_level_api/integers/signed/tests/mod.rs +++ b/tfhe/src/high_level_api/integers/signed/tests/mod.rs @@ -1,5 +1,5 @@ use crate::prelude::*; -use crate::{ClientKey, FheInt16, FheInt32, FheInt64, FheInt8, FheUint64, FheUint8}; +use crate::{ClientKey, FheBool, FheInt16, FheInt32, FheInt64, FheInt8, FheUint64, FheUint8}; use rand::prelude::*; mod cpu; @@ -390,6 +390,47 @@ fn test_case_if_then_else(cks: &ClientKey) { ); } +fn test_case_flip(client_key: &ClientKey) { + let clear_a = rand::random::(); + let clear_b = rand::random::(); + + let a = FheInt32::encrypt(clear_a, client_key); + let b = FheInt32::encrypt(clear_b, client_key); + + let c = FheBool::encrypt(true, client_key); + let (ra, rb) = c.flip(&a, &b); + let decrypted_a: i32 = ra.decrypt(client_key); + let decrypted_b: i32 = rb.decrypt(client_key); + assert_eq!((decrypted_a, decrypted_b), (clear_b, clear_a)); + + let c = FheBool::encrypt(false, client_key); + + let (ra, rb) = c.flip(&a, &b); + let decrypted_a: i32 = ra.decrypt(client_key); + let decrypted_b: i32 = rb.decrypt(client_key); + assert_eq!((decrypted_a, decrypted_b), (clear_a, clear_b)); +} + +fn test_case_scalar_flip(client_key: &ClientKey) { + let clear_a = rand::random::(); + let clear_b = rand::random::(); + + let a = FheInt32::encrypt(clear_a, client_key); + let b = FheInt32::encrypt(clear_b, client_key); + + let c = FheBool::encrypt(true, client_key); + let (ra, rb) = c.flip(&a, clear_b); + let decrypted_a: i32 = ra.decrypt(client_key); + let decrypted_b: i32 = rb.decrypt(client_key); + assert_eq!((decrypted_a, decrypted_b), (clear_b, clear_a)); + + let c = FheBool::encrypt(false, client_key); + let (ra, rb) = c.flip(clear_a, &b); + let decrypted_a: i32 = ra.decrypt(client_key); + let decrypted_b: i32 = rb.decrypt(client_key); + assert_eq!((decrypted_a, decrypted_b), (clear_a, clear_b)); +} + fn test_case_abs(cks: &ClientKey) { let mut rng = rand::thread_rng(); diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs index 0a682e8f4..471f454d6 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs @@ -386,6 +386,18 @@ fn test_if_then_else() { super::test_case_if_then_else(&client_key); } +#[test] +fn test_flip() { + let client_key = setup_default_cpu(); + super::test_case_flip(&client_key); +} + +#[test] +fn test_scalar_flip() { + let client_key = setup_default_cpu(); + super::test_case_scalar_flip(&client_key); +} + #[test] fn test_scalar_shift_when_clear_type_is_small() { // This is a regression tests diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs index 1a5061413..817ca181f 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/gpu.rs @@ -121,6 +121,12 @@ fn test_if_then_else_gpu_multibit() { super::test_case_if_then_else(&client_key); } +#[test] +fn test_flip() { + let client_key = setup_gpu(Some(PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS)); + super::test_case_flip(&client_key); +} + #[test] fn test_sum_gpu() { let client_key = setup_default_gpu(); diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/hpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/hpu.rs index a4ceea4dd..4b9f47c18 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/hpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/hpu.rs @@ -88,6 +88,13 @@ fn test_case_if_then_else_hpu() { let client_key = setup_default_hpu(); super::test_case_if_then_else(&client_key); } + +#[test] +fn test_case_flip_hpu() { + let client_key = setup_default_hpu(); + super::test_case_flip(&client_key); +} + #[test] fn test_case_uint32_div_rem_hpu() { let client_key = setup_default_hpu(); diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs b/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs index 7bd5ad36b..f44f1a052 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs @@ -1,7 +1,7 @@ use crate::high_level_api::traits::BitSlice; use crate::integer::U256; use crate::prelude::*; -use crate::{ClientKey, FheUint256, FheUint32, FheUint64, FheUint8}; +use crate::{ClientKey, FheBool, FheUint256, FheUint32, FheUint64, FheUint8}; use rand::{thread_rng, Rng}; mod cpu; @@ -567,6 +567,47 @@ fn test_case_if_then_else(client_key: &ClientKey) { ); } +fn test_case_flip(client_key: &ClientKey) { + let clear_a = rand::random::(); + let clear_b = rand::random::(); + + let a = FheUint32::encrypt(clear_a, client_key); + let b = FheUint32::encrypt(clear_b, client_key); + + let c = FheBool::encrypt(true, client_key); + let (ra, rb) = c.flip(&a, &b); + let decrypted_a: u32 = ra.decrypt(client_key); + let decrypted_b: u32 = rb.decrypt(client_key); + assert_eq!((decrypted_a, decrypted_b), (clear_b, clear_a)); + + let c = FheBool::encrypt(false, client_key); + + let (ra, rb) = c.flip(&a, &b); + let decrypted_a: u32 = ra.decrypt(client_key); + let decrypted_b: u32 = rb.decrypt(client_key); + assert_eq!((decrypted_a, decrypted_b), (clear_a, clear_b)); +} + +fn test_case_scalar_flip(client_key: &ClientKey) { + let clear_a = rand::random::(); + let clear_b = rand::random::(); + + let a = FheUint32::encrypt(clear_a, client_key); + let b = FheUint32::encrypt(clear_b, client_key); + + let c = FheBool::encrypt(true, client_key); + let (ra, rb) = c.flip(&a, clear_b); + let decrypted_a: u32 = ra.decrypt(client_key); + let decrypted_b: u32 = rb.decrypt(client_key); + assert_eq!((decrypted_a, decrypted_b), (clear_b, clear_a)); + + let c = FheBool::encrypt(false, client_key); + let (ra, rb) = c.flip(clear_a, &b); + let decrypted_a: u32 = ra.decrypt(client_key); + let decrypted_b: u32 = rb.decrypt(client_key); + assert_eq!((decrypted_a, decrypted_b), (clear_a, clear_b)); +} + fn test_case_leading_trailing_zeros_ones(cks: &ClientKey) { let mut rng = rand::thread_rng(); for _ in 0..5 { diff --git a/tfhe/src/high_level_api/prelude.rs b/tfhe/src/high_level_api/prelude.rs index d29f3b664..acacb163d 100644 --- a/tfhe/src/high_level_api/prelude.rs +++ b/tfhe/src/high_level_api/prelude.rs @@ -8,7 +8,7 @@ //! ``` pub use crate::high_level_api::traits::{ BitSlice, CiphertextList, DivRem, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin, - FheOrd, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, FheWait, IfThenElse, + FheOrd, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, FheWait, Flip, IfThenElse, OverflowingAdd, OverflowingMul, OverflowingNeg, OverflowingSub, ReRandomize, RotateLeft, RotateLeftAssign, RotateRight, RotateRightAssign, ScalarIfThenElse, SquashNoise, Tagged, }; diff --git a/tfhe/src/high_level_api/traits.rs b/tfhe/src/high_level_api/traits.rs index 869ed3312..d05a0cda8 100644 --- a/tfhe/src/high_level_api/traits.rs +++ b/tfhe/src/high_level_api/traits.rs @@ -163,6 +163,12 @@ pub trait ScalarIfThenElse { } } +pub trait Flip { + type Output; + + fn flip(&self, lhs: Lhs, rhs: Rhs) -> (Self::Output, Self::Output); +} + pub trait OverflowingAdd { type Output;