chore(gpu): remove remaining async functions from the integer gpu api

This commit is contained in:
Agnes Leroy
2025-10-20 10:45:35 +02:00
committed by Agnès Leroy
parent 20b7b06ffb
commit 42644349ef
12 changed files with 1318 additions and 2600 deletions

View File

@@ -64,11 +64,7 @@ impl CudaServerKey {
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_mul_assign_async<T: CudaIntegerRadixCiphertext>(
pub fn unchecked_mul_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
@@ -78,68 +74,58 @@ impl CudaServerKey {
let is_boolean_left = ct_left.holds_boolean_value();
let is_boolean_right = ct_right.holds_boolean_value();
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_mul_assign(
streams,
ct_left.as_mut(),
is_boolean_left,
ct_right.as_ref(),
is_boolean_right,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension(),
d_bsk.input_lwe_dimension(),
d_bsk.polynomial_size(),
d_bsk.decomp_base_log(),
d_bsk.decomp_level_count(),
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
num_blocks,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_mul_assign(
streams,
ct_left.as_mut(),
is_boolean_left,
ct_right.as_ref(),
is_boolean_right,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.polynomial_size(),
d_multibit_bsk.decomp_base_log(),
d_multibit_bsk.decomp_level_count(),
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
num_blocks,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
pub fn unchecked_mul_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) {
unsafe {
self.unchecked_mul_assign_async(ct_left, ct_right, streams);
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_mul_assign(
streams,
ct_left.as_mut(),
is_boolean_left,
ct_right.as_ref(),
is_boolean_right,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension(),
d_bsk.input_lwe_dimension(),
d_bsk.polynomial_size(),
d_bsk.decomp_base_log(),
d_bsk.decomp_level_count(),
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
num_blocks,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_mul_assign(
streams,
ct_left.as_mut(),
is_boolean_left,
ct_right.as_ref(),
is_boolean_right,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension(),
d_multibit_bsk.input_lwe_dimension(),
d_multibit_bsk.polynomial_size(),
d_multibit_bsk.decomp_base_log(),
d_multibit_bsk.decomp_level_count(),
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
num_blocks,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
streams.synchronize();
}
/// Computes homomorphically a multiplication between two ciphertexts encrypting integer values.
@@ -198,11 +184,7 @@ impl CudaServerKey {
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn mul_assign_async<T: CudaIntegerRadixCiphertext>(
pub fn mul_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
@@ -233,22 +215,10 @@ impl CudaServerKey {
}
};
self.unchecked_mul_assign_async(lhs, rhs, streams);
self.unchecked_mul_assign(lhs, rhs, streams);
// Carries are cleaned internally in the mul algorithm
}
pub fn mul_assign<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &mut T,
ct_right: &T,
streams: &CudaStreams,
) {
unsafe {
self.mul_assign_async(ct_left, ct_right, streams);
}
streams.synchronize();
}
pub fn get_mul_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,

View File

@@ -51,11 +51,7 @@ impl CudaServerKey {
num_blocks: u64,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext {
let result = unsafe {
self.generate_oblivious_pseudo_random_unbounded_integer_async(seed, num_blocks, streams)
};
streams.synchronize();
result
self.generate_oblivious_pseudo_random_unbounded_integer(seed, num_blocks, streams)
}
/// Generates an encrypted `num_block` blocks unsigned integer
@@ -97,24 +93,20 @@ impl CudaServerKey {
num_blocks: u64,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext {
let result = unsafe {
let message_bits_count = self.message_modulus.0.ilog2() as u64;
let range_log_size = message_bits_count * num_blocks;
let message_bits_count = self.message_modulus.0.ilog2() as u64;
let range_log_size = message_bits_count * num_blocks;
assert!(
assert!(
random_bits_count <= range_log_size,
"The range asked for a random value (=[0, 2^{random_bits_count}[) does not fit in the available range [0, 2^{range_log_size}[",
);
self.generate_oblivious_pseudo_random_bounded_integer_async(
seed,
random_bits_count,
num_blocks,
streams,
)
};
streams.synchronize();
result
self.generate_oblivious_pseudo_random_bounded_integer(
seed,
random_bits_count,
num_blocks,
streams,
)
}
/// Generates an encrypted `num_block` blocks signed integer
@@ -150,11 +142,7 @@ impl CudaServerKey {
num_blocks: u64,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext {
let result = unsafe {
self.generate_oblivious_pseudo_random_unbounded_integer_async(seed, num_blocks, streams)
};
streams.synchronize();
result
self.generate_oblivious_pseudo_random_unbounded_integer(seed, num_blocks, streams)
}
/// Generates an encrypted `num_block` blocks signed integer
@@ -198,28 +186,24 @@ impl CudaServerKey {
num_blocks: u64,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext {
let result = unsafe {
let message_bits_count = self.message_modulus.0.ilog2() as u64;
let range_log_size = message_bits_count * num_blocks;
let message_bits_count = self.message_modulus.0.ilog2() as u64;
let range_log_size = message_bits_count * num_blocks;
#[allow(clippy::int_plus_one)]
{
assert!(
random_bits_count + 1 <= range_log_size,
"The range asked for a random value (=[0, 2^{}[) does not fit in the available range [-2^{}, 2^{}[",
random_bits_count, range_log_size - 1, range_log_size - 1,
);
}
#[allow(clippy::int_plus_one)]
{
assert!(
random_bits_count + 1 <= range_log_size,
"The range asked for a random value (=[0, 2^{}[) does not fit in the available range [-2^{}, 2^{}[",
random_bits_count, range_log_size - 1, range_log_size - 1,
);
}
self.generate_oblivious_pseudo_random_bounded_integer_async(
seed,
random_bits_count,
num_blocks,
streams,
)
};
streams.synchronize();
result
self.generate_oblivious_pseudo_random_bounded_integer(
seed,
random_bits_count,
num_blocks,
streams,
)
}
// Generic interface to generate a single-block oblivious pseudo-random integer.
@@ -246,22 +230,13 @@ impl CudaServerKey {
"The number of random bits asked for (={random_bits_count}) is bigger than carry_bits_count (={carry_bits_count}) + message_bits_count(={message_bits_count})",
);
let result = unsafe {
self.generate_oblivious_pseudo_random_bounded_integer_async(
seed,
random_bits_count,
1,
streams,
)
};
streams.synchronize();
result
self.generate_oblivious_pseudo_random_bounded_integer(seed, random_bits_count, 1, streams)
}
// Generic internal implementation for unbounded pseudo-random generation.
// It calls the core implementation with parameters for the unbounded case.
//
unsafe fn generate_oblivious_pseudo_random_unbounded_integer_async<T>(
fn generate_oblivious_pseudo_random_unbounded_integer<T>(
&self,
seed: Seed,
num_blocks: u64,
@@ -280,7 +255,7 @@ impl CudaServerKey {
return result;
}
self.generate_multiblocks_oblivious_pseudo_random_async(
self.generate_multiblocks_oblivious_pseudo_random(
result.as_mut(),
seed,
num_blocks,
@@ -294,7 +269,7 @@ impl CudaServerKey {
// Generic internal implementation for bounded pseudo-random generation.
// It calls the core implementation with parameters for the bounded case.
//
unsafe fn generate_oblivious_pseudo_random_bounded_integer_async<T>(
fn generate_oblivious_pseudo_random_bounded_integer<T>(
&self,
seed: Seed,
random_bits_count: u64,
@@ -318,7 +293,7 @@ impl CudaServerKey {
return result;
}
self.generate_multiblocks_oblivious_pseudo_random_async(
self.generate_multiblocks_oblivious_pseudo_random(
result.as_mut(),
seed,
num_active_blocks,
@@ -331,7 +306,7 @@ impl CudaServerKey {
// Core private implementation that calls the OPRF backend.
// This function contains the main logic for both bounded and unbounded generation.
//
unsafe fn generate_multiblocks_oblivious_pseudo_random_async(
fn generate_multiblocks_oblivious_pseudo_random(
&self,
result: &mut CudaRadixCiphertext,
seed: Seed,
@@ -369,57 +344,62 @@ impl CudaServerKey {
})
.collect();
let mut d_seeded_lwe_input = CudaVec::<u64>::new_async(h_seeded_lwe_list.len(), streams, 0);
d_seeded_lwe_input.copy_from_cpu_async(&h_seeded_lwe_list, streams, 0);
let mut d_seeded_lwe_input =
unsafe { CudaVec::<u64>::new_async(h_seeded_lwe_list.len(), streams, 0) };
unsafe {
d_seeded_lwe_input.copy_from_cpu_async(&h_seeded_lwe_list, streams, 0);
}
let message_bits_count = self.message_modulus.0.ilog2();
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_grouped_oprf(
streams,
result,
&d_seeded_lwe_input,
num_active_blocks as u32,
&d_bsk.d_vec,
d_bsk.input_lwe_dimension,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
message_bits_count,
total_random_bits as u32,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_bsk) => {
cuda_backend_grouped_oprf(
streams,
result,
&d_seeded_lwe_input,
num_active_blocks as u32,
&d_bsk.d_vec,
d_bsk.input_lwe_dimension,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
d_bsk.grouping_factor,
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
message_bits_count,
total_random_bits as u32,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_grouped_oprf(
streams,
result,
&d_seeded_lwe_input,
num_active_blocks as u32,
&d_bsk.d_vec,
d_bsk.input_lwe_dimension,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
self.message_modulus,
self.carry_modulus,
PBSType::Classical,
message_bits_count,
total_random_bits as u32,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_bsk) => {
cuda_backend_grouped_oprf(
streams,
result,
&d_seeded_lwe_input,
num_active_blocks as u32,
&d_bsk.d_vec,
d_bsk.input_lwe_dimension,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
d_bsk.grouping_factor,
self.message_modulus,
self.carry_modulus,
PBSType::MultiBit,
message_bits_count,
total_random_bits as u32,
None,
);
}
}
}
}

View File

@@ -9,11 +9,7 @@ use crate::integer::gpu::{
};
impl CudaServerKey {
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_rotate_right_assign_async<T>(
pub fn unchecked_rotate_right_assign<T>(
&self,
ct: &mut T,
rotate: &CudaUnsignedRadixCiphertext,
@@ -24,216 +20,79 @@ impl CudaServerKey {
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let is_signed = T::IS_SIGNED;
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_rotate_right_assign(
streams,
ct.as_mut(),
rotate.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_rotate_right_assign(
streams,
ct.as_mut(),
rotate.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_rotate_right_assign(
streams,
ct.as_mut(),
rotate.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_rotate_right_assign(
streams,
ct.as_mut(),
rotate.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_rotate_right_async<T>(
&self,
ct: &T,
rotate: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
{
let mut result = ct.duplicate(streams);
self.unchecked_rotate_right_assign_async(&mut result, rotate, streams);
result
}
pub fn unchecked_rotate_right<T>(
&self,
ct: &T,
rotate: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_rotate_right_async(ct, rotate, streams) };
streams.synchronize();
result
}
pub fn unchecked_rotate_right_assign<T>(
&self,
ct: &mut T,
rotate: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) where
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.unchecked_rotate_right_assign_async(ct, rotate, streams);
}
streams.synchronize();
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_rotate_left_assign_async<T>(
&self,
ct: &mut T,
rotate: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) where
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let is_signed = T::IS_SIGNED;
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_rotate_left_assign(
streams,
ct.as_mut(),
rotate.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_rotate_left_assign(
streams,
ct.as_mut(),
rotate.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_rotate_left_async<T>(
&self,
ct: &T,
rotate: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
{
let mut result = ct.duplicate(streams);
self.unchecked_rotate_left_assign_async(&mut result, rotate, streams);
result
}
pub fn unchecked_rotate_left<T>(
&self,
ct: &T,
rotate: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_rotate_left_async(ct, rotate, streams) };
streams.synchronize();
self.unchecked_rotate_right_assign(&mut result, rotate, streams);
result
}
@@ -245,17 +104,72 @@ impl CudaServerKey {
) where
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let is_signed = T::IS_SIGNED;
unsafe {
self.unchecked_rotate_left_assign_async(ct, rotate, streams);
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_rotate_left_assign(
streams,
ct.as_mut(),
rotate.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_rotate_left_assign(
streams,
ct.as_mut(),
rotate.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
streams.synchronize();
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn rotate_right_async<T>(
pub fn unchecked_rotate_left<T>(
&self,
ct: &T,
rotate: &CudaUnsignedRadixCiphertext,
@@ -264,82 +178,11 @@ impl CudaServerKey {
where
T: CudaIntegerRadixCiphertext,
{
let mut tmp_lhs: T;
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
let (lhs, rhs) = match (
ct.block_carries_are_empty(),
rotate.block_carries_are_empty(),
) {
(true, true) => (ct, rotate),
(true, false) => {
tmp_rhs = rotate.duplicate(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
(&tmp_lhs, rotate)
}
(false, false) => {
tmp_lhs = ct.duplicate(streams);
tmp_rhs = rotate.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
let mut result = lhs.duplicate(streams);
self.unchecked_rotate_right_assign_async(&mut result, rhs, streams);
let mut result = ct.duplicate(streams);
self.unchecked_rotate_left_assign(&mut result, rotate, streams);
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn rotate_right_assign_async<T>(
&self,
ct: &mut T,
rotate: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) where
T: CudaIntegerRadixCiphertext,
{
let mut tmp_lhs: T;
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
let (lhs, rhs) = match (
ct.block_carries_are_empty(),
rotate.block_carries_are_empty(),
) {
(true, true) => (ct, rotate),
(true, false) => {
tmp_rhs = rotate.duplicate(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
(&mut tmp_lhs, rotate)
}
(false, false) => {
tmp_lhs = ct.duplicate(streams);
tmp_rhs = rotate.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(&mut tmp_lhs, &tmp_rhs)
}
};
self.unchecked_rotate_right_assign_async(lhs, rhs, streams);
}
/// Computes homomorphically a right rotate by an encrypted amount
///
/// The result is returned as a new ciphertext.
@@ -385,36 +228,6 @@ impl CudaServerKey {
rotate: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.rotate_right_async(ct, rotate, streams) };
streams.synchronize();
result
}
pub fn rotate_right_assign<T>(
&self,
ct: &mut T,
rotate: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) where
T: CudaIntegerRadixCiphertext,
{
unsafe { self.rotate_right_assign_async(ct, rotate, streams) };
streams.synchronize();
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn rotate_left_async<T>(
&self,
ct: &T,
rotate: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
{
@@ -447,15 +260,11 @@ impl CudaServerKey {
};
let mut result = lhs.duplicate(streams);
self.unchecked_rotate_left_assign_async(&mut result, rhs, streams);
self.unchecked_rotate_right_assign(&mut result, rhs, streams);
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn rotate_left_assign_async<T>(
pub fn rotate_right_assign<T>(
&self,
ct: &mut T,
rotate: &CudaUnsignedRadixCiphertext,
@@ -491,7 +300,7 @@ impl CudaServerKey {
}
};
self.unchecked_rotate_left_assign_async(lhs, rhs, streams);
self.unchecked_rotate_right_assign(lhs, rhs, streams);
}
/// Computes homomorphically a right rotate by an encrypted amount
@@ -542,10 +351,39 @@ impl CudaServerKey {
where
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.rotate_left_async(ct, rotate, streams) };
streams.synchronize();
let mut tmp_lhs: T;
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
let (lhs, rhs) = match (
ct.block_carries_are_empty(),
rotate.block_carries_are_empty(),
) {
(true, true) => (ct, rotate),
(true, false) => {
tmp_rhs = rotate.duplicate(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
(&tmp_lhs, rotate)
}
(false, false) => {
tmp_lhs = ct.duplicate(streams);
tmp_rhs = rotate.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
let mut result = lhs.duplicate(streams);
self.unchecked_rotate_left_assign(&mut result, rhs, streams);
result
}
pub fn rotate_left_assign<T>(
&self,
ct: &mut T,
@@ -554,8 +392,35 @@ impl CudaServerKey {
) where
T: CudaIntegerRadixCiphertext,
{
unsafe { self.rotate_left_assign_async(ct, rotate, streams) };
streams.synchronize();
let mut tmp_lhs: T;
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
let (lhs, rhs) = match (
ct.block_carries_are_empty(),
rotate.block_carries_are_empty(),
) {
(true, true) => (ct, rotate),
(true, false) => {
tmp_rhs = rotate.duplicate(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
(&mut tmp_lhs, rotate)
}
(false, false) => {
tmp_lhs = ct.duplicate(streams);
tmp_rhs = rotate.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(&mut tmp_lhs, &tmp_rhs)
}
};
self.unchecked_rotate_left_assign(lhs, rhs, streams);
}
pub fn get_rotate_left_size_on_gpu<T: CudaIntegerRadixCiphertext>(

View File

@@ -69,11 +69,7 @@ impl CudaServerKey {
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_add_assign_async<Scalar, T>(
pub fn unchecked_scalar_add_assign<Scalar, T>(
&self,
ct: &mut T,
scalar: Scalar,
@@ -84,47 +80,33 @@ impl CudaServerKey {
{
if scalar != Scalar::ZERO {
let bits_in_message = self.message_modulus.0.ilog2();
let mut d_decomposed_scalar = CudaVec::<u64>::new_async(
ct.as_ref().d_blocks.lwe_ciphertext_count().0,
streams,
0,
);
let mut d_decomposed_scalar = unsafe {
CudaVec::<u64>::new_async(ct.as_ref().d_blocks.lwe_ciphertext_count().0, streams, 0)
};
let decomposed_scalar =
BlockDecomposer::with_early_stop_at_zero(scalar, bits_in_message)
.iter_as::<u64>()
.take(d_decomposed_scalar.len())
.collect::<Vec<_>>();
d_decomposed_scalar.copy_from_cpu_async(decomposed_scalar.as_slice(), streams, 0);
unsafe {
d_decomposed_scalar.copy_from_cpu_async(decomposed_scalar.as_slice(), streams, 0);
}
// If the scalar is decomposed using less than the number of blocks our ciphertext
// has, we just don't touch ciphertext's last blocks
cuda_backend_scalar_addition_assign(
streams,
ct.as_mut(),
&d_decomposed_scalar,
&decomposed_scalar,
decomposed_scalar.len() as u32,
self.message_modulus.0 as u32,
self.carry_modulus.0 as u32,
);
unsafe {
cuda_backend_scalar_addition_assign(
streams,
ct.as_mut(),
&d_decomposed_scalar,
&decomposed_scalar,
decomposed_scalar.len() as u32,
self.message_modulus.0 as u32,
self.carry_modulus.0 as u32,
);
}
}
}
pub fn unchecked_scalar_add_assign<Scalar, T>(
&self,
ct: &mut T,
scalar: Scalar,
streams: &CudaStreams,
) where
Scalar: DecomposableInto<u8> + CastInto<u64>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.unchecked_scalar_add_assign_async(ct, scalar, streams);
}
streams.synchronize();
}
/// Computes homomorphically an addition between a scalar and a ciphertext.
///
/// This function computes the operation without checking if it exceeds the capacity of the
@@ -180,16 +162,8 @@ impl CudaServerKey {
self.get_scalar_add_assign_size_on_gpu(ct, streams)
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_add_assign_async<Scalar, T>(
&self,
ct: &mut T,
scalar: Scalar,
streams: &CudaStreams,
) where
pub fn scalar_add_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)
where
Scalar: DecomposableInto<u8> + CastInto<u64>,
T: CudaIntegerRadixCiphertext,
{
@@ -197,7 +171,7 @@ impl CudaServerKey {
self.full_propagate_assign(ct, streams);
}
self.unchecked_scalar_add_assign_async(ct, scalar, streams);
self.unchecked_scalar_add_assign(ct, scalar, streams);
let _carry = self.propagate_single_carry_assign(ct, streams, None, OutputFlag::None);
}
@@ -290,17 +264,6 @@ impl CudaServerKey {
full_prop_mem.max(single_carry_mem)
}
pub fn scalar_add_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)
where
Scalar: DecomposableInto<u8> + CastInto<u64>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.scalar_add_assign_async(ct, scalar, streams);
}
streams.synchronize();
}
pub fn unsigned_overflowing_scalar_add<Scalar>(
&self,
ct_left: &CudaUnsignedRadixCiphertext,

View File

@@ -10,11 +10,7 @@ use crate::integer::gpu::{
};
impl CudaServerKey {
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_bitop_assign_async<Scalar, T>(
pub fn unchecked_scalar_bitop_assign<Scalar, T>(
&self,
ct: &mut T,
rhs: Scalar,
@@ -32,66 +28,67 @@ impl CudaServerKey {
.map(|x| x as u64)
.collect::<Vec<_>>();
let clear_blocks = CudaVec::from_cpu_async(&h_clear_blocks, streams, 0);
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_bitop_assign(
streams,
ct.as_mut(),
&clear_blocks,
&h_clear_blocks,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_bitop_assign(
streams,
ct.as_mut(),
&clear_blocks,
&h_clear_blocks,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
let clear_blocks = unsafe { CudaVec::from_cpu_async(&h_clear_blocks, streams, 0) };
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_bitop_assign(
streams,
ct.as_mut(),
&clear_blocks,
&h_clear_blocks,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
op,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_bitop_assign(
streams,
ct.as_mut(),
&clear_blocks,
&h_clear_blocks,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
op,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
}
@@ -120,10 +117,7 @@ impl CudaServerKey {
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarAnd, streams);
}
streams.synchronize();
self.unchecked_scalar_bitop_assign(ct, rhs, BitOpType::ScalarAnd, streams);
}
pub fn unchecked_scalar_bitor<Scalar, T>(&self, ct: &T, rhs: Scalar, streams: &CudaStreams) -> T
@@ -145,10 +139,7 @@ impl CudaServerKey {
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarOr, streams);
}
streams.synchronize();
self.unchecked_scalar_bitop_assign(ct, rhs, BitOpType::ScalarOr, streams);
}
pub fn unchecked_scalar_bitxor<Scalar, T>(
@@ -175,29 +166,7 @@ impl CudaServerKey {
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarXor, streams);
}
streams.synchronize();
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_bitand_assign_async<Scalar, T>(
&self,
ct: &mut T,
rhs: Scalar,
streams: &CudaStreams,
) where
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
}
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarAnd, streams);
self.unchecked_scalar_bitop_assign(ct, rhs, BitOpType::ScalarXor, streams);
}
pub fn scalar_bitand_assign<Scalar, T>(&self, ct: &mut T, rhs: Scalar, streams: &CudaStreams)
@@ -205,10 +174,10 @@ impl CudaServerKey {
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.scalar_bitand_assign_async(ct, rhs, streams);
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
}
streams.synchronize();
self.unchecked_scalar_bitop_assign(ct, rhs, BitOpType::ScalarAnd, streams);
}
pub fn scalar_bitand<Scalar, T>(&self, ct: &T, rhs: Scalar, streams: &CudaStreams) -> T
@@ -221,34 +190,15 @@ impl CudaServerKey {
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_bitor_assign_async<Scalar, T>(
&self,
ct: &mut T,
rhs: Scalar,
streams: &CudaStreams,
) where
pub fn scalar_bitor_assign<Scalar, T>(&self, ct: &mut T, rhs: Scalar, streams: &CudaStreams)
where
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
}
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarOr, streams);
}
pub fn scalar_bitor_assign<Scalar, T>(&self, ct: &mut T, rhs: Scalar, streams: &CudaStreams)
where
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.scalar_bitor_assign_async(ct, rhs, streams);
}
streams.synchronize();
self.unchecked_scalar_bitop_assign(ct, rhs, BitOpType::ScalarOr, streams);
}
pub fn scalar_bitor<Scalar, T>(&self, ct: &T, rhs: Scalar, streams: &CudaStreams) -> T
@@ -261,34 +211,15 @@ impl CudaServerKey {
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_bitxor_assign_async<Scalar, T>(
&self,
ct: &mut T,
rhs: Scalar,
streams: &CudaStreams,
) where
pub fn scalar_bitxor_assign<Scalar, T>(&self, ct: &mut T, rhs: Scalar, streams: &CudaStreams)
where
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
}
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarXor, streams);
}
pub fn scalar_bitxor_assign<Scalar, T>(&self, ct: &mut T, rhs: Scalar, streams: &CudaStreams)
where
Scalar: DecomposableInto<u8>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.scalar_bitxor_assign_async(ct, rhs, streams);
}
streams.synchronize();
self.unchecked_scalar_bitop_assign(ct, rhs, BitOpType::ScalarXor, streams);
}
pub fn scalar_bitxor<Scalar, T>(&self, ct: &T, rhs: Scalar, streams: &CudaStreams) -> T

View File

@@ -102,11 +102,7 @@ impl CudaServerKey {
None
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_signed_and_unsigned_scalar_comparison_async<Scalar, T>(
pub fn unchecked_signed_and_unsigned_scalar_comparison<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
@@ -155,7 +151,8 @@ impl CudaServerKey {
// as we will handle them separately.
scalar_blocks.truncate(ct.as_ref().d_blocks.lwe_ciphertext_count().0);
let d_scalar_blocks: CudaVec<u64> = CudaVec::from_cpu_async(&scalar_blocks, streams, 0);
let d_scalar_blocks: CudaVec<u64> =
unsafe { CudaVec::from_cpu_async(&scalar_blocks, streams, 0) };
let block = CudaLweCiphertextList::new(
ct.as_ref().d_blocks.lwe_dimension(),
@@ -171,78 +168,76 @@ impl CudaServerKey {
let mut result =
CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(block, ct_info));
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_comparison(
streams,
result.as_mut().as_mut(),
ct.as_ref(),
&d_scalar_blocks,
&scalar_blocks,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
scalar_blocks.len() as u32,
op,
signed_with_positive_scalar,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_comparison(
streams,
result.as_mut().as_mut(),
ct.as_ref(),
&d_scalar_blocks,
&scalar_blocks,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
scalar_blocks.len() as u32,
op,
signed_with_positive_scalar,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_comparison(
streams,
result.as_mut().as_mut(),
ct.as_ref(),
&d_scalar_blocks,
&scalar_blocks,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
scalar_blocks.len() as u32,
op,
signed_with_positive_scalar,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_comparison(
streams,
result.as_mut().as_mut(),
ct.as_ref(),
&d_scalar_blocks,
&scalar_blocks,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
scalar_blocks.len() as u32,
op,
signed_with_positive_scalar,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_comparison_async<Scalar, T>(
pub fn unchecked_scalar_comparison<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
@@ -292,25 +287,18 @@ impl CudaServerKey {
}
if scalar >= Scalar::ZERO {
self.unchecked_signed_and_unsigned_scalar_comparison_async(
ct, scalar, op, true, streams,
)
self.unchecked_signed_and_unsigned_scalar_comparison(ct, scalar, op, true, streams)
} else {
let scalar_as_trivial = self.create_trivial_radix(scalar, num_blocks, streams);
self.unchecked_comparison(ct, &scalar_as_trivial, op, streams)
}
} else {
// Unsigned
self.unchecked_signed_and_unsigned_scalar_comparison_async(
ct, scalar, op, false, streams,
)
self.unchecked_signed_and_unsigned_scalar_comparison(ct, scalar, op, false, streams)
}
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_minmax_async<Scalar, T>(
pub fn unchecked_scalar_minmax<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
@@ -328,72 +316,75 @@ impl CudaServerKey {
.iter_as::<u64>()
.collect::<Vec<_>>();
let d_scalar_blocks: CudaVec<u64> = CudaVec::from_cpu_async(&scalar_blocks, streams, 0);
let d_scalar_blocks: CudaVec<u64> =
unsafe { CudaVec::from_cpu_async(&scalar_blocks, streams, 0) };
let mut result = ct.duplicate(streams);
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_comparison(
streams,
result.as_mut(),
ct.as_ref(),
&d_scalar_blocks,
&scalar_blocks,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
scalar_blocks.len() as u32,
op,
T::IS_SIGNED,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_comparison(
streams,
result.as_mut(),
ct.as_ref(),
&d_scalar_blocks,
&scalar_blocks,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
scalar_blocks.len() as u32,
op,
T::IS_SIGNED,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_comparison(
streams,
result.as_mut(),
ct.as_ref(),
&d_scalar_blocks,
&scalar_blocks,
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
scalar_blocks.len() as u32,
op,
T::IS_SIGNED,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_comparison(
streams,
result.as_mut(),
ct.as_ref(),
&d_scalar_blocks,
&scalar_blocks,
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
scalar_blocks.len() as u32,
op,
T::IS_SIGNED,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
result
@@ -465,7 +456,6 @@ impl CudaServerKey {
}
}
}
streams.synchronize();
boolean_res
}
@@ -535,27 +525,9 @@ impl CudaServerKey {
}
}
}
streams.synchronize();
boolean_res
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_eq_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::EQ, streams)
}
pub fn unchecked_scalar_eq<Scalar, T>(
&self,
ct: &T,
@@ -566,35 +538,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let result = unsafe { self.unchecked_scalar_eq_async(ct, scalar, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_eq_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
ct
} else {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
&tmp_lhs
};
self.unchecked_scalar_eq_async(lhs, scalar, streams)
self.unchecked_scalar_comparison(ct, scalar, ComparisonType::EQ, streams)
}
/// Compares for equality 2 ciphertexts
@@ -643,25 +587,6 @@ impl CudaServerKey {
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let result = unsafe { self.scalar_eq_async(ct, scalar, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_ne_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
@@ -675,7 +600,7 @@ impl CudaServerKey {
&tmp_lhs
};
self.unchecked_scalar_ne_async(lhs, scalar, streams)
self.unchecked_scalar_eq(lhs, scalar, streams)
}
/// Compares for equality 2 ciphertexts
@@ -725,29 +650,19 @@ impl CudaServerKey {
streams: &CudaStreams,
) -> CudaBooleanBlock
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let result = unsafe { self.scalar_ne_async(ct, scalar, streams) };
streams.synchronize();
result
}
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
ct
} else {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
&tmp_lhs
};
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_ne_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::NE, streams)
self.unchecked_scalar_ne(lhs, scalar, streams)
}
pub fn unchecked_scalar_ne<Scalar, T>(
@@ -760,26 +675,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
Scalar: DecomposableInto<u64>,
{
let result = unsafe { self.unchecked_scalar_ne_async(ct, scalar, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_gt_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::GT, streams)
self.unchecked_scalar_comparison(ct, scalar, ComparisonType::NE, streams)
}
pub fn unchecked_scalar_gt<Scalar, T>(
@@ -792,26 +688,7 @@ impl CudaServerKey {
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_gt_async(ct, scalar, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_ge_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::GE, streams)
self.unchecked_scalar_comparison(ct, scalar, ComparisonType::GT, streams)
}
pub fn unchecked_scalar_ge<Scalar, T>(
@@ -824,26 +701,7 @@ impl CudaServerKey {
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_ge_async(ct, scalar, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_lt_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::LT, streams)
self.unchecked_scalar_comparison(ct, scalar, ComparisonType::GE, streams)
}
pub fn unchecked_scalar_lt<Scalar, T>(
@@ -856,26 +714,7 @@ impl CudaServerKey {
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_lt_async(ct, scalar, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_le_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::LE, streams)
self.unchecked_scalar_comparison(ct, scalar, ComparisonType::LT, streams)
}
pub fn unchecked_scalar_le<Scalar, T>(
@@ -888,34 +727,7 @@ impl CudaServerKey {
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_le_async(ct, scalar, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_gt_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
ct
} else {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
&tmp_lhs
};
self.unchecked_scalar_gt_async(lhs, scalar, streams)
self.unchecked_scalar_comparison(ct, scalar, ComparisonType::LE, streams)
}
pub fn scalar_gt<Scalar, T>(
@@ -924,25 +736,6 @@ impl CudaServerKey {
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_gt_async(ct, scalar, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_ge_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
@@ -956,7 +749,7 @@ impl CudaServerKey {
&tmp_lhs
};
self.unchecked_scalar_ge_async(lhs, scalar, streams)
self.unchecked_scalar_gt(lhs, scalar, streams)
}
pub fn scalar_ge<Scalar, T>(
@@ -965,25 +758,6 @@ impl CudaServerKey {
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_ge_async(ct, scalar, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_lt_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
@@ -997,7 +771,7 @@ impl CudaServerKey {
&tmp_lhs
};
self.unchecked_scalar_lt_async(lhs, scalar, streams)
self.unchecked_scalar_ge(lhs, scalar, streams)
}
pub fn scalar_lt<Scalar, T>(
@@ -1010,15 +784,19 @@ impl CudaServerKey {
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_lt_async(ct, scalar, streams) };
streams.synchronize();
result
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
ct
} else {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
&tmp_lhs
};
self.unchecked_scalar_lt(lhs, scalar, streams)
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_le_async<Scalar, T>(
pub fn scalar_le<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
@@ -1037,22 +815,7 @@ impl CudaServerKey {
&tmp_lhs
};
self.unchecked_scalar_le_async(lhs, scalar, streams)
}
pub fn scalar_le<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> CudaBooleanBlock
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_le_async(ct, scalar, streams) };
streams.synchronize();
result
self.unchecked_scalar_le(lhs, scalar, streams)
}
pub fn get_scalar_eq_size_on_gpu<T: CudaIntegerRadixCiphertext>(
@@ -1103,23 +866,6 @@ impl CudaServerKey {
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::LE, streams)
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_max_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> T
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
self.unchecked_scalar_minmax_async(ct, scalar, ComparisonType::MAX, streams)
}
pub fn unchecked_scalar_max<Scalar, T>(
&self,
ct: &T,
@@ -1130,26 +876,7 @@ impl CudaServerKey {
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_max_async(ct, scalar, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_min_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> T
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
self.unchecked_scalar_minmax_async(ct, scalar, ComparisonType::MIN, streams)
self.unchecked_scalar_minmax(ct, scalar, ComparisonType::MAX, streams)
}
pub fn unchecked_scalar_min<Scalar, T>(
@@ -1162,35 +889,7 @@ impl CudaServerKey {
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_min_async(ct, scalar, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_max_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> T
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
ct
} else {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
&tmp_lhs
};
self.unchecked_scalar_max_async(lhs, scalar, streams)
self.unchecked_scalar_minmax(ct, scalar, ComparisonType::MIN, streams)
}
pub fn scalar_max<Scalar, T>(&self, ct: &T, scalar: Scalar, streams: &CudaStreams) -> T
@@ -1198,21 +897,19 @@ impl CudaServerKey {
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_max_async(ct, scalar, streams) };
streams.synchronize();
result
let mut tmp_lhs;
let lhs = if ct.block_carries_are_empty() {
ct
} else {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
&tmp_lhs
};
self.unchecked_scalar_max(lhs, scalar, streams)
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_min_async<Scalar, T>(
&self,
ct: &T,
scalar: Scalar,
streams: &CudaStreams,
) -> T
pub fn scalar_min<Scalar, T>(&self, ct: &T, scalar: Scalar, streams: &CudaStreams) -> T
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
@@ -1226,18 +923,9 @@ impl CudaServerKey {
&tmp_lhs
};
self.unchecked_scalar_min_async(lhs, scalar, streams)
self.unchecked_scalar_min(lhs, scalar, streams)
}
pub fn scalar_min<Scalar, T>(&self, ct: &T, scalar: Scalar, streams: &CudaStreams) -> T
where
Scalar: DecomposableInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_min_async(ct, scalar, streams) };
streams.synchronize();
result
}
pub fn get_scalar_max_size_on_gpu<T: CudaIntegerRadixCiphertext>(
&self,
ct_left: &T,

View File

@@ -67,24 +67,6 @@ impl CudaServerKey {
divisor: Scalar,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext
where
Scalar: Reciprocable,
{
let res = unsafe { self.unchecked_scalar_div_async(numerator, divisor, streams) };
streams.synchronize();
res
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub unsafe fn unchecked_scalar_div_async<Scalar>(
&self,
numerator: &CudaUnsignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext
where
Scalar: Reciprocable,
{
@@ -104,48 +86,50 @@ impl CudaServerKey {
let mut quotient = numerator.duplicate(streams);
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_unsigned_scalar_div_assign(
streams,
quotient.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_unsigned_scalar_div_assign(
streams,
quotient.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_multibit_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_unsigned_scalar_div_assign(
streams,
quotient.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_unsigned_scalar_div_assign(
streams,
quotient.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_multibit_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
}
}
}
@@ -194,24 +178,6 @@ impl CudaServerKey {
divisor: Scalar,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext
where
Scalar: Reciprocable,
{
let res = unsafe { self.unchecked_scalar_div_async(numerator, divisor, streams) };
streams.synchronize();
res
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub unsafe fn scalar_div_async<Scalar>(
&self,
numerator: &CudaUnsignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext
where
Scalar: Reciprocable,
{
@@ -224,7 +190,7 @@ impl CudaServerKey {
&tmp_numerator
};
self.unchecked_scalar_div_async(numerator, divisor, streams)
self.unchecked_scalar_div(numerator, divisor, streams)
}
pub fn unchecked_scalar_div_rem<Scalar>(
@@ -233,27 +199,6 @@ impl CudaServerKey {
divisor: Scalar,
streams: &CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext)
where
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
{
let (quotient, remainder) =
unsafe { self.unchecked_scalar_div_rem_async(numerator, divisor, streams) };
streams.synchronize();
(quotient, remainder)
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub unsafe fn unchecked_scalar_div_rem_async<Scalar>(
&self,
numerator: &CudaUnsignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext)
where
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
{
@@ -276,50 +221,52 @@ impl CudaServerKey {
streams,
);
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_unsigned_scalar_div_rem(
streams,
quotient.as_mut(),
remainder.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_unsigned_scalar_div_rem(
streams,
quotient.as_mut(),
remainder.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_multibit_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_unsigned_scalar_div_rem(
streams,
quotient.as_mut(),
remainder.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_unsigned_scalar_div_rem(
streams,
quotient.as_mut(),
remainder.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_multibit_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
}
}
}
@@ -370,24 +317,6 @@ impl CudaServerKey {
divisor: Scalar,
streams: &CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext)
where
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
{
let res = unsafe { self.unchecked_scalar_div_rem_async(numerator, divisor, streams) };
streams.synchronize();
res
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub unsafe fn scalar_div_rem_async<Scalar>(
&self,
numerator: &CudaUnsignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext)
where
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
{
@@ -400,7 +329,7 @@ impl CudaServerKey {
&tmp_numerator
};
self.unchecked_scalar_div_rem_async(numerator, divisor, streams)
self.unchecked_scalar_div_rem(numerator, divisor, streams)
}
pub fn unchecked_scalar_rem<Scalar>(
@@ -412,26 +341,7 @@ impl CudaServerKey {
where
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
{
let res = unsafe { self.unchecked_scalar_rem_async(numerator, divisor, streams) };
streams.synchronize();
res
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub unsafe fn unchecked_scalar_rem_async<Scalar>(
&self,
numerator: &CudaUnsignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext
where
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
{
self.unchecked_scalar_div_rem_async(numerator, divisor, streams)
.1
self.unchecked_scalar_div_rem(numerator, divisor, streams).1
}
/// Computes homomorphically a division between a ciphertext and a scalar.
@@ -473,24 +383,6 @@ impl CudaServerKey {
divisor: Scalar,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext
where
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
{
let res = unsafe { self.unchecked_scalar_rem_async(numerator, divisor, streams) };
streams.synchronize();
res
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub unsafe fn scalar_rem_async<Scalar>(
&self,
numerator: &CudaUnsignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaUnsignedRadixCiphertext
where
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
{
@@ -503,7 +395,7 @@ impl CudaServerKey {
&tmp_numerator
};
self.unchecked_scalar_rem_async(numerator, divisor, streams)
self.unchecked_scalar_rem(numerator, divisor, streams)
}
pub fn unchecked_signed_scalar_div<Scalar>(
@@ -512,25 +404,6 @@ impl CudaServerKey {
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let res = unsafe { self.unchecked_signed_scalar_div_async(numerator, divisor, streams) };
streams.synchronize();
res
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub unsafe fn unchecked_signed_scalar_div_async<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
@@ -548,48 +421,50 @@ impl CudaServerKey {
let mut quotient: CudaSignedRadixCiphertext = numerator.duplicate(streams);
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_signed_scalar_div_assign(
streams,
quotient.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_signed_scalar_div_assign(
streams,
quotient.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_multibit_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_signed_scalar_div_assign(
streams,
quotient.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_signed_scalar_div_assign(
streams,
quotient.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_multibit_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
}
}
}
@@ -638,25 +513,6 @@ impl CudaServerKey {
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let res = unsafe { self.signed_scalar_div_async(numerator, divisor, streams) };
streams.synchronize();
res
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub unsafe fn signed_scalar_div_async<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
@@ -670,7 +526,7 @@ impl CudaServerKey {
&tmp_numerator
};
self.unchecked_signed_scalar_div_async(numerator, divisor, streams)
self.unchecked_signed_scalar_div(numerator, divisor, streams)
}
pub fn unchecked_signed_scalar_div_rem<Scalar>(
@@ -679,26 +535,6 @@ impl CudaServerKey {
divisor: Scalar,
streams: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let res =
unsafe { self.unchecked_signed_scalar_div_rem_async(numerator, divisor, streams) };
streams.synchronize();
res
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub unsafe fn unchecked_signed_scalar_div_rem_async<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
@@ -720,53 +556,54 @@ impl CudaServerKey {
streams,
);
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_signed_scalar_div_rem_assign(
streams,
quotient.as_mut(),
remainder.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_signed_scalar_div_rem_assign(
streams,
quotient.as_mut(),
remainder.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_multibit_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_signed_scalar_div_rem_assign(
streams,
quotient.as_mut(),
remainder.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
d_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
LweBskGroupingFactor(0),
PBSType::Classical,
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_signed_scalar_div_rem_assign(
streams,
quotient.as_mut(),
remainder.as_mut(),
divisor,
&self.key_switching_key.d_vec,
&d_multibit_bsk.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
d_multibit_bsk.input_lwe_dimension,
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.grouping_factor,
PBSType::MultiBit,
None,
);
}
}
}
(quotient, remainder)
}
@@ -812,25 +649,6 @@ impl CudaServerKey {
divisor: Scalar,
streams: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let res = unsafe { self.signed_scalar_div_rem_async(numerator, divisor, streams) };
streams.synchronize();
res
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub unsafe fn signed_scalar_div_rem_async<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
@@ -844,7 +662,7 @@ impl CudaServerKey {
&tmp_numerator
};
self.unchecked_signed_scalar_div_rem_async(numerator, divisor, streams)
self.unchecked_signed_scalar_div_rem(numerator, divisor, streams)
}
pub fn unchecked_signed_scalar_rem<Scalar>(
@@ -853,27 +671,6 @@ impl CudaServerKey {
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let remainder =
unsafe { self.unchecked_signed_scalar_rem_async(numerator, divisor, streams) };
streams.synchronize();
remainder
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub unsafe fn unchecked_signed_scalar_rem_async<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
@@ -922,25 +719,6 @@ impl CudaServerKey {
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
{
let res = unsafe { self.signed_scalar_rem_async(numerator, divisor, streams) };
streams.synchronize();
res
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronized
pub unsafe fn signed_scalar_rem_async<Scalar>(
&self,
numerator: &CudaSignedRadixCiphertext,
divisor: Scalar,
streams: &CudaStreams,
) -> CudaSignedRadixCiphertext
where
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
@@ -954,7 +732,7 @@ impl CudaServerKey {
&tmp_numerator
};
self.unchecked_signed_scalar_rem_async(numerator, divisor, streams)
self.unchecked_signed_scalar_rem(numerator, divisor, streams)
}
pub fn get_scalar_div_size_on_gpu<Scalar>(

View File

@@ -63,11 +63,7 @@ impl CudaServerKey {
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_mul_assign_async<Scalar, T>(
pub fn unchecked_scalar_mul_assign<Scalar, T>(
&self,
ct: &mut T,
scalar: Scalar,
@@ -77,7 +73,7 @@ impl CudaServerKey {
T: CudaIntegerRadixCiphertext,
{
if scalar == Scalar::ZERO {
ct.as_mut().d_blocks.0.d_vec.memset_async(0, streams, 0);
ct.as_mut().d_blocks.0.d_vec.memset(0, streams, 0);
return;
}
@@ -90,7 +86,7 @@ impl CudaServerKey {
if scalar.is_power_of_two() {
// Shifting cost one bivariate PBS so its always faster
// than multiplying
self.unchecked_scalar_left_shift_assign_async(ct, scalar.ilog2() as u64, streams);
self.unchecked_scalar_left_shift_assign(ct, scalar.ilog2() as u64, streams);
return;
}
let msg_bits = self.message_modulus.0.ilog2() as usize;
@@ -112,73 +108,60 @@ impl CudaServerKey {
return;
}
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_mul(
streams,
ct.as_mut(),
decomposed_scalar.as_slice(),
has_at_least_one_set.as_slice(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
d_bsk.decomp_base_log,
d_bsk.decomp_level_count,
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
decomposed_scalar.len() as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_mul(
streams,
ct.as_mut(),
decomposed_scalar.as_slice(),
has_at_least_one_set.as_slice(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.decomp_level_count,
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
decomposed_scalar.len() as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
pub fn unchecked_scalar_mul_assign<Scalar, T>(
&self,
ct: &mut T,
scalar: Scalar,
streams: &CudaStreams,
) where
Scalar: ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.unchecked_scalar_mul_assign_async(ct, scalar, streams);
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_mul(
streams,
ct.as_mut(),
decomposed_scalar.as_slice(),
has_at_least_one_set.as_slice(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
d_bsk.decomp_base_log,
d_bsk.decomp_level_count,
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
decomposed_scalar.len() as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_mul(
streams,
ct.as_mut(),
decomposed_scalar.as_slice(),
has_at_least_one_set.as_slice(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
d_multibit_bsk.decomp_base_log,
d_multibit_bsk.decomp_level_count,
self.key_switching_key.decomposition_base_log(),
self.key_switching_key.decomposition_level_count(),
decomposed_scalar.len() as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
streams.synchronize();
}
/// Computes homomorphically a multiplication between a scalar and a ciphertext.
@@ -227,16 +210,8 @@ impl CudaServerKey {
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_mul_assign_async<Scalar, T>(
&self,
ct: &mut T,
scalar: Scalar,
streams: &CudaStreams,
) where
pub fn scalar_mul_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)
where
Scalar: ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
T: CudaIntegerRadixCiphertext,
{
@@ -244,19 +219,9 @@ impl CudaServerKey {
self.full_propagate_assign(ct, streams);
}
self.unchecked_scalar_mul_assign_async(ct, scalar, streams);
self.unchecked_scalar_mul_assign(ct, scalar, streams);
}
pub fn scalar_mul_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)
where
Scalar: ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.scalar_mul_assign_async(ct, scalar, streams);
}
streams.synchronize();
}
pub fn get_scalar_mul_size_on_gpu<Scalar, T>(
&self,
ct: &T,

View File

@@ -11,140 +11,23 @@ use crate::integer::gpu::{
};
impl CudaServerKey {
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_rotate_left_async<Scalar, T>(
&self,
ct: &T,
n: Scalar,
stream: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let mut result = ct.duplicate(stream);
self.unchecked_scalar_rotate_left_assign_async(&mut result, n, stream);
result
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_rotate_left_assign_async<Scalar, T>(
&self,
ct: &mut T,
n: Scalar,
stream: &CudaStreams,
) where
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_rotate_left_assign(
stream,
ct.as_mut(),
u32::cast_from(n),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_rotate_left_assign(
stream,
ct.as_mut(),
u32::cast_from(n),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
pub fn unchecked_scalar_rotate_left<Scalar, T>(
&self,
ct: &T,
n: Scalar,
stream: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let result = unsafe { self.unchecked_scalar_rotate_left_async(ct, n, stream) };
stream.synchronize();
result
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_rotate_right_async<Scalar, T>(
&self,
ct: &T,
n: Scalar,
stream: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let mut result = ct.duplicate(stream);
self.unchecked_scalar_rotate_right_assign_async(&mut result, n, stream);
self.unchecked_scalar_rotate_left_assign(&mut result, n, stream);
result
}
/// # Safety
///
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until stream is synchronised
pub unsafe fn unchecked_scalar_rotate_right_assign_async<Scalar, T>(
pub fn unchecked_scalar_rotate_left_assign<Scalar, T>(
&self,
ct: &mut T,
n: Scalar,
@@ -155,60 +38,62 @@ impl CudaServerKey {
u32: CastFrom<Scalar>,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_rotate_right_assign(
stream,
ct.as_mut(),
u32::cast_from(n),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_rotate_right_assign(
stream,
ct.as_mut(),
u32::cast_from(n),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_rotate_left_assign(
stream,
ct.as_mut(),
u32::cast_from(n),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_rotate_left_assign(
stream,
ct.as_mut(),
u32::cast_from(n),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
}
@@ -224,11 +109,82 @@ impl CudaServerKey {
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let result = unsafe { self.unchecked_scalar_rotate_right_async(ct, n, stream) };
stream.synchronize();
let mut result = ct.duplicate(stream);
self.unchecked_scalar_rotate_right_assign(&mut result, n, stream);
result
}
pub fn unchecked_scalar_rotate_right_assign<Scalar, T>(
&self,
ct: &mut T,
n: Scalar,
stream: &CudaStreams,
) where
T: CudaIntegerRadixCiphertext,
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_rotate_right_assign(
stream,
ct.as_mut(),
u32::cast_from(n),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_rotate_right_assign(
stream,
ct.as_mut(),
u32::cast_from(n),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
}
pub fn scalar_rotate_left_assign<Scalar, T>(&self, ct: &mut T, n: Scalar, stream: &CudaStreams)
where
T: CudaIntegerRadixCiphertext,
@@ -239,8 +195,7 @@ impl CudaServerKey {
self.full_propagate_assign(ct, stream);
}
unsafe { self.unchecked_scalar_rotate_left_assign_async(ct, n, stream) };
stream.synchronize();
self.unchecked_scalar_rotate_left_assign(ct, n, stream);
}
pub fn scalar_rotate_right_assign<Scalar, T>(&self, ct: &mut T, n: Scalar, stream: &CudaStreams)
@@ -253,8 +208,7 @@ impl CudaServerKey {
self.full_propagate_assign(ct, stream);
}
unsafe { self.unchecked_scalar_rotate_right_assign_async(ct, n, stream) };
stream.synchronize();
self.unchecked_scalar_rotate_right_assign(ct, n, stream);
}
pub fn scalar_rotate_left<Scalar, T>(&self, ct: &T, shift: Scalar, stream: &CudaStreams) -> T

View File

@@ -13,100 +13,6 @@ use crate::integer::gpu::{
};
impl CudaServerKey {
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_left_shift_async<Scalar, T>(
&self,
ct: &T,
shift: Scalar,
streams: &CudaStreams,
) -> T
where
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
let mut result = ct.duplicate(streams);
self.unchecked_scalar_left_shift_assign_async(&mut result, shift, streams);
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_left_shift_assign_async<Scalar, T>(
&self,
ct: &mut T,
shift: Scalar,
streams: &CudaStreams,
) where
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_left_shift_assign(
streams,
ct.as_mut(),
u32::cast_from(shift),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_left_shift_assign(
streams,
ct.as_mut(),
u32::cast_from(shift),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
/// Computes homomorphically a left shift by a scalar.
///
/// The result is returned as a new ciphertext.
@@ -149,41 +55,17 @@ impl CudaServerKey {
shift: Scalar,
streams: &CudaStreams,
) -> T
where
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_left_shift_async(ct, shift, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_right_shift_async<Scalar, T>(
&self,
ct: &T,
shift: Scalar,
streams: &CudaStreams,
) -> T
where
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
let mut result = ct.duplicate(streams);
self.unchecked_scalar_right_shift_assign_async(&mut result, shift, streams);
self.unchecked_scalar_left_shift_assign(&mut result, shift, streams);
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_right_shift_assign_async<Scalar, T>(
pub fn unchecked_scalar_left_shift_assign<Scalar, T>(
&self,
ct: &mut T,
shift: Scalar,
@@ -195,65 +77,10 @@ impl CudaServerKey {
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
if T::IS_SIGNED {
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_arithmetic_right_shift_assign(
streams,
ct.as_mut(),
u32::cast_from(shift),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_arithmetic_right_shift_assign(
streams,
ct.as_mut(),
u32::cast_from(shift),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
} else {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_logical_right_shift_assign(
cuda_backend_unchecked_scalar_left_shift_assign(
streams,
ct.as_mut(),
u32::cast_from(shift),
@@ -280,7 +107,7 @@ impl CudaServerKey {
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_logical_right_shift_assign(
cuda_backend_unchecked_scalar_left_shift_assign(
streams,
ct.as_mut(),
u32::cast_from(shift),
@@ -357,16 +184,141 @@ impl CudaServerKey {
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_scalar_right_shift_async(ct, shift, streams) };
streams.synchronize();
let mut result = ct.duplicate(streams);
self.unchecked_scalar_right_shift_assign(&mut result, shift, streams);
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_right_shift_assign_async<Scalar, T>(
pub fn unchecked_scalar_right_shift_assign<Scalar, T>(
&self,
ct: &mut T,
shift: Scalar,
streams: &CudaStreams,
) where
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
unsafe {
if T::IS_SIGNED {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_arithmetic_right_shift_assign(
streams,
ct.as_mut(),
u32::cast_from(shift),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_arithmetic_right_shift_assign(
streams,
ct.as_mut(),
u32::cast_from(shift),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
} else {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_logical_right_shift_assign(
streams,
ct.as_mut(),
u32::cast_from(shift),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_logical_right_shift_assign(
streams,
ct.as_mut(),
u32::cast_from(shift),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
}
}
pub fn scalar_right_shift_assign<Scalar, T>(
&self,
ct: &mut T,
shift: Scalar,
@@ -380,27 +332,7 @@ impl CudaServerKey {
self.full_propagate_assign(ct, streams);
}
self.unchecked_scalar_right_shift_assign_async(ct, shift, streams);
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_right_shift_async<Scalar, T>(
&self,
ct: &T,
shift: Scalar,
streams: &CudaStreams,
) -> T
where
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
let mut result = ct.duplicate(streams);
self.scalar_right_shift_assign_async(&mut result, shift, streams);
result
self.unchecked_scalar_right_shift_assign(ct, shift, streams);
}
/// Computes homomorphically a right shift by a scalar.
@@ -440,54 +372,13 @@ impl CudaServerKey {
/// assert_eq!(dec_result, msg >> shift);
/// ```
pub fn scalar_right_shift<Scalar, T>(&self, ct: &T, shift: Scalar, streams: &CudaStreams) -> T
where
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_right_shift_async(ct, shift, streams) };
streams.synchronize();
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_left_shift_assign_async<Scalar, T>(
&self,
ct: &mut T,
shift: Scalar,
streams: &CudaStreams,
) where
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
}
self.unchecked_scalar_left_shift_assign_async(ct, shift, streams);
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_left_shift_async<Scalar, T>(
&self,
ct: &T,
shift: Scalar,
streams: &CudaStreams,
) -> T
where
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
let mut result = ct.duplicate(streams);
self.scalar_left_shift_assign_async(&mut result, shift, streams);
self.scalar_right_shift_assign(&mut result, shift, streams);
result
}
@@ -533,8 +424,8 @@ impl CudaServerKey {
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.scalar_left_shift_async(ct, shift, streams) };
streams.synchronize();
let mut result = ct.duplicate(streams);
self.scalar_left_shift_assign(&mut result, shift, streams);
result
}
@@ -552,104 +443,7 @@ impl CudaServerKey {
self.full_propagate_assign(ct, streams);
}
unsafe {
self.unchecked_scalar_left_shift_assign_async(ct, shift, streams);
};
streams.synchronize();
}
pub fn scalar_right_shift_assign<Scalar, T>(
&self,
ct: &mut T,
shift: Scalar,
streams: &CudaStreams,
) where
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
if !ct.block_carries_are_empty() {
self.full_propagate_assign(ct, streams);
}
unsafe {
self.unchecked_scalar_right_shift_assign_async(ct, shift, streams);
};
streams.synchronize();
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_right_shift_logical_assign_async<Scalar, T>(
&self,
ct: &mut T,
shift: Scalar,
streams: &CudaStreams,
) where
Scalar: CastFrom<u32>,
u32: CastFrom<Scalar>,
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_scalar_logical_right_shift_assign(
streams,
ct.as_mut(),
u32::cast_from(shift),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_scalar_logical_right_shift_assign(
streams,
ct.as_mut(),
u32::cast_from(shift),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
self.unchecked_scalar_left_shift_assign(ct, shift, streams);
}
pub fn get_scalar_left_shift_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64

View File

@@ -61,23 +61,6 @@ impl CudaServerKey {
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_scalar_sub_assign_async<Scalar, T>(
&self,
ct: &mut T,
scalar: Scalar,
streams: &CudaStreams,
) where
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation + CastInto<u64>,
T: CudaIntegerRadixCiphertext,
{
let negated_scalar = scalar.twos_complement_negation();
self.unchecked_scalar_add_assign_async(ct, negated_scalar, streams);
}
pub fn unchecked_scalar_sub_assign<Scalar, T>(
&self,
ct: &mut T,
@@ -87,10 +70,8 @@ impl CudaServerKey {
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation + CastInto<u64>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.unchecked_scalar_sub_assign_async(ct, scalar, streams);
}
streams.synchronize();
let negated_scalar = scalar.twos_complement_negation();
self.unchecked_scalar_add_assign(ct, negated_scalar, streams);
}
/// Computes homomorphically a subtraction between a ciphertext and a scalar.
@@ -148,16 +129,8 @@ impl CudaServerKey {
self.get_scalar_sub_assign_size_on_gpu(ct, streams)
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn scalar_sub_assign_async<Scalar, T>(
&self,
ct: &mut T,
scalar: Scalar,
streams: &CudaStreams,
) where
pub fn scalar_sub_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)
where
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation + CastInto<u64>,
T: CudaIntegerRadixCiphertext,
{
@@ -165,21 +138,10 @@ impl CudaServerKey {
self.full_propagate_assign(ct, streams);
}
self.unchecked_scalar_sub_assign_async(ct, scalar, streams);
self.unchecked_scalar_sub_assign(ct, scalar, streams);
let _carry = self.propagate_single_carry_assign(ct, streams, None, OutputFlag::None);
}
pub fn scalar_sub_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)
where
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation + CastInto<u64>,
T: CudaIntegerRadixCiphertext,
{
unsafe {
self.scalar_sub_assign_async(ct, scalar, streams);
}
streams.synchronize();
}
pub fn get_scalar_sub_assign_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
where
T: CudaIntegerRadixCiphertext,

View File

@@ -9,11 +9,7 @@ use crate::integer::gpu::{
};
impl CudaServerKey {
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_right_shift_assign_async<T>(
pub fn unchecked_right_shift_assign<T>(
&self,
ct: &mut T,
shift: &CudaUnsignedRadixCiphertext,
@@ -24,214 +20,79 @@ impl CudaServerKey {
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let is_signed = T::IS_SIGNED;
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_right_shift_assign(
streams,
ct.as_mut(),
shift.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_right_shift_assign(
streams,
ct.as_mut(),
shift.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_right_shift_assign(
streams,
ct.as_mut(),
shift.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_right_shift_assign(
streams,
ct.as_mut(),
shift.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_right_shift_async<T>(
&self,
ct: &T,
shift: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
{
let mut result = ct.duplicate(streams);
self.unchecked_right_shift_assign_async(&mut result, shift, streams);
result
}
pub fn unchecked_right_shift<T>(
&self,
ct: &T,
shift: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_right_shift_async(ct, shift, streams) };
streams.synchronize();
result
}
pub fn unchecked_right_shift_assign<T>(
&self,
ct: &mut T,
shift: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) where
T: CudaIntegerRadixCiphertext,
{
unsafe { self.unchecked_right_shift_assign_async(ct, shift, streams) };
streams.synchronize();
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_left_shift_assign_async<T>(
&self,
ct: &mut T,
shift: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) where
T: CudaIntegerRadixCiphertext,
{
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let is_signed = T::IS_SIGNED;
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_left_shift_assign(
streams,
ct.as_mut(),
shift.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_left_shift_assign(
streams,
ct.as_mut(),
shift.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn unchecked_left_shift_async<T>(
&self,
ct: &T,
shift: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
{
let mut result = ct.duplicate(streams);
self.unchecked_left_shift_assign_async(&mut result, shift, streams);
result
}
pub fn unchecked_left_shift<T>(
&self,
ct: &T,
shift: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.unchecked_left_shift_async(ct, shift, streams) };
streams.synchronize();
self.unchecked_right_shift_assign(&mut result, shift, streams);
result
}
@@ -243,15 +104,72 @@ impl CudaServerKey {
) where
T: CudaIntegerRadixCiphertext,
{
unsafe { self.unchecked_left_shift_assign_async(ct, shift, streams) };
streams.synchronize();
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
let is_signed = T::IS_SIGNED;
unsafe {
match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
cuda_backend_unchecked_left_shift_assign(
streams,
ct.as_mut(),
shift.as_ref(),
&d_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_bsk.glwe_dimension,
d_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_bsk.decomp_level_count,
d_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::Classical,
LweBskGroupingFactor(0),
d_bsk.ms_noise_reduction_configuration.as_ref(),
);
}
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
cuda_backend_unchecked_left_shift_assign(
streams,
ct.as_mut(),
shift.as_ref(),
&d_multibit_bsk.d_vec,
&self.key_switching_key.d_vec,
self.message_modulus,
self.carry_modulus,
d_multibit_bsk.glwe_dimension,
d_multibit_bsk.polynomial_size,
self.key_switching_key
.input_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key
.output_key_lwe_size()
.to_lwe_dimension(),
self.key_switching_key.decomposition_level_count(),
self.key_switching_key.decomposition_base_log(),
d_multibit_bsk.decomp_level_count,
d_multibit_bsk.decomp_base_log,
lwe_ciphertext_count.0 as u32,
is_signed,
PBSType::MultiBit,
d_multibit_bsk.grouping_factor,
None,
);
}
}
}
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn right_shift_async<T>(
pub fn unchecked_left_shift<T>(
&self,
ct: &T,
shift: &CudaUnsignedRadixCiphertext,
@@ -260,82 +178,11 @@ impl CudaServerKey {
where
T: CudaIntegerRadixCiphertext,
{
let mut tmp_lhs: T;
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
let (lhs, rhs) = match (
ct.block_carries_are_empty(),
shift.block_carries_are_empty(),
) {
(true, true) => (ct, shift),
(true, false) => {
tmp_rhs = shift.duplicate(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
(&tmp_lhs, shift)
}
(false, false) => {
tmp_lhs = ct.duplicate(streams);
tmp_rhs = shift.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
let mut result = lhs.duplicate(streams);
self.unchecked_right_shift_assign_async(&mut result, rhs, streams);
let mut result = ct.duplicate(streams);
self.unchecked_left_shift_assign(&mut result, shift, streams);
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn right_shift_assign_async<T>(
&self,
ct: &mut T,
shift: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) where
T: CudaIntegerRadixCiphertext,
{
let mut tmp_lhs: T;
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
let (lhs, rhs) = match (
ct.block_carries_are_empty(),
shift.block_carries_are_empty(),
) {
(true, true) => (ct, shift),
(true, false) => {
tmp_rhs = shift.duplicate(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
(&mut tmp_lhs, shift)
}
(false, false) => {
tmp_lhs = ct.duplicate(streams);
tmp_rhs = shift.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(&mut tmp_lhs, &tmp_rhs)
}
};
self.unchecked_right_shift_assign_async(lhs, rhs, streams);
}
/// Computes homomorphically a right shift by an encrypted amount
///
/// The result is returned as a new ciphertext.
@@ -380,36 +227,6 @@ impl CudaServerKey {
shift: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.right_shift_async(ct, shift, streams) };
streams.synchronize();
result
}
pub fn right_shift_assign<T>(
&self,
ct: &mut T,
shift: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) where
T: CudaIntegerRadixCiphertext,
{
unsafe { self.right_shift_assign_async(ct, shift, streams) };
streams.synchronize();
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn left_shift_async<T>(
&self,
ct: &T,
shift: &CudaUnsignedRadixCiphertext,
streams: &CudaStreams,
) -> T
where
T: CudaIntegerRadixCiphertext,
{
@@ -442,15 +259,11 @@ impl CudaServerKey {
};
let mut result = lhs.duplicate(streams);
self.unchecked_left_shift_assign_async(&mut result, rhs, streams);
self.unchecked_right_shift_assign(&mut result, rhs, streams);
result
}
/// # Safety
///
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
/// not be dropped until streams is synchronised
pub unsafe fn left_shift_assign_async<T>(
pub fn right_shift_assign<T>(
&self,
ct: &mut T,
shift: &CudaUnsignedRadixCiphertext,
@@ -486,7 +299,7 @@ impl CudaServerKey {
}
};
self.unchecked_left_shift_assign_async(lhs, rhs, streams);
self.unchecked_right_shift_assign(lhs, rhs, streams);
}
/// Computes homomorphically a left shift by an encrypted amount
@@ -536,8 +349,36 @@ impl CudaServerKey {
where
T: CudaIntegerRadixCiphertext,
{
let result = unsafe { self.left_shift_async(ct, shift, streams) };
streams.synchronize();
let mut tmp_lhs: T;
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
let (lhs, rhs) = match (
ct.block_carries_are_empty(),
shift.block_carries_are_empty(),
) {
(true, true) => (ct, shift),
(true, false) => {
tmp_rhs = shift.duplicate(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
(&tmp_lhs, shift)
}
(false, false) => {
tmp_lhs = ct.duplicate(streams);
tmp_rhs = shift.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(&tmp_lhs, &tmp_rhs)
}
};
let mut result = lhs.duplicate(streams);
self.unchecked_left_shift_assign(&mut result, rhs, streams);
result
}
@@ -549,8 +390,35 @@ impl CudaServerKey {
) where
T: CudaIntegerRadixCiphertext,
{
unsafe { self.left_shift_assign_async(ct, shift, streams) };
streams.synchronize();
let mut tmp_lhs: T;
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
let (lhs, rhs) = match (
ct.block_carries_are_empty(),
shift.block_carries_are_empty(),
) {
(true, true) => (ct, shift),
(true, false) => {
tmp_rhs = shift.duplicate(streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(ct, &tmp_rhs)
}
(false, true) => {
tmp_lhs = ct.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
(&mut tmp_lhs, shift)
}
(false, false) => {
tmp_lhs = ct.duplicate(streams);
tmp_rhs = shift.duplicate(streams);
self.full_propagate_assign(&mut tmp_lhs, streams);
self.full_propagate_assign(&mut tmp_rhs, streams);
(&mut tmp_lhs, &tmp_rhs)
}
};
self.unchecked_left_shift_assign(lhs, rhs, streams);
}
pub fn get_left_shift_size_on_gpu<T: CudaIntegerRadixCiphertext>(