feat(hlapi): bind scalar_bitwise/div/rem operations

This commit is contained in:
tmontaigu
2023-06-26 16:33:56 +02:00
parent 16be1c1c1d
commit d496cfa431
5 changed files with 400 additions and 12 deletions

View File

@@ -1,8 +1,8 @@
use crate::c_api::high_level_api::keys::{ClientKey, CompactPublicKey, PublicKey};
use crate::high_level_api::prelude::*;
use std::ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Mul, MulAssign,
Neg, Not, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
};
use crate::c_api::high_level_api::u128::U128;
@@ -17,12 +17,121 @@ macro_rules! impl_operations_for_integer_type {
name: $name:ident,
clear_scalar_type: $clear_scalar_type:ty
) => {
impl_binary_fn_on_type!($name => add, sub, mul, bitand, bitor, bitxor, shl, shr, eq, ne, ge, gt, le, lt, min, max);
impl_binary_assign_fn_on_type!($name => add_assign, sub_assign, mul_assign, bitand_assign, bitor_assign, bitxor_assign, shl_assign, shr_assign);
impl_scalar_binary_fn_on_type!($name, $clear_scalar_type => add, sub, mul, shl, shr, eq, ne, ge, gt, le, lt, min, max);
impl_scalar_binary_assign_fn_on_type!($name, $clear_scalar_type => add_assign, sub_assign, mul_assign, shl_assign, shr_assign);
impl_binary_fn_on_type!($name =>
add,
sub,
mul,
bitand,
bitor,
bitxor,
shl,
shr,
eq,
ne,
ge,
gt,
le,
lt,
min,
max,
div,
rem,
);
impl_binary_assign_fn_on_type!($name =>
add_assign,
sub_assign,
mul_assign,
bitand_assign,
bitor_assign,
bitxor_assign,
shl_assign,
shr_assign,
div_assign,
rem_assign,
);
impl_scalar_binary_fn_on_type!($name, $clear_scalar_type =>
add,
sub,
mul,
bitand,
bitor,
bitxor,
shl,
shr,
eq,
ne,
ge,
gt,
le,
lt,
min,
max,
rotate_right,
rotate_left,
div,
rem,
);
impl_scalar_binary_assign_fn_on_type!($name, $clear_scalar_type =>
add_assign,
sub_assign,
mul_assign,
bitand_assign,
bitor_assign,
bitxor_assign,
shl_assign,
shr_assign,
rotate_right_assign,
rotate_left_assign,
div_assign,
rem_assign,
);
impl_unary_fn_on_type!($name => neg, not);
// Implement div_rem.
// We can't use the macro above as div_rem returns a tuple.
//
// (Having div_rem is important for the cases where you need both
// the quotient and remainder as you may save time by using the div_rem
// instead of div and rem separately
::paste::paste! {
#[no_mangle]
pub unsafe extern "C" fn [<$name:snake _scalar_div_rem>](
lhs: *const $name,
rhs: $clear_scalar_type,
q_result: *mut *mut $name,
r_result: *mut *mut $name,
) -> c_int {
$crate::c_api::utils::catch_panic(|| {
let lhs = $crate::c_api::utils::get_ref_checked(lhs).unwrap();
let (q, r) = (&lhs.0).div_rem(rhs);
*q_result = Box::into_raw(Box::new($name(q)));
*r_result = Box::into_raw(Box::new($name(r)));
})
}
}
::paste::paste! {
#[no_mangle]
pub unsafe extern "C" fn [<$name:snake _div_rem>](
lhs: *const $name,
rhs: *const $name,
q_result: *mut *mut $name,
r_result: *mut *mut $name,
) -> c_int {
$crate::c_api::utils::catch_panic(|| {
let lhs = $crate::c_api::utils::get_ref_checked(lhs).unwrap();
let rhs = $crate::c_api::utils::get_ref_checked(rhs).unwrap();
let (q, r) = (&lhs.0).div_rem(&rhs.0);
*q_result = Box::into_raw(Box::new($name(q)));
*r_result = Box::into_raw(Box::new($name(r)));
})
}
}
};
}

View File

@@ -271,7 +271,87 @@ fn fhe_uint32_shift(config: Config) {
}
#[test]
fn test_uint32_shift() {
fn test_uint32_bitwise() {
let config = ConfigBuilder::all_disabled()
.enable_default_integers()
.build();
let (cks, sks) = generate_keys(config);
use rand::prelude::*;
let mut rng = rand::thread_rng();
let clear_a = rng.gen::<u32>();
let clear_b = rng.gen_range(0u32..32u32);
let a = FheUint32::try_encrypt(clear_a, &cks).unwrap();
let b = FheUint32::try_encrypt(clear_b, &cks).unwrap();
set_server_key(sks);
// encrypted bitwise
{
let c = &a | &b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a | clear_b);
let c = &a & &b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a & clear_b);
let c = &a ^ &b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a ^ clear_b);
let mut c = a.clone();
c |= &b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a | clear_b);
let mut c = a.clone();
c &= &b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a & clear_b);
let mut c = a.clone();
c ^= &b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a ^ clear_b);
}
// clear bitwise
{
let c = &a | b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a | clear_b);
let c = &a & clear_b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a & clear_b);
let c = &a ^ clear_b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a ^ clear_b);
let mut c = a.clone();
c |= clear_b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a | clear_b);
let mut c = a.clone();
c &= clear_b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a & clear_b);
let mut c = a;
c ^= clear_b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a ^ clear_b);
}
}
#[test]
fn test_bit_shift() {
let config = ConfigBuilder::all_disabled()
.enable_default_integers()
.build();
@@ -365,6 +445,94 @@ fn test_multi_bit_rotate() {
fhe_uint32_rotate(config);
}
fn fhe_uint32_div_rem(config: Config) {
let (cks, sks) = generate_keys(config);
use rand::prelude::*;
let mut rng = rand::thread_rng();
let clear_a = rng.gen::<u32>();
let clear_b = rng.gen_range(1u32..=u32::MAX);
let a = FheUint32::try_encrypt(clear_a, &cks).unwrap();
let b = FheUint32::try_encrypt(clear_b, &cks).unwrap();
set_server_key(sks);
// encrypted div/rem
{
let c = &a / &b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a / clear_b);
let c = &a % &b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a % clear_b);
let (q, r) = (&a).div_rem(&b);
let decrypted_q: u32 = q.decrypt(&cks);
let decrypted_r: u32 = r.decrypt(&cks);
assert_eq!(decrypted_q, clear_a / clear_b);
assert_eq!(decrypted_r, clear_a % clear_b);
let mut c = a.clone();
c /= &b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a / clear_b);
let mut c = a.clone();
c %= &b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a % clear_b);
}
// clear div/rem
{
let c = &a / clear_b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a / clear_b);
let c = &a % clear_b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a % clear_b);
let (q, r) = (&a).div_rem(clear_b);
let decrypted_q: u32 = q.decrypt(&cks);
let decrypted_r: u32 = r.decrypt(&cks);
assert_eq!(decrypted_q, clear_a / clear_b);
assert_eq!(decrypted_r, clear_a % clear_b);
let mut c = a.clone();
c /= clear_b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a / clear_b);
let mut c = a;
c %= clear_b;
let decrypted: u32 = c.decrypt(&cks);
assert_eq!(decrypted, clear_a % clear_b);
}
}
#[test]
fn test_uint32_div_rem() {
let config = ConfigBuilder::all_disabled()
.enable_default_integers()
.build();
fhe_uint32_div_rem(config);
}
#[test]
fn test_multi_div_rem() {
let config = ConfigBuilder::all_disabled()
.enable_custom_integers(
crate::shortint::parameters::PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS,
None,
)
.build();
fhe_uint32_div_rem(config);
}
#[test]
fn test_uint64() {
let config = ConfigBuilder::all_disabled()

View File

@@ -1,7 +1,7 @@
use std::borrow::Borrow;
use std::ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Mul, MulAssign,
Neg, Not, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
Mul, MulAssign, Neg, Not, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
};
use crate::errors::{
@@ -14,8 +14,9 @@ use crate::high_level_api::integers::IntegerServerKey;
use crate::high_level_api::internal_traits::{DecryptionKey, TypeIdentifier};
use crate::high_level_api::keys::{CompressedPublicKey, RefKeyFromKeyChain};
use crate::high_level_api::traits::{
FheBootstrap, FheDecrypt, FheEq, FheMax, FheMin, FheOrd, FheTrivialEncrypt, FheTryEncrypt,
FheTryTrivialEncrypt, RotateLeft, RotateLeftAssign, RotateRight, RotateRightAssign,
DivRem, FheBootstrap, FheDecrypt, FheEq, FheMax, FheMin, FheOrd, FheTrivialEncrypt,
FheTryEncrypt, FheTryTrivialEncrypt, RotateLeft, RotateLeftAssign, RotateRight,
RotateRightAssign,
};
use crate::high_level_api::{ClientKey, PublicKey};
use crate::integer::block_decomposition::DecomposableInto;
@@ -669,6 +670,96 @@ where
}
}
impl<P, Clear> DivRem<Clear> for GenericInteger<P>
where
P: IntegerParameter,
Clear: Into<u64>,
P::Id: WithGlobalKey<Key = IntegerServerKey>,
{
type Output = (Self, Self);
fn div_rem(self, rhs: Clear) -> Self::Output {
<&Self as DivRem<Clear>>::div_rem(&self, rhs)
}
}
impl<P, Clear> DivRem<Clear> for &GenericInteger<P>
where
P: IntegerParameter,
Clear: Into<u64>,
P::Id: WithGlobalKey<Key = IntegerServerKey>,
{
type Output = (GenericInteger<P>, GenericInteger<P>);
fn div_rem(self, rhs: Clear) -> Self::Output {
let (q, r) = self.id.with_unwrapped_global(|integer_key| {
integer_key
.pbs_key()
.scalar_div_rem_parallelized(&self.ciphertext, rhs.into())
});
(
GenericInteger::<P>::new(q, self.id),
GenericInteger::<P>::new(r, self.id),
)
}
}
impl<P> DivRem<GenericInteger<P>> for GenericInteger<P>
where
P: IntegerParameter,
P::Id: WithGlobalKey<Key = IntegerServerKey>,
{
type Output = (GenericInteger<P>, GenericInteger<P>);
fn div_rem(self, rhs: GenericInteger<P>) -> Self::Output {
<Self as DivRem<&GenericInteger<P>>>::div_rem(self, &rhs)
}
}
impl<P> DivRem<&GenericInteger<P>> for GenericInteger<P>
where
P: IntegerParameter,
P::Id: WithGlobalKey<Key = IntegerServerKey>,
{
type Output = (GenericInteger<P>, GenericInteger<P>);
fn div_rem(self, rhs: &GenericInteger<P>) -> Self::Output {
<&Self as DivRem<&GenericInteger<P>>>::div_rem(&self, rhs)
}
}
impl<P> DivRem<GenericInteger<P>> for &GenericInteger<P>
where
P: IntegerParameter,
P::Id: WithGlobalKey<Key = IntegerServerKey>,
{
type Output = (GenericInteger<P>, GenericInteger<P>);
fn div_rem(self, rhs: GenericInteger<P>) -> Self::Output {
<Self as DivRem<&GenericInteger<P>>>::div_rem(self, &rhs)
}
}
impl<P> DivRem<&GenericInteger<P>> for &GenericInteger<P>
where
P: IntegerParameter,
P::Id: WithGlobalKey<Key = IntegerServerKey>,
{
type Output = (GenericInteger<P>, GenericInteger<P>);
fn div_rem(self, rhs: &GenericInteger<P>) -> Self::Output {
let (q, r) = self.id.with_unwrapped_global(|integer_key| {
integer_key
.pbs_key()
.div_rem_parallelized(&self.ciphertext, &rhs.ciphertext)
});
(
GenericInteger::<P>::new(q, self.id),
GenericInteger::<P>::new(r, self.id),
)
}
}
macro_rules! generic_integer_impl_operation (
($rust_trait_name:ident($rust_trait_method:ident) => $key_method:ident) => {
@@ -793,6 +884,8 @@ generic_integer_impl_operation!(Shl(shl) => left_shift_parallelized);
generic_integer_impl_operation!(Shr(shr) => right_shift_parallelized);
generic_integer_impl_operation!(RotateLeft(rotate_left) => rotate_left_parallelized);
generic_integer_impl_operation!(RotateRight(rotate_right) => rotate_right_parallelized);
generic_integer_impl_operation!(Div(div) => div_parallelized);
generic_integer_impl_operation!(Rem(rem) => rem_parallelized);
generic_integer_impl_operation_assign!(AddAssign(add_assign) => add_assign_parallelized);
generic_integer_impl_operation_assign!(SubAssign(sub_assign) => sub_assign_parallelized);
@@ -804,22 +897,34 @@ generic_integer_impl_operation_assign!(ShlAssign(shl_assign) => left_shift_assig
generic_integer_impl_operation_assign!(ShrAssign(shr_assign) => right_shift_assign_parallelized);
generic_integer_impl_operation_assign!(RotateLeftAssign(rotate_left_assign) => rotate_left_assign_parallelized);
generic_integer_impl_operation_assign!(RotateRightAssign(rotate_right_assign) => rotate_right_assign_parallelized);
generic_integer_impl_operation_assign!(DivAssign(div_assign) => div_assign_parallelized);
generic_integer_impl_operation_assign!(RemAssign(rem_assign) => rem_assign_parallelized);
generic_integer_impl_scalar_operation!(Add(add) => scalar_add_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation!(Sub(sub) => scalar_sub_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation!(Mul(mul) => scalar_mul_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation!(BitAnd(bitand) => scalar_bitand_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation!(BitOr(bitor) => scalar_bitor_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation!(BitXor(bitxor) => scalar_bitxor_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation!(Shl(shl) => scalar_left_shift_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation!(Shr(shr) => scalar_right_shift_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation!(RotateLeft(rotate_left) => scalar_rotate_left_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation!(RotateRight(rotate_right) => scalar_rotate_right_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation!(Div(div) => scalar_div_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation!(Rem(rem) => scalar_rem_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation_assign!(AddAssign(add_assign) => scalar_add_assign_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation_assign!(SubAssign(sub_assign) => scalar_sub_assign_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation_assign!(MulAssign(mul_assign) => scalar_mul_assign_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation_assign!(BitAndAssign(bitand_assign) => scalar_bitand_assign_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation_assign!(BitOrAssign(bitor_assign) => scalar_bitor_assign_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation_assign!(BitXorAssign(bitxor_assign) => scalar_bitxor_assign_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation_assign!(ShlAssign(shl_assign) => scalar_left_shift_assign_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation_assign!(ShrAssign(shr_assign) => scalar_right_shift_assign_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation_assign!(RotateLeftAssign(rotate_left_assign) => scalar_rotate_left_assign_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation_assign!(RotateRightAssign(rotate_right_assign) => scalar_rotate_right_assign_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation_assign!(DivAssign(div_assign) => scalar_div_assign_parallelized(u8, u16, u32, u64));
generic_integer_impl_scalar_operation_assign!(RemAssign(rem_assign) => scalar_rem_assign_parallelized(u8, u16, u32, u64));
impl<P> Neg for GenericInteger<P>
where

View File

@@ -6,7 +6,7 @@
//! use tfhe::prelude::*;
//! ```
pub use crate::high_level_api::traits::{
DynamicFheEncryptor, DynamicFheTrivialEncryptor, DynamicFheTryEncryptor, FheBootstrap,
DivRem, DynamicFheEncryptor, DynamicFheTrivialEncryptor, DynamicFheTryEncryptor, FheBootstrap,
FheDecrypt, FheEncrypt, FheEq, FheMax, FheMin, FheNumberConstant, FheOrd, FheTrivialEncrypt,
FheTryEncrypt, FheTryTrivialEncrypt, RotateLeft, RotateLeftAssign, RotateRight,
RotateRightAssign,

View File

@@ -164,3 +164,9 @@ pub trait RotateLeftAssign<Rhs = Self> {
pub trait RotateRightAssign<Rhs = Self> {
fn rotate_right_assign(&mut self, amount: Rhs);
}
pub trait DivRem<Rhs = Self> {
type Output;
fn div_rem(self, amount: Rhs) -> Self::Output;
}