chore(gpu): remove async entry points for abs, add, sub, aes

This commit is contained in:
Agnes Leroy
2025-10-17 12:21:50 +02:00
committed by Agnès Leroy
parent 70b0c0ff19
commit c30835fc30
6 changed files with 411 additions and 692 deletions

View File

@@ -47,22 +47,18 @@ pub mod cuda {
const SBOX_PARALLELISM: usize = 16;
let bench_id = format!("{param_name}::{NUM_AES_INPUTS}_input_encryption");
let round_keys = unsafe { sks.key_expansion_async(&d_key, &streams) };
streams.synchronize();
let round_keys = sks.key_expansion(&d_key, &streams);
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
unsafe {
black_box(sks.aes_encrypt_async(
&d_iv,
&round_keys,
0,
NUM_AES_INPUTS,
SBOX_PARALLELISM,
&streams,
));
}
streams.synchronize();
black_box(sks.aes_encrypt(
&d_iv,
&round_keys,
0,
NUM_AES_INPUTS,
SBOX_PARALLELISM,
&streams,
));
})
});
@@ -82,10 +78,7 @@ pub mod cuda {
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
unsafe {
black_box(sks.key_expansion_async(&d_key, &streams));
}
streams.synchronize();
black_box(sks.key_expansion(&d_key, &streams));
})
});
@@ -118,22 +111,18 @@ pub mod cuda {
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams);
let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams);
let round_keys = unsafe { sks.key_expansion_async(&d_key, &streams) };
streams.synchronize();
let round_keys = sks.key_expansion(&d_key, &streams);
bench_group.bench_function(&bench_id, |b| {
b.iter(|| {
unsafe {
black_box(sks.aes_encrypt_async(
&d_iv,
&round_keys,
0,
NUM_AES_INPUTS,
SBOX_PARALLELISM,
&streams,
));
}
streams.synchronize();
black_box(sks.aes_encrypt(
&d_iv,
&round_keys,
0,
NUM_AES_INPUTS,
SBOX_PARALLELISM,
&streams,
));
})
});

View File

@@ -7708,7 +7708,7 @@ pub(crate) unsafe fn cuda_backend_aes_key_expansion<T: UnsignedInteger, B: Numer
}
#[allow(clippy::too_many_arguments)]
pub(crate) unsafe fn cuda_backend_get_aes_key_expansion_size_on_gpu(
pub(crate) fn cuda_backend_get_aes_key_expansion_size_on_gpu(
streams: &CudaStreams,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
@@ -7726,7 +7726,7 @@ pub(crate) unsafe fn cuda_backend_get_aes_key_expansion_size_on_gpu(
let noise_reduction_type = resolve_noise_reduction_type(ms_noise_reduction_configuration);
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
let size = {
let size = unsafe {
scratch_cuda_integer_key_expansion_64(
streams.ffi(),
std::ptr::addr_of_mut!(mem_ptr),

View File

@@ -5,11 +5,7 @@ use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
use crate::integer::gpu::{cuda_backend_unchecked_signed_abs_assign, PBSType};
impl CudaServerKey {
/// # Safety
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
/// synchronization is required
pub unsafe fn unchecked_abs_assign_async<T>(&self, ct: &mut T, streams: &CudaStreams)
pub fn unchecked_abs_assign<T>(&self, ct: &mut T, streams: &CudaStreams)
where
T: CudaIntegerRadixCiphertext,
{
@@ -78,9 +74,8 @@ impl CudaServerKey {
{
let mut res = ct.duplicate(streams);
if T::IS_SIGNED {
unsafe { self.unchecked_abs_assign_async(&mut res, streams) };
self.unchecked_abs_assign(&mut res, streams);
}
streams.synchronize();
res
}
@@ -139,9 +134,8 @@ impl CudaServerKey {
self.full_propagate_assign(&mut res, streams);
}
if T::IS_SIGNED {
unsafe { self.unchecked_abs_assign_async(&mut res, streams) };
self.unchecked_abs_assign(&mut res, streams);
}
streams.synchronize();
res
}
}

View File

@@ -27,10 +27,6 @@ impl CudaServerKey {
/// example) has always the same performance characteristics from one call to another and
/// guarantees correctness by pre-emptively clearing carries of output ciphertexts.
///
/// # Warning
///
/// - Multithreaded
///
/// # Example
///
/// ```rust
@@ -86,11 +82,7 @@ impl CudaServerKey {
self.get_add_assign_size_on_gpu(ct_left, ct_right, streams)
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn add_assign_async<T: CudaIntegerRadixCiphertext>(
pub fn add_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
@@ -121,25 +113,8 @@ impl CudaServerKey {
}
};
let _carry = self.add_and_propagate_single_carry_assign_async(
lhs,
rhs,
streams,
None,
OutputFlag::None,
);
}
pub fn add_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) {
unsafe {
self.add_assign_async(ct_left, ct_right, streams);
}
streams.synchronize();
let _carry =
self.add_and_propagate_single_carry_assign(lhs, rhs, streams, None, OutputFlag::None);
}
pub fn get_add_assign_size_on_gpu<T: CudaIntegerRadixCiphertext>(
@@ -286,11 +261,7 @@ impl CudaServerKey {
result
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_add_assign_async<T: CudaIntegerRadixCiphertext>(
pub fn unchecked_add_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
@@ -319,22 +290,7 @@ impl CudaServerKey {
}
}
pub fn unchecked_add_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) {
unsafe {
self.unchecked_add_assign_async(ct_left, ct_right, streams);
}
streams.synchronize();
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_partial_sum_ciphertexts_assign_async<T: CudaIntegerRadixCiphertext>(
pub fn unchecked_partial_sum_ciphertexts_assign<T: CudaIntegerRadixCiphertext>(
&self,
result: &mut T,
ciphertexts: &[T],
@@ -345,11 +301,14 @@ impl CudaServerKey {
return;
}
result.as_mut().d_blocks.0.d_vec.copy_from_gpu_async(
&ciphertexts[0].as_ref().d_blocks.0.d_vec,
streams,
0,
);
unsafe {
result.as_mut().d_blocks.0.d_vec.copy_from_gpu_async(
&ciphertexts[0].as_ref().d_blocks.0.d_vec,
streams,
0,
);
streams.synchronize();
}
result.as_mut().info = ciphertexts[0].as_ref().info.clone();
if ciphertexts.len() == 1 {
return;
@@ -365,7 +324,7 @@ impl CudaServerKey {
);
if ciphertexts.len() == 2 {
self.add_assign_async(result, &ciphertexts[1], streams);
self.add_assign(result, &ciphertexts[1], streams);
return;
}
@@ -373,58 +332,60 @@ impl CudaServerKey {
let mut terms = CudaRadixCiphertext::from_radix_ciphertext_vec(ciphertexts, streams);
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_partial_sum_ciphertexts_assign(
streams,
result.as_mut(),
&mut terms,
reduce_degrees_for_single_carry_propagation,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
num_blocks.0 as u32,
radix_count_in_vec as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_partial_sum_ciphertexts_assign(
streams,
result.as_mut(),
&mut terms,
reduce_degrees_for_single_carry_propagation,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
num_blocks.0 as u32,
radix_count_in_vec as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_partial_sum_ciphertexts_assign(
streams,
result.as_mut(),
&mut terms,
reduce_degrees_for_single_carry_propagation,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
num_blocks.0 as u32,
radix_count_in_vec as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_partial_sum_ciphertexts_assign(
streams,
result.as_mut(),
&mut terms,
reduce_degrees_for_single_carry_propagation,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
num_blocks.0 as u32,
radix_count_in_vec as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
}
@@ -433,23 +394,9 @@ impl CudaServerKey {
&self,
ciphertexts: &[T],
streams: &CudaStreams,
) -> T {
let result = unsafe { self.unchecked_sum_ciphertexts_async(ciphertexts, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_sum_ciphertexts_async<T: CudaIntegerRadixCiphertext>(
&self,
ciphertexts: &[T],
streams: &CudaStreams,
) -> T {
let mut result = self
.unchecked_partial_sum_ciphertexts_async(ciphertexts, true, streams)
.unchecked_partial_sum_ciphertexts(ciphertexts, true, streams)
.unwrap();
self.propagate_single_carry_assign(&mut result, streams, None, OutputFlag::None);
@@ -458,21 +405,6 @@ impl CudaServerKey {
}
pub fn unchecked_partial_sum_ciphertexts<T: CudaIntegerRadixCiphertext>(
&self,
ciphertexts: &[T],
streams: &CudaStreams,
) -> Option<T> {
let result =
unsafe { self.unchecked_partial_sum_ciphertexts_async(ciphertexts, false, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_partial_sum_ciphertexts_async<T: CudaIntegerRadixCiphertext>(
&self,
ciphertexts: &[T],
reduce_degrees_for_single_carry_propagation: bool,
@@ -485,11 +417,10 @@ impl CudaServerKey {
let mut result = ciphertexts[0].duplicate(streams);
if ciphertexts.len() == 1 {
streams.synchronize();
return Some(result);
}
self.unchecked_partial_sum_ciphertexts_assign_async(
self.unchecked_partial_sum_ciphertexts_assign(
&mut result,
ciphertexts,
reduce_degrees_for_single_carry_propagation,
@@ -500,20 +431,6 @@ impl CudaServerKey {
}
pub fn sum_ciphertexts<T: CudaIntegerRadixCiphertext>(
&self,
ciphertexts: Vec<T>,
streams: &CudaStreams,
) -> Option<T> {
let res = unsafe { self.sum_ciphertexts_async(ciphertexts, streams) };
streams.synchronize();
res
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn sum_ciphertexts_async<T: CudaIntegerRadixCiphertext>(
&self,
mut ciphertexts: Vec<T>,
streams: &CudaStreams,
@@ -529,7 +446,7 @@ impl CudaServerKey {
self.full_propagate_assign(&mut *ct, streams);
});
Some(self.unchecked_sum_ciphertexts_async(&ciphertexts, streams))
Some(self.unchecked_sum_ciphertexts(&ciphertexts, streams))
}
/// ```rust
@@ -620,38 +537,12 @@ impl CudaServerKey {
lhs.as_ref().d_blocks.lwe_ciphertext_count().0,
rhs.as_ref().d_blocks.lwe_ciphertext_count().0
);
let ct_res;
let ct_overflowed;
unsafe {
(ct_res, ct_overflowed) =
self.unchecked_unsigned_overflowing_add_async(lhs, rhs, stream);
}
stream.synchronize();
(ct_res, ct_overflowed)
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_unsigned_overflowing_add_async(
&self,
lhs: &CudaUnsignedRadixCiphertext,
rhs: &CudaUnsignedRadixCiphertext,
stream: &CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) {
let output_flag = OutputFlag::from_signedness(CudaUnsignedRadixCiphertext::IS_SIGNED);
let mut ct_res = lhs.duplicate(stream);
let mut carry_out: CudaUnsignedRadixCiphertext = self
.add_and_propagate_single_carry_assign_async(
&mut ct_res,
rhs,
stream,
None,
output_flag,
);
let mut carry_out: CudaUnsignedRadixCiphertext =
self.add_and_propagate_single_carry_assign(&mut ct_res, rhs, stream, None, output_flag);
if lhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO
&& rhs.as_ref().info.blocks.last().unwrap().noise_level == NoiseLevel::ZERO
@@ -666,28 +557,43 @@ impl CudaServerKey {
(ct_res, ct_overflowed)
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_signed_overflowing_add_async(
pub fn unchecked_signed_overflowing_add(
&self,
lhs: &CudaSignedRadixCiphertext,
rhs: &CudaSignedRadixCiphertext,
stream: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
self.unchecked_signed_overflowing_add_with_input_carry(lhs, rhs, None, stream)
}
pub fn unchecked_signed_overflowing_add_with_input_carry(
&self,
lhs: &CudaSignedRadixCiphertext,
rhs: &CudaSignedRadixCiphertext,
input_carry: Option<&CudaBooleanBlock>,
stream: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
assert_eq!(
lhs.as_ref().d_blocks.lwe_ciphertext_count().0,
rhs.as_ref().d_blocks.lwe_ciphertext_count().0,
"lhs and rhs must have the name number of blocks ({} vs {})",
lhs.as_ref().d_blocks.lwe_ciphertext_count().0,
rhs.as_ref().d_blocks.lwe_ciphertext_count().0
);
assert!(
lhs.as_ref().d_blocks.lwe_ciphertext_count().0 > 0,
"inputs cannot be empty"
);
let output_flag = OutputFlag::from_signedness(CudaSignedRadixCiphertext::IS_SIGNED);
let mut ct_res = lhs.duplicate(stream);
let carry_out: CudaSignedRadixCiphertext = self
.add_and_propagate_single_carry_assign_async(
&mut ct_res,
rhs,
stream,
input_carry,
output_flag,
);
let carry_out: CudaSignedRadixCiphertext = self.add_and_propagate_single_carry_assign(
&mut ct_res,
rhs,
stream,
input_carry,
output_flag,
);
let ct_overflowed = CudaBooleanBlock::from_cuda_radix_ciphertext(carry_out.ciphertext);
@@ -770,39 +676,7 @@ impl CudaServerKey {
self.unchecked_signed_overflowing_add(lhs, rhs, stream)
}
pub fn unchecked_signed_overflowing_add(
&self,
ct_left: &CudaSignedRadixCiphertext,
ct_right: &CudaSignedRadixCiphertext,
stream: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
assert_eq!(
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0,
ct_right.as_ref().d_blocks.lwe_ciphertext_count().0,
"lhs and rhs must have the name number of blocks ({} vs {})",
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0,
ct_right.as_ref().d_blocks.lwe_ciphertext_count().0
);
assert!(
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0 > 0,
"inputs cannot be empty"
);
let result;
let overflowed;
unsafe {
(result, overflowed) =
self.unchecked_signed_overflowing_add_async(ct_left, ct_right, None, stream);
};
stream.synchronize();
(result, overflowed)
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub(crate) unsafe fn add_and_propagate_single_carry_assign_async<T>(
pub(crate) fn add_and_propagate_single_carry_assign<T>(
&self,
lhs: &mut T,
rhs: &T,
@@ -821,58 +695,60 @@ impl CudaServerKey {
let in_carry: &CudaRadixCiphertext =
input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref());
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_add_and_propagate_single_carry_assign(
streams,
lhs.as_mut(),
rhs.as_ref(),
carry_out.as_mut(),
in_carry,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
num_blocks,
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
requested_flag,
uses_carry,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_add_and_propagate_single_carry_assign(
streams,
lhs.as_mut(),
rhs.as_ref(),
carry_out.as_mut(),
in_carry,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
num_blocks,
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
requested_flag,
uses_carry,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_add_and_propagate_single_carry_assign(
streams,
lhs.as_mut(),
rhs.as_ref(),
carry_out.as_mut(),
in_carry,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
num_blocks,
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
requested_flag,
uses_carry,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_add_and_propagate_single_carry_assign(
streams,
lhs.as_mut(),
rhs.as_ref(),
carry_out.as_mut(),
in_carry,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
num_blocks,
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
requested_flag,
uses_carry,
None,
);
}
}
}
carry_out

View File

@@ -133,18 +133,15 @@ impl CudaServerKey {
self.get_aes_encrypt_size_on_gpu(num_aes_inputs, parallelism, streams);
if check_valid_cuda_malloc(aes_encrypt_size, streams.gpu_indexes[0]) {
let round_keys = unsafe { self.key_expansion_async(key, streams) };
let res = unsafe {
self.aes_encrypt_async(
iv,
&round_keys,
start_counter,
num_aes_inputs,
parallelism,
streams,
)
};
streams.synchronize();
let round_keys = self.key_expansion(key, streams);
let res = self.aes_encrypt(
iv,
&round_keys,
start_counter,
num_aes_inputs,
parallelism,
streams,
);
return res;
}
parallelism /= 2;
@@ -176,26 +173,18 @@ impl CudaServerKey {
self.get_aes_encrypt_size_on_gpu(num_aes_inputs, sbox_parallelism, streams);
check_valid_cuda_malloc_assert_oom(aes_encrypt_size, gpu_index);
let round_keys = unsafe { self.key_expansion_async(key, streams) };
let res = unsafe {
self.aes_encrypt_async(
iv,
&round_keys,
start_counter,
num_aes_inputs,
sbox_parallelism,
streams,
)
};
streams.synchronize();
res
let round_keys = self.key_expansion(key, streams);
self.aes_encrypt(
iv,
&round_keys,
start_counter,
num_aes_inputs,
sbox_parallelism,
streams,
)
}
/// # Safety
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
/// synchronization is required
pub unsafe fn aes_encrypt_async(
pub fn aes_encrypt(
&self,
iv: &CudaUnsignedRadixCiphertext,
round_keys: &CudaUnsignedRadixCiphertext,
@@ -229,79 +218,64 @@ impl CudaServerKey {
result.as_ref().d_blocks.lwe_ciphertext_count().0
);
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_aes_ctr_encrypt(
streams,
result.as_mut(),
iv.as_ref(),
round_keys.as_ref(),
start_counter,
num_aes_inputs as u32,
sbox_parallelism as u32,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_aes_ctr_encrypt(
streams,
result.as_mut(),
iv.as_ref(),
round_keys.as_ref(),
start_counter,
num_aes_inputs as u32,
sbox_parallelism as u32,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_aes_ctr_encrypt(
streams,
result.as_mut(),
iv.as_ref(),
round_keys.as_ref(),
start_counter,
num_aes_inputs as u32,
sbox_parallelism as u32,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_aes_ctr_encrypt(
streams,
result.as_mut(),
iv.as_ref(),
round_keys.as_ref(),
start_counter,
num_aes_inputs as u32,
sbox_parallelism as u32,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
}
}
}
result
}
fn get_aes_encrypt_size_on_gpu(
&self,
num_aes_inputs: usize,
sbox_parallelism: usize,
streams: &CudaStreams,
) -> u64 {
let size = unsafe {
self.get_aes_encrypt_size_on_gpu_async(num_aes_inputs, sbox_parallelism, streams)
};
streams.synchronize();
size
}
/// # Safety
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
/// synchronization is required
unsafe fn get_aes_encrypt_size_on_gpu_async(
pub fn get_aes_encrypt_size_on_gpu(
&self,
num_aes_inputs: usize,
sbox_parallelism: usize,
@@ -347,11 +321,7 @@ impl CudaServerKey {
}
}
/// # Safety
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
/// synchronization is required
pub unsafe fn key_expansion_async(
pub fn key_expansion(
&self,
key: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
@@ -369,64 +339,56 @@ impl CudaServerKey {
key.as_ref().d_blocks.lwe_ciphertext_count().0
);
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_aes_key_expansion(
streams,
expanded_keys.as_mut(),
key.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_aes_key_expansion(
streams,
expanded_keys.as_mut(),
key.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_aes_key_expansion(
streams,
expanded_keys.as_mut(),
key.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_aes_key_expansion(
streams,
expanded_keys.as_mut(),
key.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
}
}
}
expanded_keys
}
fn get_key_expansion_size_on_gpu(&self, streams: &CudaStreams) -> u64 {
let size = unsafe { self.get_key_expansion_size_on_gpu_async(streams) };
streams.synchronize();
size
}
/// # Safety
///
/// - [CudaStreams::synchronize] __must__ be called after this function as soon as
/// synchronization is required
unsafe fn get_key_expansion_size_on_gpu_async(&self, streams: &CudaStreams) -> u64 {
pub fn get_key_expansion_size_on_gpu(&self, streams: &CudaStreams) -> u64 {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => cuda_backend_get_aes_key_expansion_size_on_gpu(
streams,

View File

@@ -62,41 +62,12 @@ impl CudaServerKey {
ct_left: &T,
ct_right: &T,
streams: &CudaStreams,
) -> T {
let result = unsafe { self.unchecked_sub_async(ct_left, ct_right, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_sub_async<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &T,
streams: &CudaStreams,
) -> T {
let mut result = ct_left.duplicate(streams);
self.unchecked_sub_assign_async(&mut result, ct_right, streams);
self.unchecked_sub_assign(&mut result, ct_right, streams);
result
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_sub_assign_async<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) {
let neg = self.unchecked_neg_async(ct_right, streams);
self.unchecked_add_assign_async(ct_left, &neg, streams);
}
/// Computes homomorphically a subtraction between two ciphertexts encrypting integer values.
///
/// This function computes the subtraction without checking if it exceeds the capacity of the
@@ -146,10 +117,8 @@ impl CudaServerKey {
ct_right: &T,
streams: &CudaStreams,
) {
unsafe {
self.unchecked_sub_assign_async(ct_left, ct_right, streams);
}
streams.synchronize();
let neg = self.unchecked_neg(ct_right, streams);
self.unchecked_add_assign(ct_left, &neg, streams);
}
/// Computes homomorphically the subtraction between ct_left and ct_right.
@@ -205,8 +174,8 @@ impl CudaServerKey {
ct_right: &T,
streams: &CudaStreams,
) -> T {
let result = unsafe { self.sub_async(ct_left, ct_right, streams) };
streams.synchronize();
let mut result = ct_left.duplicate(streams);
self.sub_assign(&mut result, ct_right, streams);
result
}
@@ -219,42 +188,11 @@ impl CudaServerKey {
self.get_sub_assign_size_on_gpu(ct_left, ct_right, streams)
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn sub_async<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,
ct_right: &T,
streams: &CudaStreams,
) -> T {
let mut result = ct_left.duplicate(streams);
self.sub_assign_async(&mut result, ct_right, streams);
result
}
pub fn sub_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) {
unsafe {
self.sub_assign_async(ct_left, ct_right, streams);
}
streams.synchronize();
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn sub_assign_async<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) {
let mut tmp_rhs;
@@ -345,27 +283,6 @@ impl CudaServerKey {
lhs.as_ref().d_blocks.lwe_ciphertext_count().0,
rhs.as_ref().d_blocks.lwe_ciphertext_count().0
);
let ct_res;
let ct_overflowed;
unsafe {
(ct_res, ct_overflowed) =
self.unchecked_unsigned_overflowing_sub_async(lhs, rhs, stream);
}
stream.synchronize();
(ct_res, ct_overflowed)
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_unsigned_overflowing_sub_async(
&self,
lhs: &CudaUnsignedRadixCiphertext,
rhs: &CudaUnsignedRadixCiphertext,
stream: &CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaBooleanBlock) {
let mut ct_res = lhs.duplicate(stream);
let compute_overflow = true;
@@ -380,56 +297,58 @@ impl CudaServerKey {
let in_carry_dvec =
INPUT_BORROW.map_or_else(|| aux_block.as_ref(), |block| block.as_ref().as_ref());
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
stream,
ciphertext,
rhs.as_ref(),
overflow_block.as_mut(),
in_carry_dvec,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().message_modulus,
ciphertext.info.blocks.first().unwrap().carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
compute_overflow,
uses_input_borrow,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
stream,
ciphertext,
rhs.as_ref(),
overflow_block.as_mut(),
in_carry_dvec,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().message_modulus,
ciphertext.info.blocks.first().unwrap().carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
compute_overflow,
uses_input_borrow,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
stream,
ciphertext,
rhs.as_ref(),
overflow_block.as_mut(),
in_carry_dvec,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().message_modulus,
ciphertext.info.blocks.first().unwrap().carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
compute_overflow,
uses_input_borrow,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_unsigned_overflowing_sub_assign(
stream,
ciphertext,
rhs.as_ref(),
overflow_block.as_mut(),
in_carry_dvec,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
ciphertext.info.blocks.first().unwrap().message_modulus,
ciphertext.info.blocks.first().unwrap().carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
compute_overflow,
uses_input_borrow,
None,
);
}
}
}
let ct_overflowed = CudaBooleanBlock::from_cuda_radix_ciphertext(overflow_block.ciphertext);
@@ -437,11 +356,7 @@ impl CudaServerKey {
(ct_res, ct_overflowed)
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub(crate) unsafe fn sub_and_propagate_single_carry_assign<T>(
pub(crate) fn sub_and_propagate_single_carry_assign<T>(
&self,
lhs: &mut T,
rhs: &T,
@@ -460,58 +375,60 @@ impl CudaServerKey {
let in_carry: &CudaRadixCiphertext =
input_carry.map_or_else(|| aux_block.as_ref(), |block| block.0.as_ref());
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_sub_and_propagate_single_carry_assign(
streams,
lhs.as_mut(),
rhs.as_ref(),
carry_out.as_mut(),
in_carry,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
num_blocks,
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
requested_flag,
uses_carry,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_sub_and_propagate_single_carry_assign(
streams,
lhs.as_mut(),
rhs.as_ref(),
carry_out.as_mut(),
in_carry,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
num_blocks,
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
requested_flag,
uses_carry,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_sub_and_propagate_single_carry_assign(
streams,
lhs.as_mut(),
rhs.as_ref(),
carry_out.as_mut(),
in_carry,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
d_bsk.input_lwe_dimension(),
d_bsk.glwe_dimension(),
d_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count(),
d_bsk.decomp_base_log(),
num_blocks,
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
LweBskGroupingFactor(0),
requested_flag,
uses_carry,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_sub_and_propagate_single_carry_assign(
streams,
lhs.as_mut(),
rhs.as_ref(),
carry_out.as_mut(),
in_carry,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.polynomial_size(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count(),
d_multibit_bsk.decomp_base_log(),
num_blocks,
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
requested_flag,
uses_carry,
None,
);
}
}
}
carry_out
@@ -611,30 +528,11 @@ impl CudaServerKey {
ct_left.as_ref().d_blocks.lwe_ciphertext_count().0 > 0,
"inputs cannot be empty"
);
let result;
let overflowed;
unsafe {
(result, overflowed) =
self.unchecked_signed_overflowing_sub_async(ct_left, ct_right, stream);
};
stream.synchronize();
(result, overflowed)
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_signed_overflowing_sub_async(
&self,
ct_left: &CudaSignedRadixCiphertext,
ct_right: &CudaSignedRadixCiphertext,
stream: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaBooleanBlock) {
let flipped_rhs = self.bitnot(ct_right, stream);
let ct_input_carry: CudaUnsignedRadixCiphertext = self.create_trivial_radix(1, 1, stream);
let input_carry = CudaBooleanBlock::from_cuda_radix_ciphertext(ct_input_carry.ciphertext);
self.unchecked_signed_overflowing_add_async(
self.unchecked_signed_overflowing_add_with_input_carry(
ct_left,
&flipped_rhs,
Some(&input_carry),