mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 22:57:59 -05:00
refactor(gpu): moving vector_comparisons's functions to the backend
This commit is contained in:
committed by
Agnès Leroy
parent
f5bfc7f79d
commit
6752249f7f
@@ -6128,132 +6128,6 @@ pub(crate) unsafe fn cuda_backend_apply_many_univariate_lut<T: UnsignedInteger,
|
||||
cleanup_cuda_apply_univariate_lut_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_apply_bivariate_lut<T: UnsignedInteger, B: Numeric>(
|
||||
streams: &CudaStreams,
|
||||
output: &mut CudaSliceMut<T>,
|
||||
output_degrees: &mut Vec<u64>,
|
||||
output_noise_levels: &mut Vec<u64>,
|
||||
input_1: &CudaSlice<T>,
|
||||
input_2: &CudaSlice<T>,
|
||||
input_lut: &[T],
|
||||
lut_degree: u64,
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
lwe_dimension: LweDimension,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
num_blocks: u32,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
shift: u32,
|
||||
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
|
||||
) {
|
||||
assert_eq!(
|
||||
streams.gpu_indexes[0],
|
||||
input_1.gpu_index(0),
|
||||
"GPU error: first stream is on GPU {}, first input 1 pointer is on GPU {}",
|
||||
streams.gpu_indexes[0].get(),
|
||||
input_1.gpu_index(0).get(),
|
||||
);
|
||||
assert_eq!(
|
||||
streams.gpu_indexes[0],
|
||||
input_2.gpu_index(0),
|
||||
"GPU error: first stream is on GPU {}, first input 2 pointer is on GPU {}",
|
||||
streams.gpu_indexes[0].get(),
|
||||
input_2.gpu_index(0).get(),
|
||||
);
|
||||
assert_eq!(
|
||||
streams.gpu_indexes[0],
|
||||
output.gpu_index(0),
|
||||
"GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
|
||||
streams.gpu_indexes[0].get(),
|
||||
output.gpu_index(0).get(),
|
||||
);
|
||||
assert_eq!(
|
||||
streams.gpu_indexes[0],
|
||||
bootstrapping_key.gpu_index(0),
|
||||
"GPU error: first stream is on GPU {}, first bsk pointer is on GPU {}",
|
||||
streams.gpu_indexes[0].get(),
|
||||
bootstrapping_key.gpu_index(0).get(),
|
||||
);
|
||||
assert_eq!(
|
||||
streams.gpu_indexes[0],
|
||||
keyswitch_key.gpu_index(0),
|
||||
"GPU error: first stream is on GPU {}, first ksk pointer is on GPU {}",
|
||||
streams.gpu_indexes[0].get(),
|
||||
keyswitch_key.gpu_index(0).get(),
|
||||
);
|
||||
|
||||
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
let mut cuda_ffi_output = prepare_cuda_radix_ffi_from_slice_mut(
|
||||
output,
|
||||
output_degrees,
|
||||
output_noise_levels,
|
||||
num_blocks,
|
||||
(glwe_dimension.0 * polynomial_size.0) as u32,
|
||||
);
|
||||
let cuda_ffi_input_1 = prepare_cuda_radix_ffi_from_slice(
|
||||
input_1,
|
||||
output_degrees,
|
||||
output_noise_levels,
|
||||
num_blocks,
|
||||
(glwe_dimension.0 * polynomial_size.0) as u32,
|
||||
);
|
||||
let cuda_ffi_input_2 = prepare_cuda_radix_ffi_from_slice(
|
||||
input_2,
|
||||
output_degrees,
|
||||
output_noise_levels,
|
||||
num_blocks,
|
||||
(glwe_dimension.0 * polynomial_size.0) as u32,
|
||||
);
|
||||
scratch_cuda_apply_bivariate_lut_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
input_lut.as_ptr().cast(),
|
||||
lwe_dimension.0 as u32,
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
lut_degree,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
cuda_apply_bivariate_lut_64(
|
||||
streams.ffi(),
|
||||
&raw mut cuda_ffi_output,
|
||||
&raw const cuda_ffi_input_1,
|
||||
&raw const cuda_ffi_input_2,
|
||||
mem_ptr,
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
num_blocks,
|
||||
shift,
|
||||
);
|
||||
cleanup_cuda_apply_bivariate_lut_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
@@ -10109,6 +9983,300 @@ pub(crate) unsafe fn cuda_backend_unchecked_index_of_clear<
|
||||
update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_unchecked_all_eq_slices<
|
||||
T: UnsignedInteger,
|
||||
B: Numeric,
|
||||
C: CudaIntegerRadixCiphertext,
|
||||
>(
|
||||
streams: &CudaStreams,
|
||||
match_ct: &mut CudaBooleanBlock,
|
||||
lhs: &[C],
|
||||
rhs: &[C],
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
|
||||
) {
|
||||
assert_eq!(streams.gpu_indexes[0], bootstrapping_key.gpu_index(0));
|
||||
assert_eq!(streams.gpu_indexes[0], keyswitch_key.gpu_index(0));
|
||||
|
||||
let num_inputs = lhs.len() as u32;
|
||||
let num_blocks = lhs[0].as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
|
||||
|
||||
let mut match_degrees = vec![match_ct.0.ciphertext.info.blocks[0].degree.get()];
|
||||
let mut match_noise_levels = vec![match_ct.0.ciphertext.info.blocks[0].noise_level.0];
|
||||
let mut ffi_match = prepare_cuda_radix_ffi(
|
||||
&match_ct.0.ciphertext,
|
||||
&mut match_degrees,
|
||||
&mut match_noise_levels,
|
||||
);
|
||||
|
||||
let mut ffi_lhs_degrees: Vec<Vec<u64>> = Vec::with_capacity(lhs.len());
|
||||
let mut ffi_lhs_noise_levels: Vec<Vec<u64>> = Vec::with_capacity(lhs.len());
|
||||
let ffi_lhs: Vec<CudaRadixCiphertextFFI> = lhs
|
||||
.iter()
|
||||
.map(|ct| {
|
||||
let degrees = ct
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.degree.get())
|
||||
.collect();
|
||||
let noise_levels = ct
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.noise_level.0)
|
||||
.collect();
|
||||
ffi_lhs_degrees.push(degrees);
|
||||
ffi_lhs_noise_levels.push(noise_levels);
|
||||
|
||||
prepare_cuda_radix_ffi(
|
||||
ct.as_ref(),
|
||||
ffi_lhs_degrees.last_mut().unwrap(),
|
||||
ffi_lhs_noise_levels.last_mut().unwrap(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut ffi_rhs_degrees: Vec<Vec<u64>> = Vec::with_capacity(rhs.len());
|
||||
let mut ffi_rhs_noise_levels: Vec<Vec<u64>> = Vec::with_capacity(rhs.len());
|
||||
let ffi_rhs: Vec<CudaRadixCiphertextFFI> = rhs
|
||||
.iter()
|
||||
.map(|ct| {
|
||||
let degrees = ct
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.degree.get())
|
||||
.collect();
|
||||
let noise_levels = ct
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.noise_level.0)
|
||||
.collect();
|
||||
ffi_rhs_degrees.push(degrees);
|
||||
ffi_rhs_noise_levels.push(noise_levels);
|
||||
|
||||
prepare_cuda_radix_ffi(
|
||||
ct.as_ref(),
|
||||
ffi_rhs_degrees.last_mut().unwrap(),
|
||||
ffi_rhs_noise_levels.last_mut().unwrap(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
|
||||
scratch_cuda_unchecked_all_eq_slices_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_inputs,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_unchecked_all_eq_slices_64(
|
||||
streams.ffi(),
|
||||
&raw mut ffi_match,
|
||||
ffi_lhs.as_ptr(),
|
||||
ffi_rhs.as_ptr(),
|
||||
num_inputs,
|
||||
num_blocks,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
|
||||
cleanup_cuda_unchecked_all_eq_slices_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
|
||||
update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
/// - The data must not be moved or dropped while being used by the CUDA kernel.
|
||||
/// - This function assumes exclusive access to the passed data; violating this may lead to
|
||||
/// undefined behavior.
|
||||
pub(crate) unsafe fn cuda_backend_unchecked_contains_sub_slice<
|
||||
T: UnsignedInteger,
|
||||
B: Numeric,
|
||||
C: CudaIntegerRadixCiphertext,
|
||||
>(
|
||||
streams: &CudaStreams,
|
||||
match_ct: &mut CudaBooleanBlock,
|
||||
lhs: &[C],
|
||||
rhs: &[C],
|
||||
bootstrapping_key: &CudaVec<B>,
|
||||
keyswitch_key: &CudaVec<T>,
|
||||
message_modulus: MessageModulus,
|
||||
carry_modulus: CarryModulus,
|
||||
glwe_dimension: GlweDimension,
|
||||
polynomial_size: PolynomialSize,
|
||||
big_lwe_dimension: LweDimension,
|
||||
small_lwe_dimension: LweDimension,
|
||||
ks_level: DecompositionLevelCount,
|
||||
ks_base_log: DecompositionBaseLog,
|
||||
pbs_level: DecompositionLevelCount,
|
||||
pbs_base_log: DecompositionBaseLog,
|
||||
pbs_type: PBSType,
|
||||
grouping_factor: LweBskGroupingFactor,
|
||||
ms_noise_reduction_configuration: Option<&CudaModulusSwitchNoiseReductionConfiguration>,
|
||||
) {
|
||||
assert_eq!(streams.gpu_indexes[0], bootstrapping_key.gpu_index(0));
|
||||
assert_eq!(streams.gpu_indexes[0], keyswitch_key.gpu_index(0));
|
||||
|
||||
let num_inputs_lhs = lhs.len() as u32;
|
||||
let num_inputs_rhs = rhs.len() as u32;
|
||||
let num_blocks = lhs[0].as_ref().d_blocks.lwe_ciphertext_count().0 as u32;
|
||||
let noise_reduction_type = resolve_ms_noise_reduction_config(ms_noise_reduction_configuration);
|
||||
|
||||
let mut match_degrees = vec![match_ct.0.ciphertext.info.blocks[0].degree.get()];
|
||||
let mut match_noise_levels = vec![match_ct.0.ciphertext.info.blocks[0].noise_level.0];
|
||||
let mut ffi_match = prepare_cuda_radix_ffi(
|
||||
&match_ct.0.ciphertext,
|
||||
&mut match_degrees,
|
||||
&mut match_noise_levels,
|
||||
);
|
||||
|
||||
let mut ffi_lhs_degrees: Vec<Vec<u64>> = Vec::with_capacity(lhs.len());
|
||||
let mut ffi_lhs_noise_levels: Vec<Vec<u64>> = Vec::with_capacity(lhs.len());
|
||||
let ffi_lhs: Vec<CudaRadixCiphertextFFI> = lhs
|
||||
.iter()
|
||||
.map(|ct| {
|
||||
let degrees = ct
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.degree.get())
|
||||
.collect();
|
||||
let noise_levels = ct
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.noise_level.0)
|
||||
.collect();
|
||||
ffi_lhs_degrees.push(degrees);
|
||||
ffi_lhs_noise_levels.push(noise_levels);
|
||||
|
||||
prepare_cuda_radix_ffi(
|
||||
ct.as_ref(),
|
||||
ffi_lhs_degrees.last_mut().unwrap(),
|
||||
ffi_lhs_noise_levels.last_mut().unwrap(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut ffi_rhs_degrees: Vec<Vec<u64>> = Vec::with_capacity(rhs.len());
|
||||
let mut ffi_rhs_noise_levels: Vec<Vec<u64>> = Vec::with_capacity(rhs.len());
|
||||
let ffi_rhs: Vec<CudaRadixCiphertextFFI> = rhs
|
||||
.iter()
|
||||
.map(|ct| {
|
||||
let degrees = ct
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.degree.get())
|
||||
.collect();
|
||||
let noise_levels = ct
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|b| b.noise_level.0)
|
||||
.collect();
|
||||
ffi_rhs_degrees.push(degrees);
|
||||
ffi_rhs_noise_levels.push(noise_levels);
|
||||
|
||||
prepare_cuda_radix_ffi(
|
||||
ct.as_ref(),
|
||||
ffi_rhs_degrees.last_mut().unwrap(),
|
||||
ffi_rhs_noise_levels.last_mut().unwrap(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let mut mem_ptr: *mut i8 = std::ptr::null_mut();
|
||||
|
||||
scratch_cuda_unchecked_contains_sub_slice_64(
|
||||
streams.ffi(),
|
||||
std::ptr::addr_of_mut!(mem_ptr),
|
||||
glwe_dimension.0 as u32,
|
||||
polynomial_size.0 as u32,
|
||||
big_lwe_dimension.0 as u32,
|
||||
small_lwe_dimension.0 as u32,
|
||||
ks_level.0 as u32,
|
||||
ks_base_log.0 as u32,
|
||||
pbs_level.0 as u32,
|
||||
pbs_base_log.0 as u32,
|
||||
grouping_factor.0 as u32,
|
||||
num_inputs_lhs,
|
||||
num_inputs_rhs,
|
||||
num_blocks,
|
||||
message_modulus.0 as u32,
|
||||
carry_modulus.0 as u32,
|
||||
pbs_type as u32,
|
||||
true,
|
||||
noise_reduction_type as u32,
|
||||
);
|
||||
|
||||
cuda_unchecked_contains_sub_slice_64(
|
||||
streams.ffi(),
|
||||
&raw mut ffi_match,
|
||||
ffi_lhs.as_ptr(),
|
||||
ffi_rhs.as_ptr(),
|
||||
num_inputs_rhs,
|
||||
num_blocks,
|
||||
mem_ptr,
|
||||
bootstrapping_key.ptr.as_ptr(),
|
||||
keyswitch_key.ptr.as_ptr(),
|
||||
);
|
||||
|
||||
cleanup_cuda_unchecked_contains_sub_slice_64(streams.ffi(), std::ptr::addr_of_mut!(mem_ptr));
|
||||
|
||||
update_noise_degree(&mut match_ct.0.ciphertext, &ffi_match);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// # Safety
|
||||
///
|
||||
|
||||
@@ -16,8 +16,8 @@ use crate::integer::gpu::ciphertext::{
|
||||
use crate::integer::gpu::noise_squashing::keys::CudaNoiseSquashingKey;
|
||||
use crate::integer::gpu::server_key::CudaBootstrappingKey;
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_apply_bivariate_lut, cuda_backend_apply_many_univariate_lut,
|
||||
cuda_backend_apply_univariate_lut, cuda_backend_cast_to_signed, cuda_backend_cast_to_unsigned,
|
||||
cuda_backend_apply_many_univariate_lut, cuda_backend_apply_univariate_lut,
|
||||
cuda_backend_cast_to_signed, cuda_backend_cast_to_unsigned,
|
||||
cuda_backend_extend_radix_with_trivial_zero_blocks_msb, cuda_backend_full_propagate_assign,
|
||||
cuda_backend_noise_squashing, cuda_backend_propagate_single_carry_assign,
|
||||
cuda_backend_trim_radix_blocks_lsb, cuda_backend_trim_radix_blocks_msb, CudaServerKey, PBSType,
|
||||
@@ -27,8 +27,7 @@ use crate::shortint::ciphertext::{Degree, NoiseLevel};
|
||||
use crate::shortint::engine::fill_many_lut_accumulator;
|
||||
use crate::shortint::parameters::AtomicPatternKind;
|
||||
use crate::shortint::server_key::{
|
||||
generate_lookup_table, BivariateLookupTableOwned, LookupTableOwned, LookupTableSize,
|
||||
ManyLookupTableOwned,
|
||||
generate_lookup_table, LookupTableOwned, LookupTableSize, ManyLookupTableOwned,
|
||||
};
|
||||
use crate::shortint::{PBSOrder, PaddingBit, ShortintEncoding};
|
||||
|
||||
@@ -654,30 +653,6 @@ impl CudaServerKey {
|
||||
}
|
||||
}
|
||||
|
||||
/// Generates a bivariate accumulator
|
||||
pub(crate) fn generate_lookup_table_bivariate<F>(&self, f: F) -> BivariateLookupTableOwned
|
||||
where
|
||||
F: Fn(u64, u64) -> u64,
|
||||
{
|
||||
// Depending on the factor used, rhs and / or lhs may have carries
|
||||
// (degree >= message_modulus) which is why we need to apply the message_modulus
|
||||
// to clear them
|
||||
let message_modulus = self.message_modulus.0;
|
||||
let factor_u64 = message_modulus;
|
||||
let wrapped_f = |input: u64| -> u64 {
|
||||
let lhs = (input / factor_u64) % message_modulus;
|
||||
let rhs = (input % factor_u64) % message_modulus;
|
||||
|
||||
f(lhs, rhs)
|
||||
};
|
||||
let accumulator = self.generate_lookup_table(wrapped_f);
|
||||
|
||||
BivariateLookupTableOwned {
|
||||
acc: accumulator,
|
||||
ct_right_modulus: self.message_modulus,
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the lookup table on the range of ciphertexts
|
||||
///
|
||||
/// The output must have exactly block_range.len() blocks
|
||||
@@ -780,122 +755,6 @@ impl CudaServerKey {
|
||||
}
|
||||
}
|
||||
|
||||
/// Applies the bivariate lookup table on the range of ciphertexts
|
||||
///
|
||||
/// The output must have exactly block_range.len() blocks
|
||||
pub(crate) fn apply_bivariate_lookup_table(
|
||||
&self,
|
||||
output: &mut CudaRadixCiphertext,
|
||||
input_1: &CudaRadixCiphertext,
|
||||
input_2: &CudaRadixCiphertext,
|
||||
lut: &BivariateLookupTableOwned,
|
||||
block_range: std::ops::Range<usize>,
|
||||
streams: &CudaStreams,
|
||||
) {
|
||||
if block_range.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
input_1.d_blocks.lwe_dimension(),
|
||||
output.d_blocks.lwe_dimension()
|
||||
);
|
||||
assert_eq!(
|
||||
input_2.d_blocks.lwe_dimension(),
|
||||
output.d_blocks.lwe_dimension()
|
||||
);
|
||||
|
||||
let lwe_dimension = input_1.d_blocks.lwe_dimension();
|
||||
let lwe_size = lwe_dimension.to_lwe_size().0;
|
||||
let num_output_blocks = output.d_blocks.lwe_ciphertext_count().0;
|
||||
|
||||
let input_slice_1 = input_1
|
||||
.d_blocks
|
||||
.0
|
||||
.d_vec
|
||||
.as_slice(lwe_size * block_range.start..lwe_size * block_range.end, 0)
|
||||
.unwrap();
|
||||
let input_slice_2 = input_2
|
||||
.d_blocks
|
||||
.0
|
||||
.d_vec
|
||||
.as_slice(lwe_size * block_range.start..lwe_size * block_range.end, 0)
|
||||
.unwrap();
|
||||
let mut output_slice = output.d_blocks.0.d_vec.as_mut_slice(.., 0).unwrap();
|
||||
let mut output_degrees = vec![0_u64; num_output_blocks];
|
||||
let mut output_noise_levels = vec![0_u64; num_output_blocks];
|
||||
|
||||
let num_ct_blocks = block_range.len() as u32;
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_apply_bivariate_lut(
|
||||
streams,
|
||||
&mut output_slice,
|
||||
&mut output_degrees,
|
||||
&mut output_noise_levels,
|
||||
&input_slice_1,
|
||||
&input_slice_2,
|
||||
lut.acc.acc.as_ref(),
|
||||
lut.acc.degree.0,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
num_ct_blocks,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
self.message_modulus.0 as u32,
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_apply_bivariate_lut(
|
||||
streams,
|
||||
&mut output_slice,
|
||||
&mut output_degrees,
|
||||
&mut output_noise_levels,
|
||||
&input_slice_1,
|
||||
&input_slice_2,
|
||||
lut.acc.acc.as_ref(),
|
||||
lut.acc.degree.0,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
num_ct_blocks,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
self.message_modulus.0 as u32,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (i, info) in output.info.blocks[block_range].iter_mut().enumerate() {
|
||||
info.degree = Degree(output_degrees[i]);
|
||||
info.noise_level = NoiseLevel(output_noise_levels[i]);
|
||||
}
|
||||
}
|
||||
/// Applies many lookup tables on the range of ciphertexts
|
||||
///
|
||||
/// # Example
|
||||
|
||||
@@ -1,43 +1,13 @@
|
||||
use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList;
|
||||
use crate::core_crypto::gpu::CudaStreams;
|
||||
use crate::core_crypto::prelude::LweBskGroupingFactor;
|
||||
use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock;
|
||||
use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext};
|
||||
use crate::integer::gpu::server_key::radix::{
|
||||
CudaBlockInfo, CudaRadixCiphertext, CudaRadixCiphertextInfo,
|
||||
use crate::integer::gpu::server_key::{CudaBootstrappingKey, CudaServerKey};
|
||||
use crate::integer::gpu::{
|
||||
cuda_backend_unchecked_all_eq_slices, cuda_backend_unchecked_contains_sub_slice, PBSType,
|
||||
};
|
||||
use crate::integer::gpu::server_key::CudaServerKey;
|
||||
use crate::shortint::ciphertext::Degree;
|
||||
use crate::shortint::parameters::NoiseLevel;
|
||||
|
||||
impl CudaServerKey {
|
||||
#[allow(clippy::unused_self)]
|
||||
pub(crate) fn convert_integer_radixes_vec_to_single_integer_radix_ciphertext<T>(
|
||||
&self,
|
||||
radixes: &[T],
|
||||
streams: &CudaStreams,
|
||||
) -> T
|
||||
where
|
||||
T: CudaIntegerRadixCiphertext,
|
||||
{
|
||||
let packed_list = CudaLweCiphertextList::from_vec_cuda_lwe_ciphertexts_list(
|
||||
radixes
|
||||
.iter()
|
||||
.map(|ciphertext| &ciphertext.as_ref().d_blocks),
|
||||
streams,
|
||||
);
|
||||
let vec_block_info: Vec<CudaBlockInfo> = radixes
|
||||
.iter()
|
||||
.flat_map(|ct| ct.as_ref().info.blocks.clone())
|
||||
.collect();
|
||||
let radix_info = CudaRadixCiphertextInfo {
|
||||
blocks: vec_block_info,
|
||||
};
|
||||
CudaIntegerRadixCiphertext::from(CudaRadixCiphertext {
|
||||
d_blocks: packed_list,
|
||||
info: radix_info,
|
||||
})
|
||||
}
|
||||
|
||||
/// Compares two slices containing ciphertexts and returns an encryption of `true` if all
|
||||
/// pairs are equal, otherwise, returns an encryption of `false`.
|
||||
///
|
||||
@@ -61,7 +31,7 @@ impl CudaServerKey {
|
||||
);
|
||||
return trivial_bool;
|
||||
}
|
||||
// If both are empty, return true
|
||||
|
||||
if lhs.is_empty() {
|
||||
let trivial_ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(1, 1, streams);
|
||||
|
||||
@@ -82,73 +52,70 @@ impl CudaServerKey {
|
||||
return trivial_bool;
|
||||
}
|
||||
|
||||
let block_equality_lut = self.generate_lookup_table_bivariate(|l, r| u64::from(l == r));
|
||||
let trivial_bool =
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
|
||||
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
|
||||
|
||||
let packed_lhs_list = CudaLweCiphertextList::from_vec_cuda_lwe_ciphertexts_list(
|
||||
lhs.iter().map(|ciphertext| &ciphertext.as_ref().d_blocks),
|
||||
streams,
|
||||
);
|
||||
let packed_rhs_list = CudaLweCiphertextList::from_vec_cuda_lwe_ciphertexts_list(
|
||||
rhs.iter().map(|ciphertext| &ciphertext.as_ref().d_blocks),
|
||||
streams,
|
||||
);
|
||||
let num_radix_blocks = packed_rhs_list.lwe_ciphertext_count().0;
|
||||
let block_info = CudaBlockInfo {
|
||||
degree: Degree(0),
|
||||
message_modulus: lhs
|
||||
.first()
|
||||
.unwrap()
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.first()
|
||||
.unwrap()
|
||||
.message_modulus,
|
||||
carry_modulus: lhs
|
||||
.first()
|
||||
.unwrap()
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.first()
|
||||
.unwrap()
|
||||
.carry_modulus,
|
||||
atomic_pattern: lhs
|
||||
.first()
|
||||
.unwrap()
|
||||
.as_ref()
|
||||
.info
|
||||
.blocks
|
||||
.first()
|
||||
.unwrap()
|
||||
.atomic_pattern,
|
||||
noise_level: NoiseLevel::ZERO,
|
||||
};
|
||||
let info = CudaRadixCiphertextInfo {
|
||||
blocks: vec![block_info; num_radix_blocks],
|
||||
};
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_all_eq_slices(
|
||||
streams,
|
||||
&mut match_ct,
|
||||
lhs,
|
||||
rhs,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_all_eq_slices(
|
||||
streams,
|
||||
&mut match_ct,
|
||||
lhs,
|
||||
rhs,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let packed_lhs = CudaRadixCiphertext {
|
||||
d_blocks: packed_lhs_list,
|
||||
info: info.clone(),
|
||||
};
|
||||
let packed_rhs = CudaRadixCiphertext {
|
||||
d_blocks: packed_rhs_list,
|
||||
info,
|
||||
};
|
||||
|
||||
let mut comparison_blocks: CudaUnsignedRadixCiphertext =
|
||||
self.create_trivial_radix(0, num_radix_blocks, streams);
|
||||
|
||||
self.apply_bivariate_lookup_table(
|
||||
comparison_blocks.as_mut(),
|
||||
&packed_lhs,
|
||||
&packed_rhs,
|
||||
&block_equality_lut,
|
||||
0..num_radix_blocks,
|
||||
streams,
|
||||
);
|
||||
self.unchecked_are_all_comparisons_block_true(&comparison_blocks, streams)
|
||||
match_ct
|
||||
}
|
||||
|
||||
/// Compares two slices containing ciphertexts and returns an encryption of `true` if all
|
||||
@@ -290,16 +257,78 @@ impl CudaServerKey {
|
||||
return trivial_bool;
|
||||
}
|
||||
|
||||
let windows_results = lhs
|
||||
.windows(rhs.len())
|
||||
.map(|lhs_sub_slice| self.unchecked_all_eq_slices(lhs_sub_slice, rhs, streams).0)
|
||||
.collect::<Vec<_>>();
|
||||
let packed_windows_results = self
|
||||
.convert_integer_radixes_vec_to_single_integer_radix_ciphertext(
|
||||
&windows_results,
|
||||
streams,
|
||||
if rhs.is_empty() {
|
||||
let trivial_ct: CudaUnsignedRadixCiphertext = self.create_trivial_radix(1, 1, streams);
|
||||
let trivial_bool = CudaBooleanBlock::from_cuda_radix_ciphertext(
|
||||
trivial_ct.duplicate(streams).into_inner(),
|
||||
);
|
||||
self.unchecked_is_at_least_one_comparisons_block_true(&packed_windows_results, streams)
|
||||
return trivial_bool;
|
||||
}
|
||||
|
||||
let trivial_bool =
|
||||
self.create_trivial_zero_radix::<CudaUnsignedRadixCiphertext>(1, streams);
|
||||
let mut match_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(trivial_bool.into_inner());
|
||||
|
||||
unsafe {
|
||||
match &self.bootstrapping_key {
|
||||
CudaBootstrappingKey::Classic(d_bsk) => {
|
||||
cuda_backend_unchecked_contains_sub_slice(
|
||||
streams,
|
||||
&mut match_ct,
|
||||
lhs,
|
||||
rhs,
|
||||
&d_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_bsk.glwe_dimension,
|
||||
d_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_bsk.decomp_level_count,
|
||||
d_bsk.decomp_base_log,
|
||||
PBSType::Classical,
|
||||
LweBskGroupingFactor(0),
|
||||
d_bsk.ms_noise_reduction_configuration.as_ref(),
|
||||
);
|
||||
}
|
||||
CudaBootstrappingKey::MultiBit(d_multibit_bsk) => {
|
||||
cuda_backend_unchecked_contains_sub_slice(
|
||||
streams,
|
||||
&mut match_ct,
|
||||
lhs,
|
||||
rhs,
|
||||
&d_multibit_bsk.d_vec,
|
||||
&self.key_switching_key.d_vec,
|
||||
self.message_modulus,
|
||||
self.carry_modulus,
|
||||
d_multibit_bsk.glwe_dimension,
|
||||
d_multibit_bsk.polynomial_size,
|
||||
self.key_switching_key
|
||||
.input_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key
|
||||
.output_key_lwe_size()
|
||||
.to_lwe_dimension(),
|
||||
self.key_switching_key.decomposition_level_count(),
|
||||
self.key_switching_key.decomposition_base_log(),
|
||||
d_multibit_bsk.decomp_level_count,
|
||||
d_multibit_bsk.decomp_base_log,
|
||||
PBSType::MultiBit,
|
||||
d_multibit_bsk.grouping_factor,
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match_ct
|
||||
}
|
||||
|
||||
/// Returns a boolean ciphertext encrypting `true` if `lhs` contains `rhs`, `false` otherwise
|
||||
|
||||
Reference in New Issue
Block a user