mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
chore(gpu): remove remaining async functions from the integer gpu api
This commit is contained in:
@@ -64,11 +64,7 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_mul_assign_async<T: CudaIntegerRadixCiphertext>(
|
||||
pub fn unchecked_mul_assign<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &mut T,
|
||||
ct_right: &T,
|
||||
@@ -78,68 +74,58 @@ impl CudaServerKey {
|
||||
|
||||
let is_boolean_left = ct_left.holds_boolean_value();
|
||||
let is_boolean_right = ct_right.holds_boolean_value();
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_mul_assign(
|
||||
streams,
|
||||
ct_left.as_mut(),
|
||||
is_boolean_left,
|
||||
ct_right.as_ref(),
|
||||
is_boolean_right,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
d_bsk.decomp_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
num_blocks,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_mul_assign(
|
||||
streams,
|
||||
ct_left.as_mut(),
|
||||
is_boolean_left,
|
||||
ct_right.as_ref(),
|
||||
is_boolean_right,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
num_blocks,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unchecked_mul_assign<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &mut T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) {
|
||||
unsafe {
|
||||
self.unchecked_mul_assign_async(ct_left, ct_right, streams);
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_mul_assign(
|
||||
streams,
|
||||
ct_left.as_mut(),
|
||||
is_boolean_left,
|
||||
ct_right.as_ref(),
|
||||
is_boolean_right,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension(),
|
||||
d_bsk.input_lwe_dimension(),
|
||||
d_bsk.polynomial_size(),
|
||||
d_bsk.decomp_base_log(),
|
||||
d_bsk.decomp_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
num_blocks,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_mul_assign(
|
||||
streams,
|
||||
ct_left.as_mut(),
|
||||
is_boolean_left,
|
||||
ct_right.as_ref(),
|
||||
is_boolean_right,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension(),
|
||||
d_multibit_bsk.input_lwe_dimension(),
|
||||
d_multibit_bsk.polynomial_size(),
|
||||
d_multibit_bsk.decomp_base_log(),
|
||||
d_multibit_bsk.decomp_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
num_blocks,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
/// Computes homomorphically a multiplication between two ciphertexts encrypting integer values.
|
||||
@@ -198,11 +184,7 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn mul_assign_async<T: CudaIntegerRadixCiphertext>(
|
||||
pub fn mul_assign<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &mut T,
|
||||
ct_right: &T,
|
||||
@@ -233,22 +215,10 @@ impl CudaServerKey {
|
||||
}
|
||||
};
|
||||
|
||||
self.unchecked_mul_assign_async(lhs, rhs, streams);
|
||||
self.unchecked_mul_assign(lhs, rhs, streams);
|
||||
// Carries are cleaned internally in the mul algorithm
|
||||
}
|
||||
|
||||
pub fn mul_assign<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &mut T,
|
||||
ct_right: &T,
|
||||
streams: &CudaStreams,
|
||||
) {
|
||||
unsafe {
|
||||
self.mul_assign_async(ct_left, ct_right, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
pub fn get_mul_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
|
||||
@@ -51,11 +51,7 @@ impl CudaServerKey {
|
||||
num_blocks: u64,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext {
|
||||
let result = unsafe {
|
||||
self.generate_oblivious_pseudo_random_unbounded_integer_async(seed, num_blocks, streams)
|
||||
};
|
||||
streams.synchronize();
|
||||
result
|
||||
self.generate_oblivious_pseudo_random_unbounded_integer(seed, num_blocks, streams)
|
||||
}
|
||||
|
||||
/// Generates an encrypted `num_block` blocks unsigned integer
|
||||
@@ -97,24 +93,20 @@ impl CudaServerKey {
|
||||
num_blocks: u64,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext {
|
||||
let result = unsafe {
|
||||
let message_bits_count = self.message_modulus.0.ilog2() as u64;
|
||||
let range_log_size = message_bits_count * num_blocks;
|
||||
let message_bits_count = self.message_modulus.0.ilog2() as u64;
|
||||
let range_log_size = message_bits_count * num_blocks;
|
||||
|
||||
assert!(
|
||||
assert!(
|
||||
random_bits_count <= range_log_size,
|
||||
"The range asked for a random value (=[0, 2^{random_bits_count}[) does not fit in the available range [0, 2^{range_log_size}[",
|
||||
);
|
||||
|
||||
self.generate_oblivious_pseudo_random_bounded_integer_async(
|
||||
seed,
|
||||
random_bits_count,
|
||||
num_blocks,
|
||||
streams,
|
||||
)
|
||||
};
|
||||
streams.synchronize();
|
||||
result
|
||||
self.generate_oblivious_pseudo_random_bounded_integer(
|
||||
seed,
|
||||
random_bits_count,
|
||||
num_blocks,
|
||||
streams,
|
||||
)
|
||||
}
|
||||
|
||||
/// Generates an encrypted `num_block` blocks signed integer
|
||||
@@ -150,11 +142,7 @@ impl CudaServerKey {
|
||||
num_blocks: u64,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext {
|
||||
let result = unsafe {
|
||||
self.generate_oblivious_pseudo_random_unbounded_integer_async(seed, num_blocks, streams)
|
||||
};
|
||||
streams.synchronize();
|
||||
result
|
||||
self.generate_oblivious_pseudo_random_unbounded_integer(seed, num_blocks, streams)
|
||||
}
|
||||
|
||||
/// Generates an encrypted `num_block` blocks signed integer
|
||||
@@ -198,28 +186,24 @@ impl CudaServerKey {
|
||||
num_blocks: u64,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext {
|
||||
let result = unsafe {
|
||||
let message_bits_count = self.message_modulus.0.ilog2() as u64;
|
||||
let range_log_size = message_bits_count * num_blocks;
|
||||
let message_bits_count = self.message_modulus.0.ilog2() as u64;
|
||||
let range_log_size = message_bits_count * num_blocks;
|
||||
|
||||
#[allow(clippy::int_plus_one)]
|
||||
{
|
||||
assert!(
|
||||
random_bits_count + 1 <= range_log_size,
|
||||
"The range asked for a random value (=[0, 2^{}[) does not fit in the available range [-2^{}, 2^{}[",
|
||||
random_bits_count, range_log_size - 1, range_log_size - 1,
|
||||
);
|
||||
}
|
||||
#[allow(clippy::int_plus_one)]
|
||||
{
|
||||
assert!(
|
||||
random_bits_count + 1 <= range_log_size,
|
||||
"The range asked for a random value (=[0, 2^{}[) does not fit in the available range [-2^{}, 2^{}[",
|
||||
random_bits_count, range_log_size - 1, range_log_size - 1,
|
||||
);
|
||||
}
|
||||
|
||||
self.generate_oblivious_pseudo_random_bounded_integer_async(
|
||||
seed,
|
||||
random_bits_count,
|
||||
num_blocks,
|
||||
streams,
|
||||
)
|
||||
};
|
||||
streams.synchronize();
|
||||
result
|
||||
self.generate_oblivious_pseudo_random_bounded_integer(
|
||||
seed,
|
||||
random_bits_count,
|
||||
num_blocks,
|
||||
streams,
|
||||
)
|
||||
}
|
||||
|
||||
// Generic interface to generate a single-block oblivious pseudo-random integer.
|
||||
@@ -246,22 +230,13 @@ impl CudaServerKey {
|
||||
"The number of random bits asked for (={random_bits_count}) is bigger than carry_bits_count (={carry_bits_count}) + message_bits_count(={message_bits_count})",
|
||||
);
|
||||
|
||||
let result = unsafe {
|
||||
self.generate_oblivious_pseudo_random_bounded_integer_async(
|
||||
seed,
|
||||
random_bits_count,
|
||||
1,
|
||||
streams,
|
||||
)
|
||||
};
|
||||
streams.synchronize();
|
||||
result
|
||||
self.generate_oblivious_pseudo_random_bounded_integer(seed, random_bits_count, 1, streams)
|
||||
}
|
||||
|
||||
// Generic internal implementation for unbounded pseudo-random generation.
|
||||
// It calls the core implementation with parameters for the unbounded case.
|
||||
//
|
||||
unsafe fn generate_oblivious_pseudo_random_unbounded_integer_async<T>(
|
||||
fn generate_oblivious_pseudo_random_unbounded_integer<T>(
|
||||
&self,
|
||||
seed: Seed,
|
||||
num_blocks: u64,
|
||||
@@ -280,7 +255,7 @@ impl CudaServerKey {
|
||||
return result;
|
||||
}
|
||||
|
||||
self.generate_multiblocks_oblivious_pseudo_random_async(
|
||||
self.generate_multiblocks_oblivious_pseudo_random(
|
||||
result.as_mut(),
|
||||
seed,
|
||||
num_blocks,
|
||||
@@ -294,7 +269,7 @@ impl CudaServerKey {
|
||||
// Generic internal implementation for bounded pseudo-random generation.
|
||||
// It calls the core implementation with parameters for the bounded case.
|
||||
//
|
||||
unsafe fn generate_oblivious_pseudo_random_bounded_integer_async<T>(
|
||||
fn generate_oblivious_pseudo_random_bounded_integer<T>(
|
||||
&self,
|
||||
seed: Seed,
|
||||
random_bits_count: u64,
|
||||
@@ -318,7 +293,7 @@ impl CudaServerKey {
|
||||
return result;
|
||||
}
|
||||
|
||||
self.generate_multiblocks_oblivious_pseudo_random_async(
|
||||
self.generate_multiblocks_oblivious_pseudo_random(
|
||||
result.as_mut(),
|
||||
seed,
|
||||
num_active_blocks,
|
||||
@@ -331,7 +306,7 @@ impl CudaServerKey {
|
||||
// Core private implementation that calls the OPRF backend.
|
||||
// This function contains the main logic for both bounded and unbounded generation.
|
||||
//
|
||||
unsafe fn generate_multiblocks_oblivious_pseudo_random_async(
|
||||
fn generate_multiblocks_oblivious_pseudo_random(
|
||||
&self,
|
||||
result: &mut CudaRadixCiphertext,
|
||||
seed: Seed,
|
||||
@@ -369,57 +344,62 @@ impl CudaServerKey {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut d_seeded_lwe_input = CudaVec::<u64>::new_async(h_seeded_lwe_list.len(), streams, 0);
|
||||
d_seeded_lwe_input.copy_from_cpu_async(&h_seeded_lwe_list, streams, 0);
|
||||
let mut d_seeded_lwe_input =
|
||||
unsafe { CudaVec::<u64>::new_async(h_seeded_lwe_list.len(), streams, 0) };
|
||||
unsafe {
|
||||
d_seeded_lwe_input.copy_from_cpu_async(&h_seeded_lwe_list, streams, 0);
|
||||
}
|
||||
|
||||
let message_bits_count = self.message_modulus.0.ilog2();
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_grouped_oprf(
|
||||
streams,
|
||||
result,
|
||||
&d_seeded_lwe_input,
|
||||
num_active_blocks as u32,
|
||||
&d_bsk.d_vec,
|
||||
d_bsk.input_lwe_dimension,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
message_bits_count,
|
||||
total_random_bits as u32,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_bsk) => {
|
||||
cuda_backend_grouped_oprf(
|
||||
streams,
|
||||
result,
|
||||
&d_seeded_lwe_input,
|
||||
num_active_blocks as u32,
|
||||
&d_bsk.d_vec,
|
||||
d_bsk.input_lwe_dimension,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
d_bsk.grouping_factor,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
message_bits_count,
|
||||
total_random_bits as u32,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_grouped_oprf(
|
||||
streams,
|
||||
result,
|
||||
&d_seeded_lwe_input,
|
||||
num_active_blocks as u32,
|
||||
&d_bsk.d_vec,
|
||||
d_bsk.input_lwe_dimension,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
message_bits_count,
|
||||
total_random_bits as u32,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_bsk) => {
|
||||
cuda_backend_grouped_oprf(
|
||||
streams,
|
||||
result,
|
||||
&d_seeded_lwe_input,
|
||||
num_active_blocks as u32,
|
||||
&d_bsk.d_vec,
|
||||
d_bsk.input_lwe_dimension,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
d_bsk.grouping_factor,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
message_bits_count,
|
||||
total_random_bits as u32,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,11 +9,7 @@ use crate::integer::gpu::{
|
||||
};
|
||||
|
||||
impl CudaServerKey {
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_rotate_right_assign_async<T>(
|
||||
pub fn unchecked_rotate_right_assign<T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
rotate: &CudaUnsignedRadixCiphertext,
|
||||
@@ -24,216 +20,79 @@ impl CudaServerKey {
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let is_signed = T::IS_SIGNED;
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_rotate_right_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
rotate.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_rotate_right_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
rotate.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_rotate_right_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
rotate.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_rotate_right_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
rotate.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_rotate_right_async<T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
rotate: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut result = ct.duplicate(streams);
|
||||
self.unchecked_rotate_right_assign_async(&mut result, rotate, streams);
|
||||
result
|
||||
}
|
||||
|
||||
pub fn unchecked_rotate_right<T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
rotate: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_rotate_right_async(ct, rotate, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
pub fn unchecked_rotate_right_assign<T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
rotate: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.unchecked_rotate_right_assign_async(ct, rotate, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_rotate_left_assign_async<T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
rotate: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let is_signed = T::IS_SIGNED;
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_rotate_left_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
rotate.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_rotate_left_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
rotate.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_rotate_left_async<T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
rotate: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut result = ct.duplicate(streams);
|
||||
self.unchecked_rotate_left_assign_async(&mut result, rotate, streams);
|
||||
result
|
||||
}
|
||||
|
||||
pub fn unchecked_rotate_left<T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
rotate: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_rotate_left_async(ct, rotate, streams) };
|
||||
streams.synchronize();
|
||||
self.unchecked_rotate_right_assign(&mut result, rotate, streams);
|
||||
result
|
||||
}
|
||||
|
||||
@@ -245,17 +104,72 @@ impl CudaServerKey {
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let is_signed = T::IS_SIGNED;
|
||||
|
||||
unsafe {
|
||||
self.unchecked_rotate_left_assign_async(ct, rotate, streams);
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_rotate_left_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
rotate.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_rotate_left_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
rotate.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn rotate_right_async<T>(
|
||||
pub fn unchecked_rotate_left<T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
rotate: &CudaUnsignedRadixCiphertext,
|
||||
@@ -264,82 +178,11 @@ impl CudaServerKey {
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut tmp_lhs: T;
|
||||
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
|
||||
|
||||
let (lhs, rhs) = match (
|
||||
ct.block_carries_are_empty(),
|
||||
rotate.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => (ct, rotate),
|
||||
(true, false) => {
|
||||
tmp_rhs = rotate.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(ct, &tmp_rhs)
|
||||
}
|
||||
(false, true) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
(&tmp_lhs, rotate)
|
||||
}
|
||||
(false, false) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
tmp_rhs = rotate.duplicate(streams);
|
||||
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(&tmp_lhs, &tmp_rhs)
|
||||
}
|
||||
};
|
||||
|
||||
let mut result = lhs.duplicate(streams);
|
||||
self.unchecked_rotate_right_assign_async(&mut result, rhs, streams);
|
||||
let mut result = ct.duplicate(streams);
|
||||
self.unchecked_rotate_left_assign(&mut result, rotate, streams);
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn rotate_right_assign_async<T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
rotate: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut tmp_lhs: T;
|
||||
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
|
||||
|
||||
let (lhs, rhs) = match (
|
||||
ct.block_carries_are_empty(),
|
||||
rotate.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => (ct, rotate),
|
||||
(true, false) => {
|
||||
tmp_rhs = rotate.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(ct, &tmp_rhs)
|
||||
}
|
||||
(false, true) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
(&mut tmp_lhs, rotate)
|
||||
}
|
||||
(false, false) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
tmp_rhs = rotate.duplicate(streams);
|
||||
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(&mut tmp_lhs, &tmp_rhs)
|
||||
}
|
||||
};
|
||||
|
||||
self.unchecked_rotate_right_assign_async(lhs, rhs, streams);
|
||||
}
|
||||
|
||||
/// Computes homomorphically a right rotate by an encrypted amount
|
||||
///
|
||||
/// The result is returned as a new ciphertext.
|
||||
@@ -385,36 +228,6 @@ impl CudaServerKey {
|
||||
rotate: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.rotate_right_async(ct, rotate, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
pub fn rotate_right_assign<T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
rotate: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe { self.rotate_right_assign_async(ct, rotate, streams) };
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn rotate_left_async<T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
rotate: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
@@ -447,15 +260,11 @@ impl CudaServerKey {
|
||||
};
|
||||
|
||||
let mut result = lhs.duplicate(streams);
|
||||
self.unchecked_rotate_left_assign_async(&mut result, rhs, streams);
|
||||
self.unchecked_rotate_right_assign(&mut result, rhs, streams);
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn rotate_left_assign_async<T>(
|
||||
pub fn rotate_right_assign<T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
rotate: &CudaUnsignedRadixCiphertext,
|
||||
@@ -491,7 +300,7 @@ impl CudaServerKey {
|
||||
}
|
||||
};
|
||||
|
||||
self.unchecked_rotate_left_assign_async(lhs, rhs, streams);
|
||||
self.unchecked_rotate_right_assign(lhs, rhs, streams);
|
||||
}
|
||||
|
||||
/// Computes homomorphically a right rotate by an encrypted amount
|
||||
@@ -542,10 +351,39 @@ impl CudaServerKey {
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.rotate_left_async(ct, rotate, streams) };
|
||||
streams.synchronize();
|
||||
let mut tmp_lhs: T;
|
||||
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
|
||||
|
||||
let (lhs, rhs) = match (
|
||||
ct.block_carries_are_empty(),
|
||||
rotate.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => (ct, rotate),
|
||||
(true, false) => {
|
||||
tmp_rhs = rotate.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(ct, &tmp_rhs)
|
||||
}
|
||||
(false, true) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
(&tmp_lhs, rotate)
|
||||
}
|
||||
(false, false) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
tmp_rhs = rotate.duplicate(streams);
|
||||
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(&tmp_lhs, &tmp_rhs)
|
||||
}
|
||||
};
|
||||
|
||||
let mut result = lhs.duplicate(streams);
|
||||
self.unchecked_rotate_left_assign(&mut result, rhs, streams);
|
||||
result
|
||||
}
|
||||
|
||||
pub fn rotate_left_assign<T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
@@ -554,8 +392,35 @@ impl CudaServerKey {
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe { self.rotate_left_assign_async(ct, rotate, streams) };
|
||||
streams.synchronize();
|
||||
let mut tmp_lhs: T;
|
||||
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
|
||||
|
||||
let (lhs, rhs) = match (
|
||||
ct.block_carries_are_empty(),
|
||||
rotate.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => (ct, rotate),
|
||||
(true, false) => {
|
||||
tmp_rhs = rotate.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(ct, &tmp_rhs)
|
||||
}
|
||||
(false, true) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
(&mut tmp_lhs, rotate)
|
||||
}
|
||||
(false, false) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
tmp_rhs = rotate.duplicate(streams);
|
||||
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(&mut tmp_lhs, &tmp_rhs)
|
||||
}
|
||||
};
|
||||
|
||||
self.unchecked_rotate_left_assign(lhs, rhs, streams);
|
||||
}
|
||||
|
||||
pub fn get_rotate_left_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
|
||||
@@ -69,11 +69,7 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_add_assign_async<Scalar, T>(
|
||||
pub fn unchecked_scalar_add_assign<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
scalar: Scalar,
|
||||
@@ -84,47 +80,33 @@ impl CudaServerKey {
|
||||
{
|
||||
if scalar != Scalar::ZERO {
|
||||
let bits_in_message = self.message_modulus.0.ilog2();
|
||||
let mut d_decomposed_scalar = CudaVec::<u64>::new_async(
|
||||
ct.as_ref().d_blocks.lwe_ciphertext_count().0,
|
||||
streams,
|
||||
0,
|
||||
);
|
||||
let mut d_decomposed_scalar = unsafe {
|
||||
CudaVec::<u64>::new_async(ct.as_ref().d_blocks.lwe_ciphertext_count().0, streams, 0)
|
||||
};
|
||||
let decomposed_scalar =
|
||||
BlockDecomposer::with_early_stop_at_zero(scalar, bits_in_message)
|
||||
.iter_as::<u64>()
|
||||
.take(d_decomposed_scalar.len())
|
||||
.collect::<Vec<_>>();
|
||||
d_decomposed_scalar.copy_from_cpu_async(decomposed_scalar.as_slice(), streams, 0);
|
||||
|
||||
unsafe {
|
||||
d_decomposed_scalar.copy_from_cpu_async(decomposed_scalar.as_slice(), streams, 0);
|
||||
}
|
||||
// If the scalar is decomposed using less than the number of blocks our ciphertext
|
||||
// has, we just don't touch ciphertext's last blocks
|
||||
cuda_backend_scalar_addition_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
&d_decomposed_scalar,
|
||||
&decomposed_scalar,
|
||||
decomposed_scalar.len() as u32,
|
||||
self.message_modulus.0 as u32,
|
||||
self.carry_modulus.0 as u32,
|
||||
);
|
||||
unsafe {
|
||||
cuda_backend_scalar_addition_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
&d_decomposed_scalar,
|
||||
&decomposed_scalar,
|
||||
decomposed_scalar.len() as u32,
|
||||
self.message_modulus.0 as u32,
|
||||
self.carry_modulus.0 as u32,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_add_assign<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
Scalar: DecomposableInto<u8> + CastInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.unchecked_scalar_add_assign_async(ct, scalar, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
/// Computes homomorphically an addition between a scalar and a ciphertext.
|
||||
///
|
||||
/// This function computes the operation without checking if it exceeds the capacity of the
|
||||
@@ -180,16 +162,8 @@ impl CudaServerKey {
|
||||
self.get_scalar_add_assign_size_on_gpu(ct, streams)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_add_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
pub fn scalar_add_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)
|
||||
where
|
||||
Scalar: DecomposableInto<u8> + CastInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
@@ -197,7 +171,7 @@ impl CudaServerKey {
|
||||
self.full_propagate_assign(ct, streams);
|
||||
}
|
||||
|
||||
self.unchecked_scalar_add_assign_async(ct, scalar, streams);
|
||||
self.unchecked_scalar_add_assign(ct, scalar, streams);
|
||||
let _carry = self.propagate_single_carry_assign(ct, streams, None, OutputFlag::None);
|
||||
}
|
||||
|
||||
@@ -290,17 +264,6 @@ impl CudaServerKey {
|
||||
full_prop_mem.max(single_carry_mem)
|
||||
}
|
||||
|
||||
pub fn scalar_add_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)
|
||||
where
|
||||
Scalar: DecomposableInto<u8> + CastInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.scalar_add_assign_async(ct, scalar, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
pub fn unsigned_overflowing_scalar_add<Scalar>(
|
||||
&self,
|
||||
ct_left: &CudaUnsignedRadixCiphertext,
|
||||
|
||||
@@ -10,11 +10,7 @@ use crate::integer::gpu::{
|
||||
};
|
||||
|
||||
impl CudaServerKey {
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_bitop_assign_async<Scalar, T>(
|
||||
pub fn unchecked_scalar_bitop_assign<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
rhs: Scalar,
|
||||
@@ -32,66 +28,67 @@ impl CudaServerKey {
|
||||
.map(|x| x as u64)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let clear_blocks = CudaVec::from_cpu_async(&h_clear_blocks, streams, 0);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_bitop_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
&clear_blocks,
|
||||
&h_clear_blocks,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_bitop_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
&clear_blocks,
|
||||
&h_clear_blocks,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
let clear_blocks = unsafe { CudaVec::from_cpu_async(&h_clear_blocks, streams, 0) };
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_bitop_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
&clear_blocks,
|
||||
&h_clear_blocks,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_bitop_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
&clear_blocks,
|
||||
&h_clear_blocks,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
op,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,10 +117,7 @@ impl CudaServerKey {
|
||||
Scalar: DecomposableInto<u8>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarAnd, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
self.unchecked_scalar_bitop_assign(ct, rhs, BitOpType::ScalarAnd, streams);
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_bitor<Scalar, T>(&self, ct: &T, rhs: Scalar, streams: &CudaStreams) -> T
|
||||
@@ -145,10 +139,7 @@ impl CudaServerKey {
|
||||
Scalar: DecomposableInto<u8>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarOr, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
self.unchecked_scalar_bitop_assign(ct, rhs, BitOpType::ScalarOr, streams);
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_bitxor<Scalar, T>(
|
||||
@@ -175,29 +166,7 @@ impl CudaServerKey {
|
||||
Scalar: DecomposableInto<u8>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarXor, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_bitand_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
rhs: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
Scalar: DecomposableInto<u8>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
if !ct.block_carries_are_empty() {
|
||||
self.full_propagate_assign(ct, streams);
|
||||
}
|
||||
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarAnd, streams);
|
||||
self.unchecked_scalar_bitop_assign(ct, rhs, BitOpType::ScalarXor, streams);
|
||||
}
|
||||
|
||||
pub fn scalar_bitand_assign<Scalar, T>(&self, ct: &mut T, rhs: Scalar, streams: &CudaStreams)
|
||||
@@ -205,10 +174,10 @@ impl CudaServerKey {
|
||||
Scalar: DecomposableInto<u8>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.scalar_bitand_assign_async(ct, rhs, streams);
|
||||
if !ct.block_carries_are_empty() {
|
||||
self.full_propagate_assign(ct, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
self.unchecked_scalar_bitop_assign(ct, rhs, BitOpType::ScalarAnd, streams);
|
||||
}
|
||||
|
||||
pub fn scalar_bitand<Scalar, T>(&self, ct: &T, rhs: Scalar, streams: &CudaStreams) -> T
|
||||
@@ -221,34 +190,15 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_bitor_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
rhs: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
pub fn scalar_bitor_assign<Scalar, T>(&self, ct: &mut T, rhs: Scalar, streams: &CudaStreams)
|
||||
where
|
||||
Scalar: DecomposableInto<u8>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
if !ct.block_carries_are_empty() {
|
||||
self.full_propagate_assign(ct, streams);
|
||||
}
|
||||
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarOr, streams);
|
||||
}
|
||||
|
||||
pub fn scalar_bitor_assign<Scalar, T>(&self, ct: &mut T, rhs: Scalar, streams: &CudaStreams)
|
||||
where
|
||||
Scalar: DecomposableInto<u8>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.scalar_bitor_assign_async(ct, rhs, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
self.unchecked_scalar_bitop_assign(ct, rhs, BitOpType::ScalarOr, streams);
|
||||
}
|
||||
|
||||
pub fn scalar_bitor<Scalar, T>(&self, ct: &T, rhs: Scalar, streams: &CudaStreams) -> T
|
||||
@@ -261,34 +211,15 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_bitxor_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
rhs: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
pub fn scalar_bitxor_assign<Scalar, T>(&self, ct: &mut T, rhs: Scalar, streams: &CudaStreams)
|
||||
where
|
||||
Scalar: DecomposableInto<u8>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
if !ct.block_carries_are_empty() {
|
||||
self.full_propagate_assign(ct, streams);
|
||||
}
|
||||
self.unchecked_scalar_bitop_assign_async(ct, rhs, BitOpType::ScalarXor, streams);
|
||||
}
|
||||
|
||||
pub fn scalar_bitxor_assign<Scalar, T>(&self, ct: &mut T, rhs: Scalar, streams: &CudaStreams)
|
||||
where
|
||||
Scalar: DecomposableInto<u8>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.scalar_bitxor_assign_async(ct, rhs, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
self.unchecked_scalar_bitop_assign(ct, rhs, BitOpType::ScalarXor, streams);
|
||||
}
|
||||
|
||||
pub fn scalar_bitxor<Scalar, T>(&self, ct: &T, rhs: Scalar, streams: &CudaStreams) -> T
|
||||
|
||||
@@ -102,11 +102,7 @@ impl CudaServerKey {
|
||||
None
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_signed_and_unsigned_scalar_comparison_async<Scalar, T>(
|
||||
pub fn unchecked_signed_and_unsigned_scalar_comparison<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
@@ -155,7 +151,8 @@ impl CudaServerKey {
|
||||
// as we will handle them separately.
|
||||
scalar_blocks.truncate(ct.as_ref().d_blocks.lwe_ciphertext_count().0);
|
||||
|
||||
let d_scalar_blocks: CudaVec<u64> = CudaVec::from_cpu_async(&scalar_blocks, streams, 0);
|
||||
let d_scalar_blocks: CudaVec<u64> =
|
||||
unsafe { CudaVec::from_cpu_async(&scalar_blocks, streams, 0) };
|
||||
|
||||
let block = CudaLweCiphertextList::new(
|
||||
ct.as_ref().d_blocks.lwe_dimension(),
|
||||
@@ -171,78 +168,76 @@ impl CudaServerKey {
|
||||
let mut result =
|
||||
CudaBooleanBlock::from_cuda_radix_ciphertext(CudaRadixCiphertext::new(block, ct_info));
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_comparison(
|
||||
streams,
|
||||
result.as_mut().as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_scalar_blocks,
|
||||
&scalar_blocks,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
scalar_blocks.len() as u32,
|
||||
op,
|
||||
signed_with_positive_scalar,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_comparison(
|
||||
streams,
|
||||
result.as_mut().as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_scalar_blocks,
|
||||
&scalar_blocks,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
scalar_blocks.len() as u32,
|
||||
op,
|
||||
signed_with_positive_scalar,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_comparison(
|
||||
streams,
|
||||
result.as_mut().as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_scalar_blocks,
|
||||
&scalar_blocks,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
scalar_blocks.len() as u32,
|
||||
op,
|
||||
signed_with_positive_scalar,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_comparison(
|
||||
streams,
|
||||
result.as_mut().as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_scalar_blocks,
|
||||
&scalar_blocks,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
scalar_blocks.len() as u32,
|
||||
op,
|
||||
signed_with_positive_scalar,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_comparison_async<Scalar, T>(
|
||||
pub fn unchecked_scalar_comparison<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
@@ -292,25 +287,18 @@ impl CudaServerKey {
|
||||
}
|
||||
|
||||
if scalar >= Scalar::ZERO {
|
||||
self.unchecked_signed_and_unsigned_scalar_comparison_async(
|
||||
ct, scalar, op, true, streams,
|
||||
)
|
||||
self.unchecked_signed_and_unsigned_scalar_comparison(ct, scalar, op, true, streams)
|
||||
} else {
|
||||
let scalar_as_trivial = self.create_trivial_radix(scalar, num_blocks, streams);
|
||||
self.unchecked_comparison(ct, &scalar_as_trivial, op, streams)
|
||||
}
|
||||
} else {
|
||||
// Unsigned
|
||||
self.unchecked_signed_and_unsigned_scalar_comparison_async(
|
||||
ct, scalar, op, false, streams,
|
||||
)
|
||||
self.unchecked_signed_and_unsigned_scalar_comparison(ct, scalar, op, false, streams)
|
||||
}
|
||||
}
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_minmax_async<Scalar, T>(
|
||||
|
||||
pub fn unchecked_scalar_minmax<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
@@ -328,72 +316,75 @@ impl CudaServerKey {
|
||||
.iter_as::<u64>()
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let d_scalar_blocks: CudaVec<u64> = CudaVec::from_cpu_async(&scalar_blocks, streams, 0);
|
||||
let d_scalar_blocks: CudaVec<u64> =
|
||||
unsafe { CudaVec::from_cpu_async(&scalar_blocks, streams, 0) };
|
||||
|
||||
let mut result = ct.duplicate(streams);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_comparison(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_scalar_blocks,
|
||||
&scalar_blocks,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
scalar_blocks.len() as u32,
|
||||
op,
|
||||
T::IS_SIGNED,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_comparison(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_scalar_blocks,
|
||||
&scalar_blocks,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
scalar_blocks.len() as u32,
|
||||
op,
|
||||
T::IS_SIGNED,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_comparison(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_scalar_blocks,
|
||||
&scalar_blocks,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
scalar_blocks.len() as u32,
|
||||
op,
|
||||
T::IS_SIGNED,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_comparison(
|
||||
streams,
|
||||
result.as_mut(),
|
||||
ct.as_ref(),
|
||||
&d_scalar_blocks,
|
||||
&scalar_blocks,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
scalar_blocks.len() as u32,
|
||||
op,
|
||||
T::IS_SIGNED,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
result
|
||||
@@ -465,7 +456,6 @@ impl CudaServerKey {
|
||||
}
|
||||
}
|
||||
}
|
||||
streams.synchronize();
|
||||
boolean_res
|
||||
}
|
||||
|
||||
@@ -535,27 +525,9 @@ impl CudaServerKey {
|
||||
}
|
||||
}
|
||||
}
|
||||
streams.synchronize();
|
||||
boolean_res
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_eq_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::EQ, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_eq<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
@@ -566,35 +538,7 @@ impl CudaServerKey {
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_eq_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_eq_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
ct
|
||||
} else {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
self.unchecked_scalar_eq_async(lhs, scalar, streams)
|
||||
self.unchecked_scalar_comparison(ct, scalar, ComparisonType::EQ, streams)
|
||||
}
|
||||
|
||||
/// Compares for equality 2 ciphertexts
|
||||
@@ -643,25 +587,6 @@ impl CudaServerKey {
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let result = unsafe { self.scalar_eq_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_ne_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
@@ -675,7 +600,7 @@ impl CudaServerKey {
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
self.unchecked_scalar_ne_async(lhs, scalar, streams)
|
||||
self.unchecked_scalar_eq(lhs, scalar, streams)
|
||||
}
|
||||
|
||||
/// Compares for equality 2 ciphertexts
|
||||
@@ -725,29 +650,19 @@ impl CudaServerKey {
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let result = unsafe { self.scalar_ne_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
ct
|
||||
} else {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_ne_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::NE, streams)
|
||||
self.unchecked_scalar_ne(lhs, scalar, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_ne<Scalar, T>(
|
||||
@@ -760,26 +675,7 @@ impl CudaServerKey {
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: DecomposableInto<u64>,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_ne_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_gt_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::GT, streams)
|
||||
self.unchecked_scalar_comparison(ct, scalar, ComparisonType::NE, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_gt<Scalar, T>(
|
||||
@@ -792,26 +688,7 @@ impl CudaServerKey {
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_gt_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_ge_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::GE, streams)
|
||||
self.unchecked_scalar_comparison(ct, scalar, ComparisonType::GT, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_ge<Scalar, T>(
|
||||
@@ -824,26 +701,7 @@ impl CudaServerKey {
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_ge_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_lt_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::LT, streams)
|
||||
self.unchecked_scalar_comparison(ct, scalar, ComparisonType::GE, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_lt<Scalar, T>(
|
||||
@@ -856,26 +714,7 @@ impl CudaServerKey {
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_lt_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_le_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.unchecked_scalar_comparison_async(ct, scalar, ComparisonType::LE, streams)
|
||||
self.unchecked_scalar_comparison(ct, scalar, ComparisonType::LT, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_le<Scalar, T>(
|
||||
@@ -888,34 +727,7 @@ impl CudaServerKey {
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_le_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_gt_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
ct
|
||||
} else {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
self.unchecked_scalar_gt_async(lhs, scalar, streams)
|
||||
self.unchecked_scalar_comparison(ct, scalar, ComparisonType::LE, streams)
|
||||
}
|
||||
|
||||
pub fn scalar_gt<Scalar, T>(
|
||||
@@ -924,25 +736,6 @@ impl CudaServerKey {
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_gt_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_ge_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
@@ -956,7 +749,7 @@ impl CudaServerKey {
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
self.unchecked_scalar_ge_async(lhs, scalar, streams)
|
||||
self.unchecked_scalar_gt(lhs, scalar, streams)
|
||||
}
|
||||
|
||||
pub fn scalar_ge<Scalar, T>(
|
||||
@@ -965,25 +758,6 @@ impl CudaServerKey {
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_ge_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_lt_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
@@ -997,7 +771,7 @@ impl CudaServerKey {
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
self.unchecked_scalar_lt_async(lhs, scalar, streams)
|
||||
self.unchecked_scalar_ge(lhs, scalar, streams)
|
||||
}
|
||||
|
||||
pub fn scalar_lt<Scalar, T>(
|
||||
@@ -1010,15 +784,19 @@ impl CudaServerKey {
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_lt_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
ct
|
||||
} else {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
self.unchecked_scalar_lt(lhs, scalar, streams)
|
||||
}
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_le_async<Scalar, T>(
|
||||
|
||||
pub fn scalar_le<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
@@ -1037,22 +815,7 @@ impl CudaServerKey {
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
self.unchecked_scalar_le_async(lhs, scalar, streams)
|
||||
}
|
||||
|
||||
pub fn scalar_le<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaBooleanBlock
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_le_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
self.unchecked_scalar_le(lhs, scalar, streams)
|
||||
}
|
||||
|
||||
pub fn get_scalar_eq_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
@@ -1103,23 +866,6 @@ impl CudaServerKey {
|
||||
self.get_comparison_size_on_gpu(ct_left, ct_left, ComparisonType::LE, streams)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_max_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.unchecked_scalar_minmax_async(ct, scalar, ComparisonType::MAX, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_max<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
@@ -1130,26 +876,7 @@ impl CudaServerKey {
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_max_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_min_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
self.unchecked_scalar_minmax_async(ct, scalar, ComparisonType::MIN, streams)
|
||||
self.unchecked_scalar_minmax(ct, scalar, ComparisonType::MAX, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_min<Scalar, T>(
|
||||
@@ -1162,35 +889,7 @@ impl CudaServerKey {
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_min_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_max_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
ct
|
||||
} else {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
self.unchecked_scalar_max_async(lhs, scalar, streams)
|
||||
self.unchecked_scalar_minmax(ct, scalar, ComparisonType::MIN, streams)
|
||||
}
|
||||
|
||||
pub fn scalar_max<Scalar, T>(&self, ct: &T, scalar: Scalar, streams: &CudaStreams) -> T
|
||||
@@ -1198,21 +897,19 @@ impl CudaServerKey {
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_max_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
let mut tmp_lhs;
|
||||
let lhs = if ct.block_carries_are_empty() {
|
||||
ct
|
||||
} else {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
self.unchecked_scalar_max(lhs, scalar, streams)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_min_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
pub fn scalar_min<Scalar, T>(&self, ct: &T, scalar: Scalar, streams: &CudaStreams) -> T
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
@@ -1226,18 +923,9 @@ impl CudaServerKey {
|
||||
&tmp_lhs
|
||||
};
|
||||
|
||||
self.unchecked_scalar_min_async(lhs, scalar, streams)
|
||||
self.unchecked_scalar_min(lhs, scalar, streams)
|
||||
}
|
||||
|
||||
pub fn scalar_min<Scalar, T>(&self, ct: &T, scalar: Scalar, streams: &CudaStreams) -> T
|
||||
where
|
||||
Scalar: DecomposableInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_min_async(ct, scalar, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
pub fn get_scalar_max_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
&self,
|
||||
ct_left: &T,
|
||||
|
||||
@@ -67,24 +67,6 @@ impl CudaServerKey {
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
where
|
||||
Scalar: Reciprocable,
|
||||
{
|
||||
let res = unsafe { self.unchecked_scalar_div_async(numerator, divisor, streams) };
|
||||
streams.synchronize();
|
||||
res
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub unsafe fn unchecked_scalar_div_async<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaUnsignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
where
|
||||
Scalar: Reciprocable,
|
||||
{
|
||||
@@ -104,48 +86,50 @@ impl CudaServerKey {
|
||||
|
||||
let mut quotient = numerator.duplicate(streams);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_scalar_div_assign(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_scalar_div_assign(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_scalar_div_assign(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_scalar_div_assign(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -194,24 +178,6 @@ impl CudaServerKey {
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
where
|
||||
Scalar: Reciprocable,
|
||||
{
|
||||
let res = unsafe { self.unchecked_scalar_div_async(numerator, divisor, streams) };
|
||||
streams.synchronize();
|
||||
res
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub unsafe fn scalar_div_async<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaUnsignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
where
|
||||
Scalar: Reciprocable,
|
||||
{
|
||||
@@ -224,7 +190,7 @@ impl CudaServerKey {
|
||||
&tmp_numerator
|
||||
};
|
||||
|
||||
self.unchecked_scalar_div_async(numerator, divisor, streams)
|
||||
self.unchecked_scalar_div(numerator, divisor, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_div_rem<Scalar>(
|
||||
@@ -233,27 +199,6 @@ impl CudaServerKey {
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext)
|
||||
where
|
||||
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
{
|
||||
let (quotient, remainder) =
|
||||
unsafe { self.unchecked_scalar_div_rem_async(numerator, divisor, streams) };
|
||||
|
||||
streams.synchronize();
|
||||
|
||||
(quotient, remainder)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub unsafe fn unchecked_scalar_div_rem_async<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaUnsignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext)
|
||||
where
|
||||
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
{
|
||||
@@ -276,50 +221,52 @@ impl CudaServerKey {
|
||||
streams,
|
||||
);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_scalar_div_rem(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
remainder.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_scalar_div_rem(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
remainder.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_scalar_div_rem(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
remainder.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_unsigned_scalar_div_rem(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
remainder.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -370,24 +317,6 @@ impl CudaServerKey {
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext)
|
||||
where
|
||||
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
{
|
||||
let res = unsafe { self.unchecked_scalar_div_rem_async(numerator, divisor, streams) };
|
||||
streams.synchronize();
|
||||
res
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub unsafe fn scalar_div_rem_async<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaUnsignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> (CudaUnsignedRadixCiphertext, CudaUnsignedRadixCiphertext)
|
||||
where
|
||||
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
{
|
||||
@@ -400,7 +329,7 @@ impl CudaServerKey {
|
||||
&tmp_numerator
|
||||
};
|
||||
|
||||
self.unchecked_scalar_div_rem_async(numerator, divisor, streams)
|
||||
self.unchecked_scalar_div_rem(numerator, divisor, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_rem<Scalar>(
|
||||
@@ -412,26 +341,7 @@ impl CudaServerKey {
|
||||
where
|
||||
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
{
|
||||
let res = unsafe { self.unchecked_scalar_rem_async(numerator, divisor, streams) };
|
||||
streams.synchronize();
|
||||
res
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub unsafe fn unchecked_scalar_rem_async<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaUnsignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
where
|
||||
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
{
|
||||
self.unchecked_scalar_div_rem_async(numerator, divisor, streams)
|
||||
.1
|
||||
self.unchecked_scalar_div_rem(numerator, divisor, streams).1
|
||||
}
|
||||
|
||||
/// Computes homomorphically a division between a ciphertext and a scalar.
|
||||
@@ -473,24 +383,6 @@ impl CudaServerKey {
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
where
|
||||
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
{
|
||||
let res = unsafe { self.unchecked_scalar_rem_async(numerator, divisor, streams) };
|
||||
streams.synchronize();
|
||||
res
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub unsafe fn scalar_rem_async<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaUnsignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaUnsignedRadixCiphertext
|
||||
where
|
||||
Scalar: Reciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
{
|
||||
@@ -503,7 +395,7 @@ impl CudaServerKey {
|
||||
&tmp_numerator
|
||||
};
|
||||
|
||||
self.unchecked_scalar_rem_async(numerator, divisor, streams)
|
||||
self.unchecked_scalar_rem(numerator, divisor, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_signed_scalar_div<Scalar>(
|
||||
@@ -512,25 +404,6 @@ impl CudaServerKey {
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
{
|
||||
let res = unsafe { self.unchecked_signed_scalar_div_async(numerator, divisor, streams) };
|
||||
streams.synchronize();
|
||||
res
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub unsafe fn unchecked_signed_scalar_div_async<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaSignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
@@ -548,48 +421,50 @@ impl CudaServerKey {
|
||||
|
||||
let mut quotient: CudaSignedRadixCiphertext = numerator.duplicate(streams);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_signed_scalar_div_assign(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_signed_scalar_div_assign(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_signed_scalar_div_assign(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_signed_scalar_div_assign(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -638,25 +513,6 @@ impl CudaServerKey {
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
{
|
||||
let res = unsafe { self.signed_scalar_div_async(numerator, divisor, streams) };
|
||||
streams.synchronize();
|
||||
res
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub unsafe fn signed_scalar_div_async<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaSignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
@@ -670,7 +526,7 @@ impl CudaServerKey {
|
||||
&tmp_numerator
|
||||
};
|
||||
|
||||
self.unchecked_signed_scalar_div_async(numerator, divisor, streams)
|
||||
self.unchecked_signed_scalar_div(numerator, divisor, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_signed_scalar_div_rem<Scalar>(
|
||||
@@ -679,26 +535,6 @@ impl CudaServerKey {
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
{
|
||||
let res =
|
||||
unsafe { self.unchecked_signed_scalar_div_rem_async(numerator, divisor, streams) };
|
||||
streams.synchronize();
|
||||
res
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub unsafe fn unchecked_signed_scalar_div_rem_async<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaSignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
@@ -720,53 +556,54 @@ impl CudaServerKey {
|
||||
streams,
|
||||
);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_signed_scalar_div_rem_assign(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
remainder.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_signed_scalar_div_rem_assign(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
remainder.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_signed_scalar_div_rem_assign(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
remainder.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
d_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
LweBskGroupingFactor(0),
|
||||
PBSType::Classical,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_signed_scalar_div_rem_assign(
|
||||
streams,
|
||||
quotient.as_mut(),
|
||||
remainder.as_mut(),
|
||||
divisor,
|
||||
&self.key_switching_key.d_vec,
|
||||
&d_multibit_bsk.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
d_multibit_bsk.input_lwe_dimension,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
PBSType::MultiBit,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(quotient, remainder)
|
||||
}
|
||||
|
||||
@@ -812,25 +649,6 @@ impl CudaServerKey {
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
{
|
||||
let res = unsafe { self.signed_scalar_div_rem_async(numerator, divisor, streams) };
|
||||
streams.synchronize();
|
||||
res
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub unsafe fn signed_scalar_div_rem_async<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaSignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> (CudaSignedRadixCiphertext, CudaSignedRadixCiphertext)
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
@@ -844,7 +662,7 @@ impl CudaServerKey {
|
||||
&tmp_numerator
|
||||
};
|
||||
|
||||
self.unchecked_signed_scalar_div_rem_async(numerator, divisor, streams)
|
||||
self.unchecked_signed_scalar_div_rem(numerator, divisor, streams)
|
||||
}
|
||||
|
||||
pub fn unchecked_signed_scalar_rem<Scalar>(
|
||||
@@ -853,27 +671,6 @@ impl CudaServerKey {
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
{
|
||||
let remainder =
|
||||
unsafe { self.unchecked_signed_scalar_rem_async(numerator, divisor, streams) };
|
||||
streams.synchronize();
|
||||
|
||||
remainder
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub unsafe fn unchecked_signed_scalar_rem_async<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaSignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
@@ -922,25 +719,6 @@ impl CudaServerKey {
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
{
|
||||
let res = unsafe { self.signed_scalar_rem_async(numerator, divisor, streams) };
|
||||
streams.synchronize();
|
||||
res
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronized
|
||||
pub unsafe fn signed_scalar_rem_async<Scalar>(
|
||||
&self,
|
||||
numerator: &CudaSignedRadixCiphertext,
|
||||
divisor: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> CudaSignedRadixCiphertext
|
||||
where
|
||||
Scalar: SignedReciprocable + ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
<<Scalar as SignedReciprocable>::Unsigned as Reciprocable>::DoublePrecision: Send,
|
||||
@@ -954,7 +732,7 @@ impl CudaServerKey {
|
||||
&tmp_numerator
|
||||
};
|
||||
|
||||
self.unchecked_signed_scalar_rem_async(numerator, divisor, streams)
|
||||
self.unchecked_signed_scalar_rem(numerator, divisor, streams)
|
||||
}
|
||||
|
||||
pub fn get_scalar_div_size_on_gpu<Scalar>(
|
||||
|
||||
@@ -63,11 +63,7 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_mul_assign_async<Scalar, T>(
|
||||
pub fn unchecked_scalar_mul_assign<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
scalar: Scalar,
|
||||
@@ -77,7 +73,7 @@ impl CudaServerKey {
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
if scalar == Scalar::ZERO {
|
||||
ct.as_mut().d_blocks.0.d_vec.memset_async(0, streams, 0);
|
||||
ct.as_mut().d_blocks.0.d_vec.memset(0, streams, 0);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -90,7 +86,7 @@ impl CudaServerKey {
|
||||
if scalar.is_power_of_two() {
|
||||
// Shifting cost one bivariate PBS so its always faster
|
||||
// than multiplying
|
||||
self.unchecked_scalar_left_shift_assign_async(ct, scalar.ilog2() as u64, streams);
|
||||
self.unchecked_scalar_left_shift_assign(ct, scalar.ilog2() as u64, streams);
|
||||
return;
|
||||
}
|
||||
let msg_bits = self.message_modulus.0.ilog2() as usize;
|
||||
@@ -112,73 +108,60 @@ impl CudaServerKey {
|
||||
return;
|
||||
}
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_mul(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
decomposed_scalar.as_slice(),
|
||||
has_at_least_one_set.as_slice(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
d_bsk.decomp_base_log,
|
||||
d_bsk.decomp_level_count,
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
decomposed_scalar.len() as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_mul(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
decomposed_scalar.as_slice(),
|
||||
has_at_least_one_set.as_slice(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
decomposed_scalar.len() as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_mul_assign<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
Scalar: ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.unchecked_scalar_mul_assign_async(ct, scalar, streams);
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_mul(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
decomposed_scalar.as_slice(),
|
||||
has_at_least_one_set.as_slice(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
d_bsk.decomp_base_log,
|
||||
d_bsk.decomp_level_count,
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
decomposed_scalar.len() as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_mul(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
decomposed_scalar.as_slice(),
|
||||
has_at_least_one_set.as_slice(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
decomposed_scalar.len() as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
/// Computes homomorphically a multiplication between a scalar and a ciphertext.
|
||||
@@ -227,16 +210,8 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_mul_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
pub fn scalar_mul_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)
|
||||
where
|
||||
Scalar: ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
@@ -244,19 +219,9 @@ impl CudaServerKey {
|
||||
self.full_propagate_assign(ct, streams);
|
||||
}
|
||||
|
||||
self.unchecked_scalar_mul_assign_async(ct, scalar, streams);
|
||||
self.unchecked_scalar_mul_assign(ct, scalar, streams);
|
||||
}
|
||||
|
||||
pub fn scalar_mul_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)
|
||||
where
|
||||
Scalar: ScalarMultiplier + DecomposableInto<u8> + CastInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.scalar_mul_assign_async(ct, scalar, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
pub fn get_scalar_mul_size_on_gpu<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
|
||||
@@ -11,140 +11,23 @@ use crate::integer::gpu::{
|
||||
};
|
||||
|
||||
impl CudaServerKey {
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_rotate_left_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
n: Scalar,
|
||||
stream: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
{
|
||||
let mut result = ct.duplicate(stream);
|
||||
self.unchecked_scalar_rotate_left_assign_async(&mut result, n, stream);
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_rotate_left_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
n: Scalar,
|
||||
stream: &CudaStreams,
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_rotate_left_assign(
|
||||
stream,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(n),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_rotate_left_assign(
|
||||
stream,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(n),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_rotate_left<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
n: Scalar,
|
||||
stream: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_rotate_left_async(ct, n, stream) };
|
||||
stream.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_rotate_right_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
n: Scalar,
|
||||
stream: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
{
|
||||
let mut result = ct.duplicate(stream);
|
||||
self.unchecked_scalar_rotate_right_assign_async(&mut result, n, stream);
|
||||
self.unchecked_scalar_rotate_left_assign(&mut result, n, stream);
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `stream` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until stream is synchronised
|
||||
pub unsafe fn unchecked_scalar_rotate_right_assign_async<Scalar, T>(
|
||||
pub fn unchecked_scalar_rotate_left_assign<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
n: Scalar,
|
||||
@@ -155,60 +38,62 @@ impl CudaServerKey {
|
||||
u32: CastFrom<Scalar>,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_rotate_right_assign(
|
||||
stream,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(n),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_rotate_right_assign(
|
||||
stream,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(n),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_rotate_left_assign(
|
||||
stream,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(n),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_rotate_left_assign(
|
||||
stream,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(n),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -224,11 +109,82 @@ impl CudaServerKey {
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_rotate_right_async(ct, n, stream) };
|
||||
stream.synchronize();
|
||||
let mut result = ct.duplicate(stream);
|
||||
self.unchecked_scalar_rotate_right_assign(&mut result, n, stream);
|
||||
result
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_rotate_right_assign<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
n: Scalar,
|
||||
stream: &CudaStreams,
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_rotate_right_assign(
|
||||
stream,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(n),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_rotate_right_assign(
|
||||
stream,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(n),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scalar_rotate_left_assign<Scalar, T>(&self, ct: &mut T, n: Scalar, stream: &CudaStreams)
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
@@ -239,8 +195,7 @@ impl CudaServerKey {
|
||||
self.full_propagate_assign(ct, stream);
|
||||
}
|
||||
|
||||
unsafe { self.unchecked_scalar_rotate_left_assign_async(ct, n, stream) };
|
||||
stream.synchronize();
|
||||
self.unchecked_scalar_rotate_left_assign(ct, n, stream);
|
||||
}
|
||||
|
||||
pub fn scalar_rotate_right_assign<Scalar, T>(&self, ct: &mut T, n: Scalar, stream: &CudaStreams)
|
||||
@@ -253,8 +208,7 @@ impl CudaServerKey {
|
||||
self.full_propagate_assign(ct, stream);
|
||||
}
|
||||
|
||||
unsafe { self.unchecked_scalar_rotate_right_assign_async(ct, n, stream) };
|
||||
stream.synchronize();
|
||||
self.unchecked_scalar_rotate_right_assign(ct, n, stream);
|
||||
}
|
||||
|
||||
pub fn scalar_rotate_left<Scalar, T>(&self, ct: &T, shift: Scalar, stream: &CudaStreams) -> T
|
||||
|
||||
@@ -13,100 +13,6 @@ use crate::integer::gpu::{
|
||||
};
|
||||
|
||||
impl CudaServerKey {
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_left_shift_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
shift: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut result = ct.duplicate(streams);
|
||||
self.unchecked_scalar_left_shift_assign_async(&mut result, shift, streams);
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_left_shift_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_left_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_left_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes homomorphically a left shift by a scalar.
|
||||
///
|
||||
/// The result is returned as a new ciphertext.
|
||||
@@ -149,41 +55,17 @@ impl CudaServerKey {
|
||||
shift: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_left_shift_async(ct, shift, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_right_shift_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
shift: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut result = ct.duplicate(streams);
|
||||
self.unchecked_scalar_right_shift_assign_async(&mut result, shift, streams);
|
||||
self.unchecked_scalar_left_shift_assign(&mut result, shift, streams);
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_right_shift_assign_async<Scalar, T>(
|
||||
pub fn unchecked_scalar_left_shift_assign<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: Scalar,
|
||||
@@ -195,65 +77,10 @@ impl CudaServerKey {
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
if T::IS_SIGNED {
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_arithmetic_right_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_arithmetic_right_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_logical_right_shift_assign(
|
||||
cuda_backend_unchecked_scalar_left_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
@@ -280,7 +107,7 @@ impl CudaServerKey {
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_logical_right_shift_assign(
|
||||
cuda_backend_unchecked_scalar_left_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
@@ -357,16 +184,141 @@ impl CudaServerKey {
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_scalar_right_shift_async(ct, shift, streams) };
|
||||
streams.synchronize();
|
||||
let mut result = ct.duplicate(streams);
|
||||
self.unchecked_scalar_right_shift_assign(&mut result, shift, streams);
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_right_shift_assign_async<Scalar, T>(
|
||||
pub fn unchecked_scalar_right_shift_assign<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
unsafe {
|
||||
if T::IS_SIGNED {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_arithmetic_right_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_arithmetic_right_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_logical_right_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_logical_right_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn scalar_right_shift_assign<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: Scalar,
|
||||
@@ -380,27 +332,7 @@ impl CudaServerKey {
|
||||
self.full_propagate_assign(ct, streams);
|
||||
}
|
||||
|
||||
self.unchecked_scalar_right_shift_assign_async(ct, shift, streams);
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_right_shift_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
shift: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut result = ct.duplicate(streams);
|
||||
self.scalar_right_shift_assign_async(&mut result, shift, streams);
|
||||
result
|
||||
self.unchecked_scalar_right_shift_assign(ct, shift, streams);
|
||||
}
|
||||
|
||||
/// Computes homomorphically a right shift by a scalar.
|
||||
@@ -440,54 +372,13 @@ impl CudaServerKey {
|
||||
/// assert_eq!(dec_result, msg >> shift);
|
||||
/// ```
|
||||
pub fn scalar_right_shift<Scalar, T>(&self, ct: &T, shift: Scalar, streams: &CudaStreams) -> T
|
||||
where
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_right_shift_async(ct, shift, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_left_shift_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
if !ct.block_carries_are_empty() {
|
||||
self.full_propagate_assign(ct, streams);
|
||||
}
|
||||
|
||||
self.unchecked_scalar_left_shift_assign_async(ct, shift, streams);
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_left_shift_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
shift: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut result = ct.duplicate(streams);
|
||||
self.scalar_left_shift_assign_async(&mut result, shift, streams);
|
||||
self.scalar_right_shift_assign(&mut result, shift, streams);
|
||||
result
|
||||
}
|
||||
|
||||
@@ -533,8 +424,8 @@ impl CudaServerKey {
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.scalar_left_shift_async(ct, shift, streams) };
|
||||
streams.synchronize();
|
||||
let mut result = ct.duplicate(streams);
|
||||
self.scalar_left_shift_assign(&mut result, shift, streams);
|
||||
result
|
||||
}
|
||||
|
||||
@@ -552,104 +443,7 @@ impl CudaServerKey {
|
||||
self.full_propagate_assign(ct, streams);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
self.unchecked_scalar_left_shift_assign_async(ct, shift, streams);
|
||||
};
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
pub fn scalar_right_shift_assign<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
if !ct.block_carries_are_empty() {
|
||||
self.full_propagate_assign(ct, streams);
|
||||
}
|
||||
|
||||
unsafe {
|
||||
self.unchecked_scalar_right_shift_assign_async(ct, shift, streams);
|
||||
};
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_right_shift_logical_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
Scalar: CastFrom<u32>,
|
||||
u32: CastFrom<Scalar>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_scalar_logical_right_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_scalar_logical_right_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
u32::cast_from(shift),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
self.unchecked_scalar_left_shift_assign(ct, shift, streams);
|
||||
}
|
||||
|
||||
pub fn get_scalar_left_shift_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
|
||||
|
||||
@@ -61,23 +61,6 @@ impl CudaServerKey {
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_scalar_sub_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation + CastInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let negated_scalar = scalar.twos_complement_negation();
|
||||
self.unchecked_scalar_add_assign_async(ct, negated_scalar, streams);
|
||||
}
|
||||
|
||||
pub fn unchecked_scalar_sub_assign<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
@@ -87,10 +70,8 @@ impl CudaServerKey {
|
||||
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation + CastInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.unchecked_scalar_sub_assign_async(ct, scalar, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
let negated_scalar = scalar.twos_complement_negation();
|
||||
self.unchecked_scalar_add_assign(ct, negated_scalar, streams);
|
||||
}
|
||||
|
||||
/// Computes homomorphically a subtraction between a ciphertext and a scalar.
|
||||
@@ -148,16 +129,8 @@ impl CudaServerKey {
|
||||
self.get_scalar_sub_assign_size_on_gpu(ct, streams)
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn scalar_sub_assign_async<Scalar, T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
scalar: Scalar,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
pub fn scalar_sub_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)
|
||||
where
|
||||
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation + CastInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
@@ -165,21 +138,10 @@ impl CudaServerKey {
|
||||
self.full_propagate_assign(ct, streams);
|
||||
}
|
||||
|
||||
self.unchecked_scalar_sub_assign_async(ct, scalar, streams);
|
||||
self.unchecked_scalar_sub_assign(ct, scalar, streams);
|
||||
let _carry = self.propagate_single_carry_assign(ct, streams, None, OutputFlag::None);
|
||||
}
|
||||
|
||||
pub fn scalar_sub_assign<Scalar, T>(&self, ct: &mut T, scalar: Scalar, streams: &CudaStreams)
|
||||
where
|
||||
Scalar: DecomposableInto<u8> + Numeric + TwosComplementNegation + CastInto<u64>,
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe {
|
||||
self.scalar_sub_assign_async(ct, scalar, streams);
|
||||
}
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
pub fn get_scalar_sub_assign_size_on_gpu<T>(&self, ct: &T, streams: &CudaStreams) -> u64
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
|
||||
@@ -9,11 +9,7 @@ use crate::integer::gpu::{
|
||||
};
|
||||
|
||||
impl CudaServerKey {
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_right_shift_assign_async<T>(
|
||||
pub fn unchecked_right_shift_assign<T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: &CudaUnsignedRadixCiphertext,
|
||||
@@ -24,214 +20,79 @@ impl CudaServerKey {
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let is_signed = T::IS_SIGNED;
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_right_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
shift.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_right_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
shift.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_right_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
shift.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_right_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
shift.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_right_shift_async<T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
shift: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut result = ct.duplicate(streams);
|
||||
self.unchecked_right_shift_assign_async(&mut result, shift, streams);
|
||||
result
|
||||
}
|
||||
|
||||
pub fn unchecked_right_shift<T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
shift: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_right_shift_async(ct, shift, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
pub fn unchecked_right_shift_assign<T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe { self.unchecked_right_shift_assign_async(ct, shift, streams) };
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_left_shift_assign_async<T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let is_signed = T::IS_SIGNED;
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_left_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
shift.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_left_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
shift.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn unchecked_left_shift_async<T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
shift: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut result = ct.duplicate(streams);
|
||||
self.unchecked_left_shift_assign_async(&mut result, shift, streams);
|
||||
result
|
||||
}
|
||||
|
||||
pub fn unchecked_left_shift<T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
shift: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.unchecked_left_shift_async(ct, shift, streams) };
|
||||
streams.synchronize();
|
||||
self.unchecked_right_shift_assign(&mut result, shift, streams);
|
||||
result
|
||||
}
|
||||
|
||||
@@ -243,15 +104,72 @@ impl CudaServerKey {
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe { self.unchecked_left_shift_assign_async(ct, shift, streams) };
|
||||
streams.synchronize();
|
||||
let lwe_ciphertext_count = ct.as_ref().d_blocks.lwe_ciphertext_count();
|
||||
let is_signed = T::IS_SIGNED;
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_left_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
shift.as_ref(),
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_left_shift_assign(
|
||||
streams,
|
||||
ct.as_mut(),
|
||||
shift.as_ref(),
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
lwe_ciphertext_count.0 as u32,
|
||||
is_signed,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn right_shift_async<T>(
|
||||
pub fn unchecked_left_shift<T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
shift: &CudaUnsignedRadixCiphertext,
|
||||
@@ -260,82 +178,11 @@ impl CudaServerKey {
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut tmp_lhs: T;
|
||||
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
|
||||
|
||||
let (lhs, rhs) = match (
|
||||
ct.block_carries_are_empty(),
|
||||
shift.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => (ct, shift),
|
||||
(true, false) => {
|
||||
tmp_rhs = shift.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(ct, &tmp_rhs)
|
||||
}
|
||||
(false, true) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
(&tmp_lhs, shift)
|
||||
}
|
||||
(false, false) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
tmp_rhs = shift.duplicate(streams);
|
||||
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(&tmp_lhs, &tmp_rhs)
|
||||
}
|
||||
};
|
||||
|
||||
let mut result = lhs.duplicate(streams);
|
||||
self.unchecked_right_shift_assign_async(&mut result, rhs, streams);
|
||||
let mut result = ct.duplicate(streams);
|
||||
self.unchecked_left_shift_assign(&mut result, shift, streams);
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn right_shift_assign_async<T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let mut tmp_lhs: T;
|
||||
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
|
||||
|
||||
let (lhs, rhs) = match (
|
||||
ct.block_carries_are_empty(),
|
||||
shift.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => (ct, shift),
|
||||
(true, false) => {
|
||||
tmp_rhs = shift.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(ct, &tmp_rhs)
|
||||
}
|
||||
(false, true) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
(&mut tmp_lhs, shift)
|
||||
}
|
||||
(false, false) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
tmp_rhs = shift.duplicate(streams);
|
||||
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(&mut tmp_lhs, &tmp_rhs)
|
||||
}
|
||||
};
|
||||
|
||||
self.unchecked_right_shift_assign_async(lhs, rhs, streams);
|
||||
}
|
||||
|
||||
/// Computes homomorphically a right shift by an encrypted amount
|
||||
///
|
||||
/// The result is returned as a new ciphertext.
|
||||
@@ -380,36 +227,6 @@ impl CudaServerKey {
|
||||
shift: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.right_shift_async(ct, shift, streams) };
|
||||
streams.synchronize();
|
||||
result
|
||||
}
|
||||
|
||||
pub fn right_shift_assign<T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe { self.right_shift_assign_async(ct, shift, streams) };
|
||||
streams.synchronize();
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn left_shift_async<T>(
|
||||
&self,
|
||||
ct: &T,
|
||||
shift: &CudaUnsignedRadixCiphertext,
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
@@ -442,15 +259,11 @@ impl CudaServerKey {
|
||||
};
|
||||
|
||||
let mut result = lhs.duplicate(streams);
|
||||
self.unchecked_left_shift_assign_async(&mut result, rhs, streams);
|
||||
self.unchecked_right_shift_assign(&mut result, rhs, streams);
|
||||
result
|
||||
}
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must
|
||||
/// not be dropped until streams is synchronised
|
||||
pub unsafe fn left_shift_assign_async<T>(
|
||||
pub fn right_shift_assign<T>(
|
||||
&self,
|
||||
ct: &mut T,
|
||||
shift: &CudaUnsignedRadixCiphertext,
|
||||
@@ -486,7 +299,7 @@ impl CudaServerKey {
|
||||
}
|
||||
};
|
||||
|
||||
self.unchecked_left_shift_assign_async(lhs, rhs, streams);
|
||||
self.unchecked_right_shift_assign(lhs, rhs, streams);
|
||||
}
|
||||
|
||||
/// Computes homomorphically a left shift by an encrypted amount
|
||||
@@ -536,8 +349,36 @@ impl CudaServerKey {
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let result = unsafe { self.left_shift_async(ct, shift, streams) };
|
||||
streams.synchronize();
|
||||
let mut tmp_lhs: T;
|
||||
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
|
||||
|
||||
let (lhs, rhs) = match (
|
||||
ct.block_carries_are_empty(),
|
||||
shift.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => (ct, shift),
|
||||
(true, false) => {
|
||||
tmp_rhs = shift.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(ct, &tmp_rhs)
|
||||
}
|
||||
(false, true) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
(&tmp_lhs, shift)
|
||||
}
|
||||
(false, false) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
tmp_rhs = shift.duplicate(streams);
|
||||
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(&tmp_lhs, &tmp_rhs)
|
||||
}
|
||||
};
|
||||
|
||||
let mut result = lhs.duplicate(streams);
|
||||
self.unchecked_left_shift_assign(&mut result, rhs, streams);
|
||||
result
|
||||
}
|
||||
|
||||
@@ -549,8 +390,35 @@ impl CudaServerKey {
|
||||
) where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
unsafe { self.left_shift_assign_async(ct, shift, streams) };
|
||||
streams.synchronize();
|
||||
let mut tmp_lhs: T;
|
||||
let mut tmp_rhs: CudaUnsignedRadixCiphertext;
|
||||
|
||||
let (lhs, rhs) = match (
|
||||
ct.block_carries_are_empty(),
|
||||
shift.block_carries_are_empty(),
|
||||
) {
|
||||
(true, true) => (ct, shift),
|
||||
(true, false) => {
|
||||
tmp_rhs = shift.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(ct, &tmp_rhs)
|
||||
}
|
||||
(false, true) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
(&mut tmp_lhs, shift)
|
||||
}
|
||||
(false, false) => {
|
||||
tmp_lhs = ct.duplicate(streams);
|
||||
tmp_rhs = shift.duplicate(streams);
|
||||
|
||||
self.full_propagate_assign(&mut tmp_lhs, streams);
|
||||
self.full_propagate_assign(&mut tmp_rhs, streams);
|
||||
(&mut tmp_lhs, &tmp_rhs)
|
||||
}
|
||||
};
|
||||
|
||||
self.unchecked_left_shift_assign(lhs, rhs, streams);
|
||||
}
|
||||
|
||||
pub fn get_left_shift_size_on_gpu<T: CudaIntegerRadixCiphertext>(
|
||||
|
||||
Reference in New Issue
Block a user