mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-07 22:04:10 -05:00
chore(gpu): remove async entry points for abs, add, sub, aes
This commit is contained in:
@@ -47,22 +47,18 @@ pub mod cuda {
|
||||
const SBOX_PARALLELISM: usize = 16;
|
||||
let bench_id = format!("{param_name}::{NUM_AES_INPUTS}_input_encryption");
|
||||
|
||||
let round_keys = unsafe { sks.key_expansion_async(&d_key, &streams) };
|
||||
streams.synchronize();
|
||||
let round_keys = sks.key_expansion(&d_key, &streams);
|
||||
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
unsafe {
|
||||
black_box(sks.aes_encrypt_async(
|
||||
&d_iv,
|
||||
&round_keys,
|
||||
0,
|
||||
NUM_AES_INPUTS,
|
||||
SBOX_PARALLELISM,
|
||||
&streams,
|
||||
));
|
||||
}
|
||||
streams.synchronize();
|
||||
black_box(sks.aes_encrypt(
|
||||
&d_iv,
|
||||
&round_keys,
|
||||
0,
|
||||
NUM_AES_INPUTS,
|
||||
SBOX_PARALLELISM,
|
||||
&streams,
|
||||
));
|
||||
})
|
||||
});
|
||||
|
||||
@@ -82,10 +78,7 @@ pub mod cuda {
|
||||
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
unsafe {
|
||||
black_box(sks.key_expansion_async(&d_key, &streams));
|
||||
}
|
||||
streams.synchronize();
|
||||
black_box(sks.key_expansion(&d_key, &streams));
|
||||
})
|
||||
});
|
||||
|
||||
@@ -118,22 +111,18 @@ pub mod cuda {
|
||||
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams);
|
||||
let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams);
|
||||
|
||||
let round_keys = unsafe { sks.key_expansion_async(&d_key, &streams) };
|
||||
streams.synchronize();
|
||||
let round_keys = sks.key_expansion(&d_key, &streams);
|
||||
|
||||
bench_group.bench_function(&bench_id, |b| {
|
||||
b.iter(|| {
|
||||
unsafe {
|
||||
black_box(sks.aes_encrypt_async(
|
||||
&d_iv,
|
||||
&round_keys,
|
||||
0,
|
||||
NUM_AES_INPUTS,
|
||||
SBOX_PARALLELISM,
|
||||
&streams,
|
||||
));
|
||||
}
|
||||
streams.synchronize();
|
||||
black_box(sks.aes_encrypt(
|
||||
&d_iv,
|
||||
&round_keys,
|
||||
0,
|
||||
NUM_AES_INPUTS,
|
||||
SBOX_PARALLELISM,
|
||||
&streams,
|
||||
));
|
||||
})
|
||||
});
|
||||
|
||||
|
||||
@@ -7708,7 +7708,7 @@ pub(crate) unsafe fn cuda_backend_aes_key_expansion<T: UnsignedInteger, B: Numer
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) unsafe fn cuda_backend_get_aes_key_expansion_size_on_gpu(
|
||||
pub(crate) fn cuda_backend_get_aes_key_expansion_size_on_gpu(
|
||||
streams: &CudaStreams,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
@@ -7726,7 +7726,7 @@ pub(crate) unsafe fn cuda_backend_get_aes_key_expansion_size_on_gpu(
|
||||
let noise_reduction_type = resolve_noise_reduction_type(ms_noise_reduction_configuration);
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let size = {
|
||||
let size = unsafe {
|
||||
scratch_cuda_integer_key_expansion_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
|
||||
@@ -5,11 +5,7 @@ use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::{cuda_backend_unchecked_signed_abs_assign, PBSType};
|
||||
|
||||
impl CudaServerKey {
|
||||
/// # Safety
|
||||
///
|
||||
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
|
||||
/// synchronization is required
|
||||
pub unsafe fn unchecked_abs_assign_async<T>(&self, ct: &mut T, streams: &CudaStreams)
|
||||
pub fn unchecked_abs_assign<T>(&self, ct: &mut T, streams: &CudaStreams)
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
@@ -78,9 +74,8 @@ impl CudaServerKey {
|
||||
{
|
||||
let mut res = ct.duplicate(streams);
|
||||
if T::IS_SIGNED {
|
||||
unsafe { self.unchecked_abs_assign_async(&mut res, streams) };
|
||||
self.unchecked_abs_assign(&mut res, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
res
|
||||
}
|
||||
|
||||
@@ -139,9 +134,8 @@ impl CudaServerKey {
|
||||
self.full_propagate_assign(&mut res, streams);
|
||||
}
|
||||
if T::IS_SIGNED {
|
||||
unsafe { self.unchecked_abs_assign_async(&mut res, streams) };
|
||||
self.unchecked_abs_assign(&mut res, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,10 +27,6 @@ impl CudaServerKey {
|
||||
/// example) has always the same performance characteristics from one call to another and
|
||||
/// guarantees correctness by pre-emptively clearing carries of output ciphertexts.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// - Multithreaded
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
@@ -86,11 +82,7 @@ impl CudaServerKey {
|
||||
self.get_add_assign_size_on_gpu(ct_left, ct_right, streams)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn add_assign_async<T: CudaIntegerRadixCiphertext>(
|
||||
pub fn add_assign<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &mut T,
|
||||
ct_right: &T,
|
||||
@@ -121,25 +113,8 @@ impl CudaServerKey {
|
||||
}
|
||||
};
|
||||
|
||||
let _carry = self.add_and_propagate_single_carry_assign_async(
|
||||
lhs,
|
||||
rhs,
|
||||
streams,
|
||||
None,
|
||||
OutputFlag::None,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn add_assign<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &mut T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) {
|
||||
unsafe {
|
||||
self.add_assign_async(ct_left, ct_right, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
let _carry =
|
||||
self.add_and_propagate_single_carry_assign(lhs, rhs, streams, None, OutputFlag::None);
|
||||
}
|
||||
|
||||
pub fn get_add_assign_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
@@ -286,11 +261,7 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_add_assign_async<T: CudaIntegerRadixCiphertext>(
|
||||
pub fn unchecked_add_assign<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &mut T,
|
||||
ct_right: &T,
|
||||
@@ -319,22 +290,7 @@ impl CudaServerKey {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unchecked_add_assign<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &mut T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) {
|
||||
unsafe {
|
||||
self.unchecked_add_assign_async(ct_left, ct_right, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_partial_sum_ciphertexts_assign_async<T: CudaIntegerRadixCiphertext>(
|
||||
pub fn unchecked_partial_sum_ciphertexts_assign<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
result: &mut T,
|
||||
ciphertexts: &[T],
|
||||
@@ -345,11 +301,14 @@ impl CudaServerKey {
|
||||
return;
|
||||
}
|
||||
|
||||
result.as_mut().d_blocks.0.d_vec.copy_from_gpu_async(
|
||||
&ciphertexts[0].as_ref().d_blocks.0.d_vec,
|
||||
streams,
|
||||
0,
|
||||
);
|
||||
unsafe {
|
||||
result.as_mut().d_blocks.0.d_vec.copy_from_gpu_async(
|
||||
&ciphertexts[0].as_ref().d_blocks.0.d_vec,
|
||||
streams,
|
||||
0,
|
||||
);
|
||||
streams.synchronize();
|
||||
}
|
||||
result.as_mut().info = ciphertexts[0].as_ref().info.clone();
|
||||
if ciphertexts.len() == 1 {
|
||||
return;
|
||||
@@ -365,7 +324,7 @@ impl CudaServerKey {
|
||||
);
|
||||
|
||||
if ciphertexts.len() == 2 {
|
||||
self.add_assign_async(result, &ciphertexts[1], streams);
|
||||
self.add_assign(result, &ciphertexts[1], streams);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -373,58 +332,60 @@ impl CudaServerKey {
|
||||
|
||||
let mut terms = CudaRadixCiphertext::from_radix_ciphertext_vec(ciphertexts, streams);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_partial_sum_ciphertexts_assign(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
&mut terms,
|
||||
reduce_degrees_for_single_carry_propagation,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
num_blocks.0 as u32,
|
||||
radix_count_in_vec as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_partial_sum_ciphertexts_assign(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
&mut terms,
|
||||
reduce_degrees_for_single_carry_propagation,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
num_blocks.0 as u32,
|
||||
radix_count_in_vec as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_partial_sum_ciphertexts_assign(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
&mut terms,
|
||||
reduce_degrees_for_single_carry_propagation,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
num_blocks.0 as u32,
|
||||
radix_count_in_vec as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_partial_sum_ciphertexts_assign(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
&mut terms,
|
||||
reduce_degrees_for_single_carry_propagation,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
num_blocks.0 as u32,
|
||||
radix_count_in_vec as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -433,23 +394,9 @@ impl CudaServerKey {
|
||||
&self,
|
||||
ciphertexts: &[T],
|
||||
streams: &CudaStreams,
|
||||
) -> T {
|
||||
let result = unsafe { self.unchecked_sum_ciphertexts_async(ciphertexts, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_sum_ciphertexts_async<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ciphertexts: &[T],
|
||||
streams: &CudaStreams,
|
||||
) -> T {
|
||||
let mut result = self
|
||||
.unchecked_partial_sum_ciphertexts_async(ciphertexts, true, streams)
|
||||
.unchecked_partial_sum_ciphertexts(ciphertexts, true, streams)
|
||||
.unwrap();
|
||||
|
||||
self.propagate_single_carry_assign(&mut result, streams, None, OutputFlag::None);
|
||||
@@ -458,21 +405,6 @@ impl CudaServerKey {
|
||||
}
|
||||
|
||||
pub fn unchecked_partial_sum_ciphertexts<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ciphertexts: &[T],
|
||||
streams: &CudaStreams,
|
||||
) -> Option<T> {
|
||||
let result =
|
||||
unsafe { self.unchecked_partial_sum_ciphertexts_async(ciphertexts, false, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_partial_sum_ciphertexts_async<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ciphertexts: &[T],
|
||||
reduce_degrees_for_single_carry_propagation: bool,
|
||||
@@ -485,11 +417,10 @@ impl CudaServerKey {
|
||||
let mut result = ciphertexts[0].duplicate(streams);
|
||||
|
||||
if ciphertexts.len() == 1 {
|
||||
streams.synchronize();
|
||||
return Some(result);
|
||||
}
|
||||
|
||||
self.unchecked_partial_sum_ciphertexts_assign_async(
|
||||
self.unchecked_partial_sum_ciphertexts_assign(
|
||||
&mut result,
|
||||
ciphertexts,
|
||||
reduce_degrees_for_single_carry_propagation,
|
||||
@@ -500,20 +431,6 @@ impl CudaServerKey {
|
||||
}
|
||||
|
||||
pub fn sum_ciphertexts<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ciphertexts: Vec<T>,
|
||||
streams: &CudaStreams,
|
||||
) -> Option<T> {
|
||||
let res = unsafe { self.sum_ciphertexts_async(ciphertexts, streams) };
|
||||
streams.synchronize();
|
||||
res
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn sum_ciphertexts_async<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
mut ciphertexts: Vec<T>,
|
||||
streams: &CudaStreams,
|
||||
@@ -529,7 +446,7 @@ impl CudaServerKey {
|
||||
self.full_propagate_assign(&mut *ct, streams);
|
||||
});
|
||||
|
||||
Some(self.unchecked_sum_ciphertexts_async(&ciphertexts, streams))
|
||||
Some(self.unchecked_sum_ciphertexts(&ciphertexts, streams))
|
||||
}
|
||||
|
||||
/// ```rust
|
||||
@@ -620,38 +537,12 @@ impl CudaServerKey {
|
||||
lhs.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
rhs.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
let ct_res;
|
||||
let ct_overflowed;
|
||||
unsafe {
|
||||
(ct_res, ct_overflowed) =
|
||||
self.unchecked_unsigned_overflowing_add_async(lhs, rhs, stream);
|
||||
}
|
||||
stream.synchronize();
|
||||
|
||||
(ct_res, ct_overflowed)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_unsigned_overflowing_add_async(
|
||||
&self,
|
||||
lhs: &CudaUnsignedRadixCiphertext,
|
||||
rhs: &CudaUnsignedRadixCiphertext,
|
||||
stream: &CudaStreams,
|
||||
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) {
|
||||
let output_flag = OutputFlag::from_signedness(CudaUnsignedRadixCiphertext::IS_SIGNED);
|
||||
|
||||
let mut ct_res = lhs.duplicate(stream);
|
||||
let mut carry_out: CudaUnsignedRadixCiphertext = self
|
||||
.add_and_propagate_single_carry_assign_async(
|
||||
&mut ct_res,
|
||||
rhs,
|
||||
stream,
|
||||
None,
|
||||
output_flag,
|
||||
);
|
||||
let mut carry_out: CudaUnsignedRadixCiphertext =
|
||||
self.add_and_propagate_single_carry_assign(&mut ct_res, rhs, stream, None, output_flag);
|
||||
|
||||
if lhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO
|
||||
&& rhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO
|
||||
@@ -666,28 +557,43 @@ impl CudaServerKey {
|
||||
(ct_res, ct_overflowed)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_signed_overflowing_add_async(
|
||||
pub fn unchecked_signed_overflowing_add(
|
||||
&self,
|
||||
lhs: &CudaSignedRadixCiphertext,
|
||||
rhs: &CudaSignedRadixCiphertext,
|
||||
stream: &CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
|
||||
self.unchecked_signed_overflowing_add_with_input_carry(lhs, rhs, None, stream)
|
||||
}
|
||||
|
||||
pub fn unchecked_signed_overflowing_add_with_input_carry(
|
||||
&self,
|
||||
lhs: &CudaSignedRadixCiphertext,
|
||||
rhs: &CudaSignedRadixCiphertext,
|
||||
input_carry: Option<&CudaBooleanBlock>,
|
||||
stream: &CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
|
||||
assert_eq!(
|
||||
lhs.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
rhs.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
"lhs and rhs must have the name number of blocks ({} vs {})",
|
||||
lhs.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
rhs.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
assert!(
|
||||
lhs.as_ref().d_blocks.lwe_ciphertext_count().0 > 0,
|
||||
"inputs cannot be empty"
|
||||
);
|
||||
let output_flag = OutputFlag::from_signedness(CudaSignedRadixCiphertext::IS_SIGNED);
|
||||
|
||||
let mut ct_res = lhs.duplicate(stream);
|
||||
let carry_out: CudaSignedRadixCiphertext = self
|
||||
.add_and_propagate_single_carry_assign_async(
|
||||
&mut ct_res,
|
||||
rhs,
|
||||
stream,
|
||||
input_carry,
|
||||
output_flag,
|
||||
);
|
||||
let carry_out: CudaSignedRadixCiphertext = self.add_and_propagate_single_carry_assign(
|
||||
&mut ct_res,
|
||||
rhs,
|
||||
stream,
|
||||
input_carry,
|
||||
output_flag,
|
||||
);
|
||||
|
||||
let ct_overflowed = CudaBooleanBlock::from_cuda_radix_ciphertext(carry_out.ciphertext);
|
||||
|
||||
@@ -770,39 +676,7 @@ impl CudaServerKey {
|
||||
self.unchecked_signed_overflowing_add(lhs, rhs, stream)
|
||||
}
|
||||
|
||||
pub fn unchecked_signed_overflowing_add(
|
||||
&self,
|
||||
ct_left: &CudaSignedRadixCiphertext,
|
||||
ct_right: &CudaSignedRadixCiphertext,
|
||||
stream: &CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
|
||||
assert_eq!(
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
"lhs and rhs must have the name number of blocks ({} vs {})",
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
ct_right.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
assert!(
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0 > 0,
|
||||
"inputs cannot be empty"
|
||||
);
|
||||
|
||||
let result;
|
||||
let overflowed;
|
||||
unsafe {
|
||||
(result, overflowed) =
|
||||
self.unchecked_signed_overflowing_add_async(ct_left, ct_right, None, stream);
|
||||
};
|
||||
stream.synchronize();
|
||||
(result, overflowed)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub(crate) unsafe fn add_and_propagate_single_carry_assign_async<T>(
|
||||
pub(crate) fn add_and_propagate_single_carry_assign<T>(
|
||||
&self,
|
||||
lhs: &mut T,
|
||||
rhs: &T,
|
||||
@@ -821,58 +695,60 @@ impl CudaServerKey {
|
||||
let in_carry: &CudaRadixCiphertext =
|
||||
input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref());
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_add_and_propagate_single_carry_assign(
|
||||
streams,
|
||||
lhs.as_mut(),
|
||||
rhs.as_ref(),
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
requested_flag,
|
||||
uses_carry,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_add_and_propagate_single_carry_assign(
|
||||
streams,
|
||||
lhs.as_mut(),
|
||||
rhs.as_ref(),
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
requested_flag,
|
||||
uses_carry,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_add_and_propagate_single_carry_assign(
|
||||
streams,
|
||||
lhs.as_mut(),
|
||||
rhs.as_ref(),
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
requested_flag,
|
||||
uses_carry,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_add_and_propagate_single_carry_assign(
|
||||
streams,
|
||||
lhs.as_mut(),
|
||||
rhs.as_ref(),
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
requested_flag,
|
||||
uses_carry,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
carry_out
|
||||
|
||||
@@ -133,18 +133,15 @@ impl CudaServerKey {
|
||||
self.get_aes_encrypt_size_on_gpu(num_aes_inputs, parallelism, streams);
|
||||
|
||||
if check_valid_cuda_malloc(aes_encrypt_size, streams.gpu_indexes[0]) {
|
||||
let round_keys = unsafe { self.key_expansion_async(key, streams) };
|
||||
let res = unsafe {
|
||||
self.aes_encrypt_async(
|
||||
iv,
|
||||
&round_keys,
|
||||
start_counter,
|
||||
num_aes_inputs,
|
||||
parallelism,
|
||||
streams,
|
||||
)
|
||||
};
|
||||
streams.synchronize();
|
||||
let round_keys = self.key_expansion(key, streams);
|
||||
let res = self.aes_encrypt(
|
||||
iv,
|
||||
&round_keys,
|
||||
start_counter,
|
||||
num_aes_inputs,
|
||||
parallelism,
|
||||
streams,
|
||||
);
|
||||
return res;
|
||||
}
|
||||
parallelism /= 2;
|
||||
@@ -176,26 +173,18 @@ impl CudaServerKey {
|
||||
self.get_aes_encrypt_size_on_gpu(num_aes_inputs, sbox_parallelism, streams);
|
||||
check_valid_cuda_malloc_assert_oom(aes_encrypt_size, gpu_index);
|
||||
|
||||
let round_keys = unsafe { self.key_expansion_async(key, streams) };
|
||||
let res = unsafe {
|
||||
self.aes_encrypt_async(
|
||||
iv,
|
||||
&round_keys,
|
||||
start_counter,
|
||||
num_aes_inputs,
|
||||
sbox_parallelism,
|
||||
streams,
|
||||
)
|
||||
};
|
||||
streams.synchronize();
|
||||
res
|
||||
let round_keys = self.key_expansion(key, streams);
|
||||
self.aes_encrypt(
|
||||
iv,
|
||||
&round_keys,
|
||||
start_counter,
|
||||
num_aes_inputs,
|
||||
sbox_parallelism,
|
||||
streams,
|
||||
)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
|
||||
/// synchronization is required
|
||||
pub unsafe fn aes_encrypt_async(
|
||||
pub fn aes_encrypt(
|
||||
&self,
|
||||
iv: &CudaUnsignedRadixCiphertext,
|
||||
round_keys: &CudaUnsignedRadixCiphertext,
|
||||
@@ -229,79 +218,64 @@ impl CudaServerKey {
|
||||
result.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_aes_ctr_encrypt(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
iv.as_ref(),
|
||||
round_keys.as_ref(),
|
||||
start_counter,
|
||||
num_aes_inputs as u32,
|
||||
sbox_parallelism as u32,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_aes_ctr_encrypt(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
iv.as_ref(),
|
||||
round_keys.as_ref(),
|
||||
start_counter,
|
||||
num_aes_inputs as u32,
|
||||
sbox_parallelism as u32,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_aes_ctr_encrypt(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
iv.as_ref(),
|
||||
round_keys.as_ref(),
|
||||
start_counter,
|
||||
num_aes_inputs as u32,
|
||||
sbox_parallelism as u32,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_aes_ctr_encrypt(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
iv.as_ref(),
|
||||
round_keys.as_ref(),
|
||||
start_counter,
|
||||
num_aes_inputs as u32,
|
||||
sbox_parallelism as u32,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
fn get_aes_encrypt_size_on_gpu(
|
||||
&self,
|
||||
num_aes_inputs: usize,
|
||||
sbox_parallelism: usize,
|
||||
streams: &CudaStreams,
|
||||
) -> u64 {
|
||||
let size = unsafe {
|
||||
self.get_aes_encrypt_size_on_gpu_async(num_aes_inputs, sbox_parallelism, streams)
|
||||
};
|
||||
streams.synchronize();
|
||||
size
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
|
||||
/// synchronization is required
|
||||
unsafe fn get_aes_encrypt_size_on_gpu_async(
|
||||
pub fn get_aes_encrypt_size_on_gpu(
|
||||
&self,
|
||||
num_aes_inputs: usize,
|
||||
sbox_parallelism: usize,
|
||||
@@ -347,11 +321,7 @@ impl CudaServerKey {
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
|
||||
/// synchronization is required
|
||||
pub unsafe fn key_expansion_async(
|
||||
pub fn key_expansion(
|
||||
&self,
|
||||
key: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
@@ -369,64 +339,56 @@ impl CudaServerKey {
|
||||
key.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_aes_key_expansion(
|
||||
streams,
|
||||
expanded_keys.as_mut(),
|
||||
key.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_aes_key_expansion(
|
||||
streams,
|
||||
expanded_keys.as_mut(),
|
||||
key.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_aes_key_expansion(
|
||||
streams,
|
||||
expanded_keys.as_mut(),
|
||||
key.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_aes_key_expansion(
|
||||
streams,
|
||||
expanded_keys.as_mut(),
|
||||
key.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
expanded_keys
|
||||
}
|
||||
|
||||
fn get_key_expansion_size_on_gpu(&self, streams: &CudaStreams) -> u64 {
|
||||
let size = unsafe { self.get_key_expansion_size_on_gpu_async(streams) };
|
||||
streams.synchronize();
|
||||
size
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
|
||||
/// synchronization is required
|
||||
unsafe fn get_key_expansion_size_on_gpu_async(&self, streams: &CudaStreams) -> u64 {
|
||||
pub fn get_key_expansion_size_on_gpu(&self, streams: &CudaStreams) -> u64 {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_aes_key_expansion_size_on_gpu(
|
||||
streams,
|
||||
|
||||
@@ -62,41 +62,12 @@ impl CudaServerKey {
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> T {
|
||||
let result = unsafe { self.unchecked_sub_async(ct_left, ct_right, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_sub_async<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> T {
|
||||
let mut result = ct_left.duplicate(streams);
|
||||
self.unchecked_sub_assign_async(&mut result, ct_right, streams);
|
||||
self.unchecked_sub_assign(&mut result, ct_right, streams);
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_sub_assign_async<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &mut T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) {
|
||||
let neg = self.unchecked_neg_async(ct_right, streams);
|
||||
self.unchecked_add_assign_async(ct_left, &neg, streams);
|
||||
}
|
||||
|
||||
/// Computes homomorphically a subtraction between two ciphertexts encrypting integer values.
|
||||
///
|
||||
/// This function computes the subtraction without checking if it exceeds the capacity of the
|
||||
@@ -146,10 +117,8 @@ impl CudaServerKey {
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) {
|
||||
unsafe {
|
||||
self.unchecked_sub_assign_async(ct_left, ct_right, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
let neg = self.unchecked_neg(ct_right, streams);
|
||||
self.unchecked_add_assign(ct_left, &neg, streams);
|
||||
}
|
||||
|
||||
/// Computes homomorphically the subtraction between ct_left and ct_right.
|
||||
@@ -205,8 +174,8 @@ impl CudaServerKey {
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> T {
|
||||
let result = unsafe { self.sub_async(ct_left, ct_right, streams) };
|
||||
streams.synchronize();
|
||||
let mut result = ct_left.duplicate(streams);
|
||||
self.sub_assign(&mut result, ct_right, streams);
|
||||
result
|
||||
}
|
||||
|
||||
@@ -219,42 +188,11 @@ impl CudaServerKey {
|
||||
self.get_sub_assign_size_on_gpu(ct_left, ct_right, streams)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn sub_async<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) -> T {
|
||||
let mut result = ct_left.duplicate(streams);
|
||||
self.sub_assign_async(&mut result, ct_right, streams);
|
||||
result
|
||||
}
|
||||
|
||||
pub fn sub_assign<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &mut T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) {
|
||||
unsafe {
|
||||
self.sub_assign_async(ct_left, ct_right, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn sub_assign_async<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &mut T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) {
|
||||
let mut tmp_rhs;
|
||||
|
||||
@@ -345,27 +283,6 @@ impl CudaServerKey {
|
||||
lhs.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
rhs.as_ref().d_blocks.lwe_ciphertext_count().0
|
||||
);
|
||||
let ct_res;
|
||||
let ct_overflowed;
|
||||
unsafe {
|
||||
(ct_res, ct_overflowed) =
|
||||
self.unchecked_unsigned_overflowing_sub_async(lhs, rhs, stream);
|
||||
}
|
||||
stream.synchronize();
|
||||
|
||||
(ct_res, ct_overflowed)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_unsigned_overflowing_sub_async(
|
||||
&self,
|
||||
lhs: &CudaUnsignedRadixCiphertext,
|
||||
rhs: &CudaUnsignedRadixCiphertext,
|
||||
stream: &CudaStreams,
|
||||
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) {
|
||||
let mut ct_res = lhs.duplicate(stream);
|
||||
|
||||
let compute_overflow = true;
|
||||
@@ -380,56 +297,58 @@ impl CudaServerKey {
|
||||
let in_carry_dvec =
|
||||
INPUT_BORROW.map_or_else(|| aux_block.as_ref(), |block| block.as_ref().as_ref());
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
|
||||
stream,
|
||||
ciphertext,
|
||||
rhs.as_ref(),
|
||||
overflow_block.as_mut(),
|
||||
in_carry_dvec,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
ciphertext.info.blocks.first().unwrap().message_modulus,
|
||||
ciphertext.info.blocks.first().unwrap().carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
compute_overflow,
|
||||
uses_input_borrow,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
|
||||
stream,
|
||||
ciphertext,
|
||||
rhs.as_ref(),
|
||||
overflow_block.as_mut(),
|
||||
in_carry_dvec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
ciphertext.info.blocks.first().unwrap().message_modulus,
|
||||
ciphertext.info.blocks.first().unwrap().carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
compute_overflow,
|
||||
uses_input_borrow,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
|
||||
stream,
|
||||
ciphertext,
|
||||
rhs.as_ref(),
|
||||
overflow_block.as_mut(),
|
||||
in_carry_dvec,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
ciphertext.info.blocks.first().unwrap().message_modulus,
|
||||
ciphertext.info.blocks.first().unwrap().carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
compute_overflow,
|
||||
uses_input_borrow,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
|
||||
stream,
|
||||
ciphertext,
|
||||
rhs.as_ref(),
|
||||
overflow_block.as_mut(),
|
||||
in_carry_dvec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
ciphertext.info.blocks.first().unwrap().message_modulus,
|
||||
ciphertext.info.blocks.first().unwrap().carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
compute_overflow,
|
||||
uses_input_borrow,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
let ct_overflowed = CudaBooleanBlock::from_cuda_radix_ciphertext(overflow_block.ciphertext);
|
||||
@@ -437,11 +356,7 @@ impl CudaServerKey {
|
||||
(ct_res, ct_overflowed)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub(crate) unsafe fn sub_and_propagate_single_carry_assign<T>(
|
||||
pub(crate) fn sub_and_propagate_single_carry_assign<T>(
|
||||
&self,
|
||||
lhs: &mut T,
|
||||
rhs: &T,
|
||||
@@ -460,58 +375,60 @@ impl CudaServerKey {
|
||||
let in_carry: &CudaRadixCiphertext =
|
||||
input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref());
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_sub_and_propagate_single_carry_assign(
|
||||
streams,
|
||||
lhs.as_mut(),
|
||||
rhs.as_ref(),
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
requested_flag,
|
||||
uses_carry,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_sub_and_propagate_single_carry_assign(
|
||||
streams,
|
||||
lhs.as_mut(),
|
||||
rhs.as_ref(),
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
requested_flag,
|
||||
uses_carry,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_sub_and_propagate_single_carry_assign(
|
||||
streams,
|
||||
lhs.as_mut(),
|
||||
rhs.as_ref(),
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
d_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
requested_flag,
|
||||
uses_carry,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_sub_and_propagate_single_carry_assign(
|
||||
streams,
|
||||
lhs.as_mut(),
|
||||
rhs.as_ref(),
|
||||
carry_out.as_mut(),
|
||||
in_carry,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
num_blocks,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
requested_flag,
|
||||
uses_carry,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
carry_out
|
||||
@@ -611,30 +528,11 @@ impl CudaServerKey {
|
||||
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0 > 0,
|
||||
"inputs cannot be empty"
|
||||
);
|
||||
let result;
|
||||
let overflowed;
|
||||
unsafe {
|
||||
(result, overflowed) =
|
||||
self.unchecked_signed_overflowing_sub_async(ct_left, ct_right, stream);
|
||||
};
|
||||
stream.synchronize();
|
||||
(result, overflowed)
|
||||
}
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_signed_overflowing_sub_async(
|
||||
&self,
|
||||
ct_left: &CudaSignedRadixCiphertext,
|
||||
ct_right: &CudaSignedRadixCiphertext,
|
||||
stream: &CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
|
||||
let flipped_rhs = self.bitnot(ct_right, stream);
|
||||
let ct_input_carry: CudaUnsignedRadixCiphertext = self.create_trivial_radix(1, 1, stream);
|
||||
let input_carry = CudaBooleanBlock::from_cuda_radix_ciphertext(ct_input_carry.ciphertext);
|
||||
|
||||
self.unchecked_signed_overflowing_add_async(
|
||||
self.unchecked_signed_overflowing_add_with_input_carry(
|
||||
ct_left,
|
||||
&flipped_rhs,
|
||||
Some(&input_carry),
|
||||
|
||||
Reference in New Issue
Block a user