mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(gpu): Add signed_overflowing_scalar_add and signed_overflowing_scalar_sub
This commit is contained in:
committed by
Agnès Leroy
parent
230fa5a8f0
commit
95ef13f6ce
@@ -1801,6 +1801,18 @@ mod cuda {
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: signed_overflowing_scalar_add,
|
||||
display_name: overflowing_add,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
define_cuda_server_key_bench_clean_input_scalar_signed_fn!(
|
||||
method_name: signed_overflowing_scalar_sub,
|
||||
display_name: overflowing_sub,
|
||||
rng_func: default_signed_scalar
|
||||
);
|
||||
|
||||
//===========================================
|
||||
// Default
|
||||
//===========================================
|
||||
@@ -2132,6 +2144,8 @@ mod cuda {
|
||||
cuda_scalar_le,
|
||||
cuda_scalar_min,
|
||||
cuda_scalar_max,
|
||||
cuda_signed_overflowing_scalar_add,
|
||||
cuda_signed_overflowing_scalar_sub,
|
||||
);
|
||||
|
||||
fn cuda_bench_server_key_signed_cast_function<F>(
|
||||
|
||||
@@ -140,9 +140,14 @@ where
|
||||
(FheInt::new(result), FheBool::new(overflow))
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(_) => {
|
||||
todo!("Cuda devices do not support signed integer");
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let (result, overflow) = cuda_key.key.signed_overflowing_scalar_add(
|
||||
&self.ciphertext.on_gpu(),
|
||||
other,
|
||||
streams,
|
||||
);
|
||||
(FheInt::new(result), FheBool::new(overflow))
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -349,9 +354,14 @@ where
|
||||
(FheInt::new(result), FheBool::new(overflow))
|
||||
}
|
||||
#[cfg(feature = "gpu")]
|
||||
InternalServerKey::Cuda(_) => {
|
||||
todo!("Cuda devices do not support signed integer");
|
||||
}
|
||||
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
|
||||
let (result, overflow) = cuda_key.key.signed_overflowing_scalar_sub(
|
||||
&self.ciphertext.on_gpu(),
|
||||
other,
|
||||
streams,
|
||||
);
|
||||
(FheInt::new(result), FheBool::new(overflow))
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
use crate::core_crypto::gpu::vec::CudaVec;
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::SignedNumeric;
|
||||
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::ciphertext::{
|
||||
CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext,
|
||||
};
|
||||
use crate::integer::gpu::scalar_addition_integer_radix_assign_async;
|
||||
use crate::integer::gpu::server_key::CudaServerKey;
|
||||
use crate::prelude::CastInto;
|
||||
@@ -279,4 +282,85 @@ impl CudaServerKey {
|
||||
CudaBooleanBlock::from_cuda_radix_ciphertext(carry_out.ciphertext)
|
||||
}
|
||||
}
|
||||
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::gpu::CudaStreams;
|
||||
/// use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
///
|
||||
/// let gpu_index = 0;
|
||||
/// let streams = CudaStreams::new_single_gpu(gpu_index);
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let num_blocks = 4;
|
||||
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &streams);
|
||||
///
|
||||
/// let msg: i8 = 120;
|
||||
/// let scalar: i8 = 8;
|
||||
///
|
||||
/// let ct1 = cks.encrypt_signed(msg);
|
||||
///
|
||||
/// // Copy to GPU
|
||||
/// let d_ct1 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct1, &streams);
|
||||
///
|
||||
/// // Compute homomorphically an overflowing addition:
|
||||
/// let (d_ct_res, d_ct_overflowed) = sks.signed_overflowing_scalar_add(&d_ct1, scalar, &streams);
|
||||
///
|
||||
/// let ct_res = d_ct_res.to_signed_radix_ciphertext(&streams);
|
||||
/// let ct_overflowed = d_ct_overflowed.to_boolean_block(&streams);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result: i8 = cks.decrypt_signed(&ct_res);
|
||||
/// let dec_overflowed: bool = cks.decrypt_bool(&ct_overflowed);
|
||||
/// let (clear_result, clear_overflowed) = msg.overflowing_add(scalar);
|
||||
/// assert_eq!(dec_result, clear_result);
|
||||
/// assert_eq!(dec_overflowed, clear_overflowed);
|
||||
/// ```
|
||||
pub fn signed_overflowing_scalar_add<Scalar>(
|
||||
&self,
|
||||
ct_left: &CudaSignedRadixCiphertext,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock)
|
||||
where
|
||||
Scalar: SignedNumeric + DecomposableInto<u64> + CastInto<u64>,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
unsafe {
|
||||
tmp_lhs = ct_left.duplicate_async(streams);
|
||||
if !tmp_lhs.block_carries_are_empty() {
|
||||
self.full_propagate_assign_async(&mut tmp_lhs, streams);
|
||||
}
|
||||
}
|
||||
|
||||
let trivial: CudaSignedRadixCiphertext = self.create_trivial_radix(
|
||||
scalar,
|
||||
ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0,
|
||||
streams,
|
||||
);
|
||||
let (result, overflowed) = self.signed_overflowing_add(&tmp_lhs, &trivial, streams);
|
||||
|
||||
let mut extra_scalar_block_iter =
|
||||
BlockDecomposer::new(scalar, self.message_modulus.0.ilog2())
|
||||
.iter_as::<u64>()
|
||||
.skip(ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0);
|
||||
|
||||
let extra_blocks_have_correct_value = if scalar < Scalar::ZERO {
|
||||
extra_scalar_block_iter.all(|block| block == (self.message_modulus.0 as u64 - 1))
|
||||
} else {
|
||||
extra_scalar_block_iter.all(|block| block == 0)
|
||||
};
|
||||
|
||||
if extra_blocks_have_correct_value {
|
||||
(result, overflowed)
|
||||
} else {
|
||||
let trivial_one: CudaSignedRadixCiphertext = self.create_trivial_radix(1, 1, streams);
|
||||
// Scalar has more blocks so addition counts as overflowing
|
||||
(
|
||||
result,
|
||||
CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_one.ciphertext),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::Numeric;
|
||||
use crate::integer::block_decomposition::DecomposableInto;
|
||||
use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext;
|
||||
use crate::core_crypto::prelude::{Numeric, SignedNumeric};
|
||||
use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto};
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaSignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::CudaServerKey;
|
||||
use crate::integer::server_key::TwosComplementNegation;
|
||||
use crate::prelude::CastInto;
|
||||
@@ -163,4 +164,85 @@ impl CudaServerKey {
|
||||
}
|
||||
stream.synchronize();
|
||||
}
|
||||
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::gpu::CudaStreams;
|
||||
/// use tfhe::integer::gpu::ciphertext::{CudaSignedRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
/// use tfhe::integer::gpu::gen_keys_radix_gpu;
|
||||
/// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
|
||||
///
|
||||
/// let gpu_index = 0;
|
||||
/// let streams = CudaStreams::new_single_gpu(gpu_index);
|
||||
///
|
||||
/// // Generate the client key and the server key:
|
||||
/// let num_blocks = 4;
|
||||
/// let (cks, sks) = gen_keys_radix_gpu(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks, &streams);
|
||||
///
|
||||
/// let msg: i8 = 120;
|
||||
/// let scalar: i8 = 8;
|
||||
///
|
||||
/// let ct1 = cks.encrypt_signed(msg);
|
||||
///
|
||||
/// // Copy to GPU
|
||||
/// let d_ct1 = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&ct1, &streams);
|
||||
///
|
||||
/// // Compute homomorphically an overflowing addition:
|
||||
/// let (d_ct_res, d_ct_overflowed) = sks.signed_overflowing_scalar_sub(&d_ct1, scalar, &streams);
|
||||
///
|
||||
/// let ct_res = d_ct_res.to_signed_radix_ciphertext(&streams);
|
||||
/// let ct_overflowed = d_ct_overflowed.to_boolean_block(&streams);
|
||||
///
|
||||
/// // Decrypt:
|
||||
/// let dec_result: i8 = cks.decrypt_signed(&ct_res);
|
||||
/// let dec_overflowed: bool = cks.decrypt_bool(&ct_overflowed);
|
||||
/// let (clear_result, clear_overflowed) = msg.overflowing_sub(scalar);
|
||||
/// assert_eq!(dec_result, clear_result);
|
||||
/// assert_eq!(dec_overflowed, clear_overflowed);
|
||||
/// ```
|
||||
pub fn signed_overflowing_scalar_sub<Scalar>(
|
||||
&self,
|
||||
ct_left: &CudaSignedRadixCiphertext,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock)
|
||||
where
|
||||
Scalar: SignedNumeric + DecomposableInto<u64> + CastInto<u64>,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
unsafe {
|
||||
tmp_lhs = ct_left.duplicate_async(streams);
|
||||
if !tmp_lhs.block_carries_are_empty() {
|
||||
self.full_propagate_assign_async(&mut tmp_lhs, streams);
|
||||
}
|
||||
}
|
||||
|
||||
let trivial: CudaSignedRadixCiphertext = self.create_trivial_radix(
|
||||
scalar,
|
||||
ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0,
|
||||
streams,
|
||||
);
|
||||
let (result, overflowed) = self.signed_overflowing_sub(&tmp_lhs, &trivial, streams);
|
||||
|
||||
let mut extra_scalar_block_iter =
|
||||
BlockDecomposer::new(scalar, self.message_modulus.0.ilog2())
|
||||
.iter_as::<u64>()
|
||||
.skip(ct_left.ciphertext.d_blocks.lwe_ciphertext_count().0);
|
||||
|
||||
let extra_blocks_have_correct_value = if scalar < Scalar::ZERO {
|
||||
extra_scalar_block_iter.all(|block| block == (self.message_modulus.0 as u64 - 1))
|
||||
} else {
|
||||
extra_scalar_block_iter.all(|block| block == 0)
|
||||
};
|
||||
|
||||
if extra_blocks_have_correct_value {
|
||||
(result, overflowed)
|
||||
} else {
|
||||
let trivial_one: CudaSignedRadixCiphertext = self.create_trivial_radix(1, 1, streams);
|
||||
// Scalar has more blocks so addition counts as overflowing
|
||||
(
|
||||
result,
|
||||
CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_one.ciphertext),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -389,3 +389,40 @@ where
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// for signed overflowing scalar ops
|
||||
impl<'a, F>
|
||||
FunctionExecutor<(&'a SignedRadixCiphertext, i64), (SignedRadixCiphertext, BooleanBlock)>
|
||||
for GpuFunctionExecutor<F>
|
||||
where
|
||||
F: Fn(
|
||||
&CudaServerKey,
|
||||
&CudaSignedRadixCiphertext,
|
||||
i64,
|
||||
&CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock),
|
||||
{
|
||||
fn setup(&mut self, cks: &RadixClientKey, sks: Arc<ServerKey>) {
|
||||
self.setup_from_keys(cks, &sks);
|
||||
}
|
||||
|
||||
fn execute(
|
||||
&mut self,
|
||||
input: (&'a SignedRadixCiphertext, i64),
|
||||
) -> (SignedRadixCiphertext, BooleanBlock) {
|
||||
let context = self
|
||||
.context
|
||||
.as_ref()
|
||||
.expect("setup was not properly called");
|
||||
|
||||
let d_ctxt_1: CudaSignedRadixCiphertext =
|
||||
CudaSignedRadixCiphertext::from_signed_radix_ciphertext(input.0, &context.streams);
|
||||
|
||||
let (d_res, d_res_bool) = (self.func)(&context.sks, &d_ctxt_1, input.1, &context.streams);
|
||||
|
||||
(
|
||||
d_res.to_signed_radix_ciphertext(&context.streams),
|
||||
d_res_bool.to_boolean_block(&context.streams),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,12 +3,14 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{
|
||||
};
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_add::{
|
||||
signed_default_scalar_add_test, signed_unchecked_scalar_add_test,
|
||||
signed_default_overflowing_scalar_add_test, signed_default_scalar_add_test,
|
||||
signed_unchecked_scalar_add_test,
|
||||
};
|
||||
use crate::shortint::parameters::*;
|
||||
|
||||
create_gpu_parametrized_test!(integer_signed_unchecked_scalar_add);
|
||||
create_gpu_parametrized_test!(integer_signed_scalar_add);
|
||||
create_gpu_parametrized_test!(integer_signed_overflowing_scalar_add);
|
||||
|
||||
fn integer_signed_unchecked_scalar_add<P>(param: P)
|
||||
where
|
||||
@@ -25,3 +27,11 @@ where
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::scalar_add);
|
||||
signed_default_scalar_add_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_signed_overflowing_scalar_add<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::signed_overflowing_scalar_add);
|
||||
signed_default_overflowing_scalar_add_test(param, executor);
|
||||
}
|
||||
|
||||
@@ -2,10 +2,13 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::{
|
||||
create_gpu_parametrized_test, GpuFunctionExecutor,
|
||||
};
|
||||
use crate::integer::gpu::CudaServerKey;
|
||||
use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_sub::signed_unchecked_scalar_sub_test;
|
||||
use crate::integer::server_key::radix_parallel::tests_signed::test_scalar_sub::{
|
||||
signed_default_overflowing_scalar_sub_test, signed_unchecked_scalar_sub_test,
|
||||
};
|
||||
use crate::shortint::parameters::*;
|
||||
|
||||
create_gpu_parametrized_test!(integer_signed_unchecked_scalar_sub);
|
||||
create_gpu_parametrized_test!(integer_signed_overflowing_scalar_sub);
|
||||
|
||||
fn integer_signed_unchecked_scalar_sub<P>(param: P)
|
||||
where
|
||||
@@ -14,3 +17,11 @@ where
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::unchecked_scalar_sub);
|
||||
signed_unchecked_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
fn integer_signed_overflowing_scalar_sub<P>(param: P)
|
||||
where
|
||||
P: Into<PBSParameters>,
|
||||
{
|
||||
let executor = GpuFunctionExecutor::new(&CudaServerKey::signed_overflowing_scalar_sub);
|
||||
signed_default_overflowing_scalar_sub_test(param, executor);
|
||||
}
|
||||
|
||||
@@ -195,7 +195,7 @@ where
|
||||
let ctxt_0 = cks.encrypt_signed(clear_0);
|
||||
|
||||
let (ct_res, result_overflowed) = executor.execute((&ctxt_0, clear_1));
|
||||
let (tmp_ct, tmp_o) = sks.signed_overflowing_scalar_add_parallelized(&ctxt_0, clear_1);
|
||||
let (tmp_ct, tmp_o) = executor.execute((&ctxt_0, clear_1));
|
||||
assert!(ct_res.block_carries_are_empty());
|
||||
assert_eq!(ct_res, tmp_ct, "Failed determinism check");
|
||||
assert_eq!(tmp_o, result_overflowed, "Failed determinism check");
|
||||
|
||||
Reference in New Issue
Block a user