chore(gpu): remove async entry points for abs, add, sub, aes

This commit is contained in:
Agnes Leroy
2025-10-17 12:21:50 +02:00
committed by Agnès Leroy
parent 70b0c0ff19
commit c30835fc30
6 changed files with 411 additions and 692 deletions

View File

@@ -47,22 +47,18 @@ pub mod cuda {
const SBOX_PARALLELISM: usize = 16; const SBOX_PARALLELISM: usize = 16;
let bench_id = format!("{param_name}::{NUM_AES_INPUTS}_input_encryption"); let bench_id = format!("{param_name}::{NUM_AES_INPUTS}_input_encryption");
let round_keys = unsafe { sks.key_expansion_async(&d_key, &streams) }; let round_keys = sks.key_expansion(&d_key, &streams);
streams.synchronize();
bench_group.bench_function(&bench_id, |b| { bench_group.bench_function(&bench_id, |b| {
b.iter(|| { b.iter(|| {
unsafe { black_box(sks.aes_encrypt(
black_box(sks.aes_encrypt_async( &d_iv,
&d_iv, &round_keys,
&round_keys, 0,
0, NUM_AES_INPUTS,
NUM_AES_INPUTS, SBOX_PARALLELISM,
SBOX_PARALLELISM, &streams,
&streams, ));
));
}
streams.synchronize();
}) })
}); });
@@ -82,10 +78,7 @@ pub mod cuda {
bench_group.bench_function(&bench_id, |b| { bench_group.bench_function(&bench_id, |b| {
b.iter(|| { b.iter(|| {
unsafe { black_box(sks.key_expansion(&d_key, &streams));
black_box(sks.key_expansion_async(&d_key, &streams));
}
streams.synchronize();
}) })
}); });
@@ -118,22 +111,18 @@ pub mod cuda {
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams); let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams);
let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams); let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams);
let round_keys = unsafe { sks.key_expansion_async(&d_key, &streams) }; let round_keys = sks.key_expansion(&d_key, &streams);
streams.synchronize();
bench_group.bench_function(&bench_id, |b| { bench_group.bench_function(&bench_id, |b| {
b.iter(|| { b.iter(|| {
unsafe { black_box(sks.aes_encrypt(
black_box(sks.aes_encrypt_async( &d_iv,
&d_iv, &round_keys,
&round_keys, 0,
0, NUM_AES_INPUTS,
NUM_AES_INPUTS, SBOX_PARALLELISM,
SBOX_PARALLELISM, &streams,
&streams, ));
));
}
streams.synchronize();
}) })
}); });

View File

@@ -7708,7 +7708,7 @@ pub(crate) unsafe fn cuda_backend_aes_key_expansion<T: UnsignedInteger, B: Numer
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub(crate) unsafe fn cuda_backend_get_aes_key_expansion_size_on_gpu( pub(crate) fn cuda_backend_get_aes_key_expansion_size_on_gpu(
streams: &CudaStreams, streams: &CudaStreams,
message_modulus: MessageModulus, message_modulus: MessageModulus,
carry_modulus: CarryModulus, carry_modulus: CarryModulus,
@@ -7726,7 +7726,7 @@ pub(crate) unsafe fn cuda_backend_get_aes_key_expansion_size_on_gpu(
let noise_reduction_type = resolve_noise_reduction_type(ms_noise_reduction_configuration); let noise_reduction_type = resolve_noise_reduction_type(ms_noise_reduction_configuration);
let mut mem_ptr: *mut i8 = std::ptr::null_mut(); let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let size = { let size = unsafe {
scratch_cuda_integer_key_expansion_64( scratch_cuda_integer_key_expansion_64(
streams.ffi(), streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr), std::ptr::addr_of_mut!(mem_ptr),

View File

@@ -5,11 +5,7 @@ use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::{cuda_backend_unchecked_signed_abs_assign, PBSType}; use crate::integer::gpu::{cuda_backend_unchecked_signed_abs_assign, PBSType};
impl CudaServerKey { impl CudaServerKey {
/// # Safety pub fn unchecked_abs_assign<T>(&self, ct: &mut T, streams: &CudaStreams)
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
/// synchronization is required
pub unsafe fn unchecked_abs_assign_async<T>(&self, ct: &mut T, streams: &CudaStreams)
where where
T: CudaIntegerRadixCiphertext, T: CudaIntegerRadixCiphertext,
{ {
@@ -78,9 +74,8 @@ impl CudaServerKey {
{ {
let mut res = ct.duplicate(streams); let mut res = ct.duplicate(streams);
if T::IS_SIGNED { if T::IS_SIGNED {
unsafe { self.unchecked_abs_assign_async(&mut res, streams) }; self.unchecked_abs_assign(&mut res, streams);
} }
streams.synchronize();
res res
} }
@@ -139,9 +134,8 @@ impl CudaServerKey {
self.full_propagate_assign(&mut res, streams); self.full_propagate_assign(&mut res, streams);
} }
if T::IS_SIGNED { if T::IS_SIGNED {
unsafe { self.unchecked_abs_assign_async(&mut res, streams) }; self.unchecked_abs_assign(&mut res, streams);
} }
streams.synchronize();
res res
} }
} }

View File

@@ -27,10 +27,6 @@ impl CudaServerKey {
/// example) has always the same performance characteristics from one call to another and /// example) has always the same performance characteristics from one call to another and
/// guarantees correctness by pre-emptively clearing carries of output ciphertexts. /// guarantees correctness by pre-emptively clearing carries of output ciphertexts.
/// ///
/// # Warning
///
/// - Multithreaded
///
/// # Example /// # Example
/// ///
/// ```rust /// ```rust
@@ -86,11 +82,7 @@ impl CudaServerKey {
self.get_add_assign_size_on_gpu(ct_left, ct_right, streams) self.get_add_assign_size_on_gpu(ct_left, ct_right, streams)
} }
/// # Safety pub fn add_assign<T: CudaIntegerRadixCiphertext>(
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn add_assign_async<T: CudaIntegerRadixCiphertext>(
&self, &self,
ct_left: &mut T, ct_left: &mut T,
ct_right: &T, ct_right: &T,
@@ -121,25 +113,8 @@ impl CudaServerKey {
} }
}; };
let _carry = self.add_and_propagate_single_carry_assign_async( let _carry =
lhs, self.add_and_propagate_single_carry_assign(lhs, rhs, streams, None, OutputFlag::None);
rhs,
streams,
None,
OutputFlag::None,
);
}
pub fn add_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) {
unsafe {
self.add_assign_async(ct_left, ct_right, streams);
}
streams.synchronize();
} }
pub fn get_add_assign_size_on_gpu<T: CudaIntegerRadixCiphertext>( pub fn get_add_assign_size_on_gpu<T: CudaIntegerRadixCiphertext>(
@@ -286,11 +261,7 @@ impl CudaServerKey {
result result
} }
/// # Safety pub fn unchecked_add_assign<T: CudaIntegerRadixCiphertext>(
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_add_assign_async<T: CudaIntegerRadixCiphertext>(
&self, &self,
ct_left: &mut T, ct_left: &mut T,
ct_right: &T, ct_right: &T,
@@ -319,22 +290,7 @@ impl CudaServerKey {
} }
} }
pub fn unchecked_add_assign<T: CudaIntegerRadixCiphertext>( pub fn unchecked_partial_sum_ciphertexts_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) {
unsafe {
self.unchecked_add_assign_async(ct_left, ct_right, streams);
}
streams.synchronize();
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_partial_sum_ciphertexts_assign_async<T: CudaIntegerRadixCiphertext>(
&self, &self,
result: &mut T, result: &mut T,
ciphertexts: &[T], ciphertexts: &[T],
@@ -345,11 +301,14 @@ impl CudaServerKey {
return; return;
} }
result.as_mut().d_blocks.0.d_vec.copy_from_gpu_async( unsafe {
&ciphertexts[0].as_ref().d_blocks.0.d_vec, result.as_mut().d_blocks.0.d_vec.copy_from_gpu_async(
streams, &ciphertexts[0].as_ref().d_blocks.0.d_vec,
0, streams,
); 0,
);
streams.synchronize();
}
result.as_mut().info = ciphertexts[0].as_ref().info.clone(); result.as_mut().info = ciphertexts[0].as_ref().info.clone();
if ciphertexts.len() == 1 { if ciphertexts.len() == 1 {
return; return;
@@ -365,7 +324,7 @@ impl CudaServerKey {
); );
if ciphertexts.len() == 2 { if ciphertexts.len() == 2 {
self.add_assign_async(result, &ciphertexts[1], streams); self.add_assign(result, &ciphertexts[1], streams);
return; return;
} }
@@ -373,58 +332,60 @@ impl CudaServerKey {
let mut terms = CudaRadixCiphertext::from_radix_ciphertext_vec(ciphertexts, streams); let mut terms = CudaRadixCiphertext::from_radix_ciphertext_vec(ciphertexts, streams);
match &self.bootstrapping_key { unsafe {
CudaBootstrappingKey::Classic(d_bsk) => { match &self.bootstrapping_key {
cuda_backend_unchecked_partial_sum_ciphertexts_assign( CudaBootstrappingKey::Classic(d_bsk) => {
streams, cuda_backend_unchecked_partial_sum_ciphertexts_assign(
result.as_mut(), streams,
&mut terms, result.as_mut(),
reduce_degrees_for_single_carry_propagation, &mut terms,
&d_bsk.d_vec, reduce_degrees_for_single_carry_propagation,
&self.key_switching_key.d_vec, &d_bsk.d_vec,
self.message_modulus, &self.key_switching_key.d_vec,
self.carry_modulus, self.message_modulus,
d_bsk.glwe_dimension, self.carry_modulus,
d_bsk.polynomial_size, d_bsk.glwe_dimension,
self.key_switching_key d_bsk.polynomial_size,
.output_key_lwe_size() self.key_switching_key
.to_lwe_dimension(), .output_key_lwe_size()
self.key_switching_key.decomposition_level_count(), .to_lwe_dimension(),
self.key_switching_key.decomposition_base_log(), self.key_switching_key.decomposition_level_count(),
d_bsk.decomp_level_count, self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_base_log, d_bsk.decomp_level_count,
num_blocks.0 as u32, d_bsk.decomp_base_log,
radix_count_in_vec as u32, num_blocks.0 as u32,
PBSType::Classical, radix_count_in_vec as u32,
LweBskGroupingFactor(0), PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(), LweBskGroupingFactor(0),
); d_bsk.ms_noise_reduction_configuration.as_ref(),
} );
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { }
cuda_backend_unchecked_partial_sum_ciphertexts_assign( CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
streams, cuda_backend_unchecked_partial_sum_ciphertexts_assign(
result.as_mut(), streams,
&mut terms, result.as_mut(),
reduce_degrees_for_single_carry_propagation, &mut terms,
&d_multibit_bsk.d_vec, reduce_degrees_for_single_carry_propagation,
&self.key_switching_key.d_vec, &d_multibit_bsk.d_vec,
self.message_modulus, &self.key_switching_key.d_vec,
self.carry_modulus, self.message_modulus,
d_multibit_bsk.glwe_dimension, self.carry_modulus,
d_multibit_bsk.polynomial_size, d_multibit_bsk.glwe_dimension,
self.key_switching_key d_multibit_bsk.polynomial_size,
.output_key_lwe_size() self.key_switching_key
.to_lwe_dimension(), .output_key_lwe_size()
self.key_switching_key.decomposition_level_count(), .to_lwe_dimension(),
self.key_switching_key.decomposition_base_log(), self.key_switching_key.decomposition_level_count(),
d_multibit_bsk.decomp_level_count, self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_base_log, d_multibit_bsk.decomp_level_count,
num_blocks.0 as u32, d_multibit_bsk.decomp_base_log,
radix_count_in_vec as u32, num_blocks.0 as u32,
PBSType::MultiBit, radix_count_in_vec as u32,
d_multibit_bsk.grouping_factor, PBSType::MultiBit,
None, d_multibit_bsk.grouping_factor,
); None,
);
}
} }
} }
} }
@@ -433,23 +394,9 @@ impl CudaServerKey {
&self, &self,
ciphertexts: &[T], ciphertexts: &[T],
streams: &CudaStreams, streams: &CudaStreams,
) -> T {
let result = unsafe { self.unchecked_sum_ciphertexts_async(ciphertexts, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_sum_ciphertexts_async<T: CudaIntegerRadixCiphertext>(
&self,
ciphertexts: &[T],
streams: &CudaStreams,
) -> T { ) -> T {
let mut result = self let mut result = self
.unchecked_partial_sum_ciphertexts_async(ciphertexts, true, streams) .unchecked_partial_sum_ciphertexts(ciphertexts, true, streams)
.unwrap(); .unwrap();
self.propagate_single_carry_assign(&mut result, streams, None, OutputFlag::None); self.propagate_single_carry_assign(&mut result, streams, None, OutputFlag::None);
@@ -458,21 +405,6 @@ impl CudaServerKey {
} }
pub fn unchecked_partial_sum_ciphertexts<T: CudaIntegerRadixCiphertext>( pub fn unchecked_partial_sum_ciphertexts<T: CudaIntegerRadixCiphertext>(
&self,
ciphertexts: &[T],
streams: &CudaStreams,
) -> Option<T> {
let result =
unsafe { self.unchecked_partial_sum_ciphertexts_async(ciphertexts, false, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_partial_sum_ciphertexts_async<T: CudaIntegerRadixCiphertext>(
&self, &self,
ciphertexts: &[T], ciphertexts: &[T],
reduce_degrees_for_single_carry_propagation: bool, reduce_degrees_for_single_carry_propagation: bool,
@@ -485,11 +417,10 @@ impl CudaServerKey {
let mut result = ciphertexts[0].duplicate(streams); let mut result = ciphertexts[0].duplicate(streams);
if ciphertexts.len() == 1 { if ciphertexts.len() == 1 {
streams.synchronize();
return Some(result); return Some(result);
} }
self.unchecked_partial_sum_ciphertexts_assign_async( self.unchecked_partial_sum_ciphertexts_assign(
&mut result, &mut result,
ciphertexts, ciphertexts,
reduce_degrees_for_single_carry_propagation, reduce_degrees_for_single_carry_propagation,
@@ -500,20 +431,6 @@ impl CudaServerKey {
} }
pub fn sum_ciphertexts<T: CudaIntegerRadixCiphertext>( pub fn sum_ciphertexts<T: CudaIntegerRadixCiphertext>(
&self,
ciphertexts: Vec<T>,
streams: &CudaStreams,
) -> Option<T> {
let res = unsafe { self.sum_ciphertexts_async(ciphertexts, streams) };
streams.synchronize();
res
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn sum_ciphertexts_async<T: CudaIntegerRadixCiphertext>(
&self, &self,
mut ciphertexts: Vec<T>, mut ciphertexts: Vec<T>,
streams: &CudaStreams, streams: &CudaStreams,
@@ -529,7 +446,7 @@ impl CudaServerKey {
self.full_propagate_assign(&mut *ct, streams); self.full_propagate_assign(&mut *ct, streams);
}); });
Some(self.unchecked_sum_ciphertexts_async(&ciphertexts, streams)) Some(self.unchecked_sum_ciphertexts(&ciphertexts, streams))
} }
/// ```rust /// ```rust
@@ -620,38 +537,12 @@ impl CudaServerKey {
lhs.as_ref().d_blocks.lwe_ciphertext_count().0, lhs.as_ref().d_blocks.lwe_ciphertext_count().0,
rhs.as_ref().d_blocks.lwe_ciphertext_count().0 rhs.as_ref().d_blocks.lwe_ciphertext_count().0
); );
let ct_res;
let ct_overflowed;
unsafe {
(ct_res, ct_overflowed) =
self.unchecked_unsigned_overflowing_add_async(lhs, rhs, stream);
}
stream.synchronize();
(ct_res, ct_overflowed)
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_unsigned_overflowing_add_async(
&self,
lhs: &CudaUnsignedRadixCiphertext,
rhs: &CudaUnsignedRadixCiphertext,
stream: &CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) {
let output_flag = OutputFlag::from_signedness(CudaUnsignedRadixCiphertext::IS_SIGNED); let output_flag = OutputFlag::from_signedness(CudaUnsignedRadixCiphertext::IS_SIGNED);
let mut ct_res = lhs.duplicate(stream); let mut ct_res = lhs.duplicate(stream);
let mut carry_out: CudaUnsignedRadixCiphertext = self let mut carry_out: CudaUnsignedRadixCiphertext =
.add_and_propagate_single_carry_assign_async( self.add_and_propagate_single_carry_assign(&mut ct_res, rhs, stream, None, output_flag);
&mut ct_res,
rhs,
stream,
None,
output_flag,
);
if lhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO if lhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO
&& rhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO && rhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO
@@ -666,28 +557,43 @@ impl CudaServerKey {
(ct_res, ct_overflowed) (ct_res, ct_overflowed)
} }
/// # Safety pub fn unchecked_signed_overflowing_add(
/// &self,
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must lhs: &CudaSignedRadixCiphertext,
/// not be dropped until stream is synchronised rhs: &CudaSignedRadixCiphertext,
pub unsafe fn unchecked_signed_overflowing_add_async( stream: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
self.unchecked_signed_overflowing_add_with_input_carry(lhs, rhs, None, stream)
}
pub fn unchecked_signed_overflowing_add_with_input_carry(
&self, &self,
lhs: &CudaSignedRadixCiphertext, lhs: &CudaSignedRadixCiphertext,
rhs: &CudaSignedRadixCiphertext, rhs: &CudaSignedRadixCiphertext,
input_carry: Option<&CudaBooleanBlock>, input_carry: Option<&CudaBooleanBlock>,
stream: &CudaStreams, stream: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) { ) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
assert_eq!(
lhs.as_ref().d_blocks.lwe_ciphertext_count().0,
rhs.as_ref().d_blocks.lwe_ciphertext_count().0,
"lhs and rhs must have the name number of blocks ({} vs {})",
lhs.as_ref().d_blocks.lwe_ciphertext_count().0,
rhs.as_ref().d_blocks.lwe_ciphertext_count().0
);
assert!(
lhs.as_ref().d_blocks.lwe_ciphertext_count().0 > 0,
"inputs cannot be empty"
);
let output_flag = OutputFlag::from_signedness(CudaSignedRadixCiphertext::IS_SIGNED); let output_flag = OutputFlag::from_signedness(CudaSignedRadixCiphertext::IS_SIGNED);
let mut ct_res = lhs.duplicate(stream); let mut ct_res = lhs.duplicate(stream);
let carry_out: CudaSignedRadixCiphertext = self let carry_out: CudaSignedRadixCiphertext = self.add_and_propagate_single_carry_assign(
.add_and_propagate_single_carry_assign_async( &mut ct_res,
&mut ct_res, rhs,
rhs, stream,
stream, input_carry,
input_carry, output_flag,
output_flag, );
);
let ct_overflowed = CudaBooleanBlock::from_cuda_radix_ciphertext(carry_out.ciphertext); let ct_overflowed = CudaBooleanBlock::from_cuda_radix_ciphertext(carry_out.ciphertext);
@@ -770,39 +676,7 @@ impl CudaServerKey {
self.unchecked_signed_overflowing_add(lhs, rhs, stream) self.unchecked_signed_overflowing_add(lhs, rhs, stream)
} }
pub fn unchecked_signed_overflowing_add( pub(crate) fn add_and_propagate_single_carry_assign<T>(
&self,
ct_left: &CudaSignedRadixCiphertext,
ct_right: &CudaSignedRadixCiphertext,
stream: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
assert_eq!(
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0,
ct_right.as_ref().d_blocks.lwe_ciphertext_count().0,
"lhs and rhs must have the name number of blocks ({} vs {})",
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0,
ct_right.as_ref().d_blocks.lwe_ciphertext_count().0
);
assert!(
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0 > 0,
"inputs cannot be empty"
);
let result;
let overflowed;
unsafe {
(result, overflowed) =
self.unchecked_signed_overflowing_add_async(ct_left, ct_right, None, stream);
};
stream.synchronize();
(result, overflowed)
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub(crate) unsafe fn add_and_propagate_single_carry_assign_async<T>(
&self, &self,
lhs: &mut T, lhs: &mut T,
rhs: &T, rhs: &T,
@@ -821,58 +695,60 @@ impl CudaServerKey {
let in_carry: &CudaRadixCiphertext = let in_carry: &CudaRadixCiphertext =
input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref()); input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref());
match &self.bootstrapping_key { unsafe {
CudaBootstrappingKey::Classic(d_bsk) => { match &self.bootstrapping_key {
cuda_backend_add_and_propagate_single_carry_assign( CudaBootstrappingKey::Classic(d_bsk) => {
streams, cuda_backend_add_and_propagate_single_carry_assign(
lhs.as_mut(), streams,
rhs.as_ref(), lhs.as_mut(),
carry_out.as_mut(), rhs.as_ref(),
in_carry, carry_out.as_mut(),
&d_bsk.d_vec, in_carry,
&self.key_switching_key.d_vec, &d_bsk.d_vec,
d_bsk.input_lwe_dimension(), &self.key_switching_key.d_vec,
d_bsk.glwe_dimension(), d_bsk.input_lwe_dimension(),
d_bsk.polynomial_size(), d_bsk.glwe_dimension(),
self.key_switching_key.decomposition_level_count(), d_bsk.polynomial_size(),
self.key_switching_key.decomposition_base_log(), self.key_switching_key.decomposition_level_count(),
d_bsk.decomp_level_count(), self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_base_log(), d_bsk.decomp_level_count(),
num_blocks, d_bsk.decomp_base_log(),
self.message_modulus, num_blocks,
self.carry_modulus, self.message_modulus,
PBSType::Classical, self.carry_modulus,
LweBskGroupingFactor(0), PBSType::Classical,
requested_flag, LweBskGroupingFactor(0),
uses_carry, requested_flag,
d_bsk.ms_noise_reduction_configuration.as_ref(), uses_carry,
); d_bsk.ms_noise_reduction_configuration.as_ref(),
} );
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { }
cuda_backend_add_and_propagate_single_carry_assign( CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
streams, cuda_backend_add_and_propagate_single_carry_assign(
lhs.as_mut(), streams,
rhs.as_ref(), lhs.as_mut(),
carry_out.as_mut(), rhs.as_ref(),
in_carry, carry_out.as_mut(),
&d_multibit_bsk.d_vec, in_carry,
&self.key_switching_key.d_vec, &d_multibit_bsk.d_vec,
d_multibit_bsk.input_lwe_dimension(), &self.key_switching_key.d_vec,
d_multibit_bsk.glwe_dimension(), d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.polynomial_size(), d_multibit_bsk.glwe_dimension(),
self.key_switching_key.decomposition_level_count(), d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_base_log(), self.key_switching_key.decomposition_level_count(),
d_multibit_bsk.decomp_level_count(), self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_base_log(), d_multibit_bsk.decomp_level_count(),
num_blocks, d_multibit_bsk.decomp_base_log(),
self.message_modulus, num_blocks,
self.carry_modulus, self.message_modulus,
PBSType::MultiBit, self.carry_modulus,
d_multibit_bsk.grouping_factor, PBSType::MultiBit,
requested_flag, d_multibit_bsk.grouping_factor,
uses_carry, requested_flag,
None, uses_carry,
); None,
);
}
} }
} }
carry_out carry_out

View File

@@ -133,18 +133,15 @@ impl CudaServerKey {
self.get_aes_encrypt_size_on_gpu(num_aes_inputs, parallelism, streams); self.get_aes_encrypt_size_on_gpu(num_aes_inputs, parallelism, streams);
if check_valid_cuda_malloc(aes_encrypt_size, streams.gpu_indexes[0]) { if check_valid_cuda_malloc(aes_encrypt_size, streams.gpu_indexes[0]) {
let round_keys = unsafe { self.key_expansion_async(key, streams) }; let round_keys = self.key_expansion(key, streams);
let res = unsafe { let res = self.aes_encrypt(
self.aes_encrypt_async( iv,
iv, &round_keys,
&round_keys, start_counter,
start_counter, num_aes_inputs,
num_aes_inputs, parallelism,
parallelism, streams,
streams, );
)
};
streams.synchronize();
return res; return res;
} }
parallelism /= 2; parallelism /= 2;
@@ -176,26 +173,18 @@ impl CudaServerKey {
self.get_aes_encrypt_size_on_gpu(num_aes_inputs, sbox_parallelism, streams); self.get_aes_encrypt_size_on_gpu(num_aes_inputs, sbox_parallelism, streams);
check_valid_cuda_malloc_assert_oom(aes_encrypt_size, gpu_index); check_valid_cuda_malloc_assert_oom(aes_encrypt_size, gpu_index);
let round_keys = unsafe { self.key_expansion_async(key, streams) }; let round_keys = self.key_expansion(key, streams);
let res = unsafe { self.aes_encrypt(
self.aes_encrypt_async( iv,
iv, &round_keys,
&round_keys, start_counter,
start_counter, num_aes_inputs,
num_aes_inputs, sbox_parallelism,
sbox_parallelism, streams,
streams, )
)
};
streams.synchronize();
res
} }
/// # Safety pub fn aes_encrypt(
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
/// synchronization is required
pub unsafe fn aes_encrypt_async(
&self, &self,
iv: &CudaUnsignedRadixCiphertext, iv: &CudaUnsignedRadixCiphertext,
round_keys: &CudaUnsignedRadixCiphertext, round_keys: &CudaUnsignedRadixCiphertext,
@@ -229,79 +218,64 @@ impl CudaServerKey {
result.as_ref().d_blocks.lwe_ciphertext_count().0 result.as_ref().d_blocks.lwe_ciphertext_count().0
); );
match &self.bootstrapping_key { unsafe {
CudaBootstrappingKey::Classic(d_bsk) => { match &self.bootstrapping_key {
cuda_backend_unchecked_aes_ctr_encrypt( CudaBootstrappingKey::Classic(d_bsk) => {
streams, cuda_backend_unchecked_aes_ctr_encrypt(
result.as_mut(), streams,
iv.as_ref(), result.as_mut(),
round_keys.as_ref(), iv.as_ref(),
start_counter, round_keys.as_ref(),
num_aes_inputs as u32, start_counter,
sbox_parallelism as u32, num_aes_inputs as u32,
&d_bsk.d_vec, sbox_parallelism as u32,
&self.key_switching_key.d_vec, &d_bsk.d_vec,
self.message_modulus, &self.key_switching_key.d_vec,
self.carry_modulus, self.message_modulus,
d_bsk.glwe_dimension, self.carry_modulus,
d_bsk.polynomial_size, d_bsk.glwe_dimension,
d_bsk.input_lwe_dimension, d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(), d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_base_log(), self.key_switching_key.decomposition_level_count(),
d_bsk.decomp_level_count, self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_base_log, d_bsk.decomp_level_count,
LweBskGroupingFactor(0), d_bsk.decomp_base_log,
PBSType::Classical, LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(), PBSType::Classical,
); d_bsk.ms_noise_reduction_configuration.as_ref(),
} );
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { }
cuda_backend_unchecked_aes_ctr_encrypt( CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
streams, cuda_backend_unchecked_aes_ctr_encrypt(
result.as_mut(), streams,
iv.as_ref(), result.as_mut(),
round_keys.as_ref(), iv.as_ref(),
start_counter, round_keys.as_ref(),
num_aes_inputs as u32, start_counter,
sbox_parallelism as u32, num_aes_inputs as u32,
&d_multibit_bsk.d_vec, sbox_parallelism as u32,
&self.key_switching_key.d_vec, &d_multibit_bsk.d_vec,
self.message_modulus, &self.key_switching_key.d_vec,
self.carry_modulus, self.message_modulus,
d_multibit_bsk.glwe_dimension, self.carry_modulus,
d_multibit_bsk.polynomial_size, d_multibit_bsk.glwe_dimension,
d_multibit_bsk.input_lwe_dimension, d_multibit_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(), d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_base_log(), self.key_switching_key.decomposition_level_count(),
d_multibit_bsk.decomp_level_count, self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_base_log, d_multibit_bsk.decomp_level_count,
d_multibit_bsk.grouping_factor, d_multibit_bsk.decomp_base_log,
PBSType::MultiBit, d_multibit_bsk.grouping_factor,
None, PBSType::MultiBit,
); None,
);
}
} }
} }
result result
} }
fn get_aes_encrypt_size_on_gpu( pub fn get_aes_encrypt_size_on_gpu(
&self,
num_aes_inputs: usize,
sbox_parallelism: usize,
streams: &CudaStreams,
) -> u64 {
let size = unsafe {
self.get_aes_encrypt_size_on_gpu_async(num_aes_inputs, sbox_parallelism, streams)
};
streams.synchronize();
size
}
/// # Safety
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
/// synchronization is required
unsafe fn get_aes_encrypt_size_on_gpu_async(
&self, &self,
num_aes_inputs: usize, num_aes_inputs: usize,
sbox_parallelism: usize, sbox_parallelism: usize,
@@ -347,11 +321,7 @@ impl CudaServerKey {
} }
} }
/// # Safety pub fn key_expansion(
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
/// synchronization is required
pub unsafe fn key_expansion_async(
&self, &self,
key: &CudaUnsignedRadixCiphertext, key: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams, streams: &CudaStreams,
@@ -369,64 +339,56 @@ impl CudaServerKey {
key.as_ref().d_blocks.lwe_ciphertext_count().0 key.as_ref().d_blocks.lwe_ciphertext_count().0
); );
match &self.bootstrapping_key { unsafe {
CudaBootstrappingKey::Classic(d_bsk) => { match &self.bootstrapping_key {
cuda_backend_aes_key_expansion( CudaBootstrappingKey::Classic(d_bsk) => {
streams, cuda_backend_aes_key_expansion(
expanded_keys.as_mut(), streams,
key.as_ref(), expanded_keys.as_mut(),
&d_bsk.d_vec, key.as_ref(),
&self.key_switching_key.d_vec, &d_bsk.d_vec,
self.message_modulus, &self.key_switching_key.d_vec,
self.carry_modulus, self.message_modulus,
d_bsk.glwe_dimension, self.carry_modulus,
d_bsk.polynomial_size, d_bsk.glwe_dimension,
d_bsk.input_lwe_dimension, d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(), d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_base_log(), self.key_switching_key.decomposition_level_count(),
d_bsk.decomp_level_count, self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_base_log, d_bsk.decomp_level_count,
LweBskGroupingFactor(0), d_bsk.decomp_base_log,
PBSType::Classical, LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(), PBSType::Classical,
); d_bsk.ms_noise_reduction_configuration.as_ref(),
} );
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { }
cuda_backend_aes_key_expansion( CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
streams, cuda_backend_aes_key_expansion(
expanded_keys.as_mut(), streams,
key.as_ref(), expanded_keys.as_mut(),
&d_multibit_bsk.d_vec, key.as_ref(),
&self.key_switching_key.d_vec, &d_multibit_bsk.d_vec,
self.message_modulus, &self.key_switching_key.d_vec,
self.carry_modulus, self.message_modulus,
d_multibit_bsk.glwe_dimension, self.carry_modulus,
d_multibit_bsk.polynomial_size, d_multibit_bsk.glwe_dimension,
d_multibit_bsk.input_lwe_dimension, d_multibit_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(), d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_base_log(), self.key_switching_key.decomposition_level_count(),
d_multibit_bsk.decomp_level_count, self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_base_log, d_multibit_bsk.decomp_level_count,
d_multibit_bsk.grouping_factor, d_multibit_bsk.decomp_base_log,
PBSType::MultiBit, d_multibit_bsk.grouping_factor,
None, PBSType::MultiBit,
); None,
);
}
} }
} }
expanded_keys expanded_keys
} }
fn get_key_expansion_size_on_gpu(&self, streams: &CudaStreams) -> u64 { pub fn get_key_expansion_size_on_gpu(&self, streams: &CudaStreams) -> u64 {
let size = unsafe { self.get_key_expansion_size_on_gpu_async(streams) };
streams.synchronize();
size
}
/// # Safety
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
/// synchronization is required
unsafe fn get_key_expansion_size_on_gpu_async(&self, streams: &CudaStreams) -> u64 {
match &self.bootstrapping_key { match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_aes_key_expansion_size_on_gpu( CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_aes_key_expansion_size_on_gpu(
streams, streams,

View File

@@ -62,41 +62,12 @@ impl CudaServerKey {
ct_left: &T, ct_left: &T,
ct_right: &T, ct_right: &T,
streams: &CudaStreams, streams: &CudaStreams,
) -> T {
let result = unsafe { self.unchecked_sub_async(ct_left, ct_right, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_sub_async<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &T,
streams: &CudaStreams,
) -> T { ) -> T {
let mut result = ct_left.duplicate(streams); let mut result = ct_left.duplicate(streams);
self.unchecked_sub_assign_async(&mut result, ct_right, streams); self.unchecked_sub_assign(&mut result, ct_right, streams);
result result
} }
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_sub_assign_async<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) {
let neg = self.unchecked_neg_async(ct_right, streams);
self.unchecked_add_assign_async(ct_left, &neg, streams);
}
/// Computes homomorphically a subtraction between two ciphertexts encrypting integer values. /// Computes homomorphically a subtraction between two ciphertexts encrypting integer values.
/// ///
/// This function computes the subtraction without checking if it exceeds the capacity of the /// This function computes the subtraction without checking if it exceeds the capacity of the
@@ -146,10 +117,8 @@ impl CudaServerKey {
ct_right: &T, ct_right: &T,
streams: &CudaStreams, streams: &CudaStreams,
) { ) {
unsafe { let neg = self.unchecked_neg(ct_right, streams);
self.unchecked_sub_assign_async(ct_left, ct_right, streams); self.unchecked_add_assign(ct_left, &neg, streams);
}
streams.synchronize();
} }
/// Computes homomorphically the subtraction between ct_left and ct_right. /// Computes homomorphically the subtraction between ct_left and ct_right.
@@ -205,8 +174,8 @@ impl CudaServerKey {
ct_right: &T, ct_right: &T,
streams: &CudaStreams, streams: &CudaStreams,
) -> T { ) -> T {
let result = unsafe { self.sub_async(ct_left, ct_right, streams) }; let mut result = ct_left.duplicate(streams);
streams.synchronize(); self.sub_assign(&mut result, ct_right, streams);
result result
} }
@@ -219,42 +188,11 @@ impl CudaServerKey {
self.get_sub_assign_size_on_gpu(ct_left, ct_right, streams) self.get_sub_assign_size_on_gpu(ct_left, ct_right, streams)
} }
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn sub_async<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &T,
streams: &CudaStreams,
) -> T {
let mut result = ct_left.duplicate(streams);
self.sub_assign_async(&mut result, ct_right, streams);
result
}
pub fn sub_assign<T: CudaIntegerRadixCiphertext>( pub fn sub_assign<T: CudaIntegerRadixCiphertext>(
&self, &self,
ct_left: &mut T, ct_left: &mut T,
ct_right: &T, ct_right: &T,
streams: &CudaStreams, streams: &CudaStreams,
) {
unsafe {
self.sub_assign_async(ct_left, ct_right, streams);
}
streams.synchronize();
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn sub_assign_async<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) { ) {
let mut tmp_rhs; let mut tmp_rhs;
@@ -345,27 +283,6 @@ impl CudaServerKey {
lhs.as_ref().d_blocks.lwe_ciphertext_count().0, lhs.as_ref().d_blocks.lwe_ciphertext_count().0,
rhs.as_ref().d_blocks.lwe_ciphertext_count().0 rhs.as_ref().d_blocks.lwe_ciphertext_count().0
); );
let ct_res;
let ct_overflowed;
unsafe {
(ct_res, ct_overflowed) =
self.unchecked_unsigned_overflowing_sub_async(lhs, rhs, stream);
}
stream.synchronize();
(ct_res, ct_overflowed)
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_unsigned_overflowing_sub_async(
&self,
lhs: &CudaUnsignedRadixCiphertext,
rhs: &CudaUnsignedRadixCiphertext,
stream: &CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) {
let mut ct_res = lhs.duplicate(stream); let mut ct_res = lhs.duplicate(stream);
let compute_overflow = true; let compute_overflow = true;
@@ -380,56 +297,58 @@ impl CudaServerKey {
let in_carry_dvec = let in_carry_dvec =
INPUT_BORROW.map_or_else(|| aux_block.as_ref(), |block| block.as_ref().as_ref()); INPUT_BORROW.map_or_else(|| aux_block.as_ref(), |block| block.as_ref().as_ref());
match &self.bootstrapping_key { unsafe {
CudaBootstrappingKey::Classic(d_bsk) => { match &self.bootstrapping_key {
cuda_backend_unchecked_unsigned_overflowing_sub_assign( CudaBootstrappingKey::Classic(d_bsk) => {
stream, cuda_backend_unchecked_unsigned_overflowing_sub_assign(
ciphertext, stream,
rhs.as_ref(), ciphertext,
overflow_block.as_mut(), rhs.as_ref(),
in_carry_dvec, overflow_block.as_mut(),
&d_bsk.d_vec, in_carry_dvec,
&self.key_switching_key.d_vec, &d_bsk.d_vec,
d_bsk.input_lwe_dimension(), &self.key_switching_key.d_vec,
d_bsk.glwe_dimension(), d_bsk.input_lwe_dimension(),
d_bsk.polynomial_size(), d_bsk.glwe_dimension(),
self.key_switching_key.decomposition_level_count(), d_bsk.polynomial_size(),
self.key_switching_key.decomposition_base_log(), self.key_switching_key.decomposition_level_count(),
d_bsk.decomp_level_count(), self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_base_log(), d_bsk.decomp_level_count(),
ciphertext.info.blocks.first().unwrap().message_modulus, d_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().carry_modulus, ciphertext.info.blocks.first().unwrap().message_modulus,
PBSType::Classical, ciphertext.info.blocks.first().unwrap().carry_modulus,
LweBskGroupingFactor(0), PBSType::Classical,
compute_overflow, LweBskGroupingFactor(0),
uses_input_borrow, compute_overflow,
d_bsk.ms_noise_reduction_configuration.as_ref(), uses_input_borrow,
); d_bsk.ms_noise_reduction_configuration.as_ref(),
} );
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { }
cuda_backend_unchecked_unsigned_overflowing_sub_assign( CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
stream, cuda_backend_unchecked_unsigned_overflowing_sub_assign(
ciphertext, stream,
rhs.as_ref(), ciphertext,
overflow_block.as_mut(), rhs.as_ref(),
in_carry_dvec, overflow_block.as_mut(),
&d_multibit_bsk.d_vec, in_carry_dvec,
&self.key_switching_key.d_vec, &d_multibit_bsk.d_vec,
d_multibit_bsk.input_lwe_dimension(), &self.key_switching_key.d_vec,
d_multibit_bsk.glwe_dimension(), d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.polynomial_size(), d_multibit_bsk.glwe_dimension(),
self.key_switching_key.decomposition_level_count(), d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_base_log(), self.key_switching_key.decomposition_level_count(),
d_multibit_bsk.decomp_level_count(), self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_base_log(), d_multibit_bsk.decomp_level_count(),
ciphertext.info.blocks.first().unwrap().message_modulus, d_multibit_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().carry_modulus, ciphertext.info.blocks.first().unwrap().message_modulus,
PBSType::MultiBit, ciphertext.info.blocks.first().unwrap().carry_modulus,
d_multibit_bsk.grouping_factor, PBSType::MultiBit,
compute_overflow, d_multibit_bsk.grouping_factor,
uses_input_borrow, compute_overflow,
None, uses_input_borrow,
); None,
);
}
} }
} }
let ct_overflowed = CudaBooleanBlock::from_cuda_radix_ciphertext(overflow_block.ciphertext); let ct_overflowed = CudaBooleanBlock::from_cuda_radix_ciphertext(overflow_block.ciphertext);
@@ -437,11 +356,7 @@ impl CudaServerKey {
(ct_res, ct_overflowed) (ct_res, ct_overflowed)
} }
/// # Safety pub(crate) fn sub_and_propagate_single_carry_assign<T>(
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub(crate) unsafe fn sub_and_propagate_single_carry_assign<T>(
&self, &self,
lhs: &mut T, lhs: &mut T,
rhs: &T, rhs: &T,
@@ -460,58 +375,60 @@ impl CudaServerKey {
let in_carry: &CudaRadixCiphertext = let in_carry: &CudaRadixCiphertext =
input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref()); input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref());
match &self.bootstrapping_key { unsafe {
CudaBootstrappingKey::Classic(d_bsk) => { match &self.bootstrapping_key {
cuda_backend_sub_and_propagate_single_carry_assign( CudaBootstrappingKey::Classic(d_bsk) => {
streams, cuda_backend_sub_and_propagate_single_carry_assign(
lhs.as_mut(), streams,
rhs.as_ref(), lhs.as_mut(),
carry_out.as_mut(), rhs.as_ref(),
in_carry, carry_out.as_mut(),
&d_bsk.d_vec, in_carry,
&self.key_switching_key.d_vec, &d_bsk.d_vec,
d_bsk.input_lwe_dimension(), &self.key_switching_key.d_vec,
d_bsk.glwe_dimension(), d_bsk.input_lwe_dimension(),
d_bsk.polynomial_size(), d_bsk.glwe_dimension(),
self.key_switching_key.decomposition_level_count(), d_bsk.polynomial_size(),
self.key_switching_key.decomposition_base_log(), self.key_switching_key.decomposition_level_count(),
d_bsk.decomp_level_count(), self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_base_log(), d_bsk.decomp_level_count(),
num_blocks, d_bsk.decomp_base_log(),
self.message_modulus, num_blocks,
self.carry_modulus, self.message_modulus,
PBSType::Classical, self.carry_modulus,
LweBskGroupingFactor(0), PBSType::Classical,
requested_flag, LweBskGroupingFactor(0),
uses_carry, requested_flag,
d_bsk.ms_noise_reduction_configuration.as_ref(), uses_carry,
); d_bsk.ms_noise_reduction_configuration.as_ref(),
} );
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => { }
cuda_backend_sub_and_propagate_single_carry_assign( CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
streams, cuda_backend_sub_and_propagate_single_carry_assign(
lhs.as_mut(), streams,
rhs.as_ref(), lhs.as_mut(),
carry_out.as_mut(), rhs.as_ref(),
in_carry, carry_out.as_mut(),
&d_multibit_bsk.d_vec, in_carry,
&self.key_switching_key.d_vec, &d_multibit_bsk.d_vec,
d_multibit_bsk.input_lwe_dimension(), &self.key_switching_key.d_vec,
d_multibit_bsk.glwe_dimension(), d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.polynomial_size(), d_multibit_bsk.glwe_dimension(),
self.key_switching_key.decomposition_level_count(), d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_base_log(), self.key_switching_key.decomposition_level_count(),
d_multibit_bsk.decomp_level_count(), self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_base_log(), d_multibit_bsk.decomp_level_count(),
num_blocks, d_multibit_bsk.decomp_base_log(),
self.message_modulus, num_blocks,
self.carry_modulus, self.message_modulus,
PBSType::MultiBit, self.carry_modulus,
d_multibit_bsk.grouping_factor, PBSType::MultiBit,
requested_flag, d_multibit_bsk.grouping_factor,
uses_carry, requested_flag,
None, uses_carry,
); None,
);
}
} }
} }
carry_out carry_out
@@ -611,30 +528,11 @@ impl CudaServerKey {
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0 > 0, ct_left.as_ref().d_blocks.lwe_ciphertext_count().0 > 0,
"inputs cannot be empty" "inputs cannot be empty"
); );
let result;
let overflowed;
unsafe {
(result, overflowed) =
self.unchecked_signed_overflowing_sub_async(ct_left, ct_right, stream);
};
stream.synchronize();
(result, overflowed)
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_signed_overflowing_sub_async(
&self,
ct_left: &CudaSignedRadixCiphertext,
ct_right: &CudaSignedRadixCiphertext,
stream: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
let flipped_rhs = self.bitnot(ct_right, stream); let flipped_rhs = self.bitnot(ct_right, stream);
let ct_input_carry: CudaUnsignedRadixCiphertext = self.create_trivial_radix(1, 1, stream); let ct_input_carry: CudaUnsignedRadixCiphertext = self.create_trivial_radix(1, 1, stream);
let input_carry = CudaBooleanBlock::from_cuda_radix_ciphertext(ct_input_carry.ciphertext); let input_carry = CudaBooleanBlock::from_cuda_radix_ciphertext(ct_input_carry.ciphertext);
self.unchecked_signed_overflowing_add_async( self.unchecked_signed_overflowing_add_with_input_carry(
ct_left, ct_left,
&flipped_rhs, &flipped_rhs,
Some(&input_carry), Some(&input_carry),